Commit 22d2d1cd authored by Stan Moore's avatar Stan Moore
Browse files

Fix issue in pair_snap_kokkos memory_usage

parent 0d7bee40
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -48,6 +48,7 @@ public:
  void coeff(int, char**);
  void init_style();
  void compute(int, int);
  double memory_usage();

  template<int NEIGHFLAG, int EVFLAG>
  KOKKOS_INLINE_FUNCTION
+19 −0
Original line number Diff line number Diff line
@@ -615,3 +615,22 @@ void PairSNAPKokkos<DeviceType>::v_tally_xyz(EV_FLOAT &ev, const int &i, const i
    v_vatom(j,5) += 0.5*v5;
  }
}

/* ----------------------------------------------------------------------
   memory usage
------------------------------------------------------------------------- */

template<class DeviceType>
double PairSNAPKokkos<DeviceType>::memory_usage()
{
  double bytes = Pair::memory_usage();
  int n = atom->ntypes+1;
  bytes += n*n*sizeof(int);
  bytes += n*n*sizeof(double);
  bytes += 3*nmax*sizeof(double);
  bytes += nmax*sizeof(int);
  bytes += (2*ncoeffall)*sizeof(double);
  bytes += (ncoeff*3)*sizeof(double);
  bytes += snaKK.memory_usage();
  return bytes;
}
+2 −0
Original line number Diff line number Diff line
@@ -65,6 +65,8 @@ inline
inline
  T_INT size_thread_scratch_arrays();

  double memory_usage();

  int ncoeff;

  // functions for bispectrum coefficients
+18 −0
Original line number Diff line number Diff line
@@ -1269,3 +1269,21 @@ double SNAKokkos<DeviceType>::compute_dsfac(double r, double rcut)
  return 0.0;
}

/* ----------------------------------------------------------------------
   memory usage of arrays
------------------------------------------------------------------------- */

template<class DeviceType>
double SNAKokkos<DeviceType>::memory_usage()
{
  int jdim = twojmax + 1;
  double bytes;
  bytes = jdim * jdim * jdim * jdim * jdim * sizeof(double);
  bytes += 2 * jdim * jdim * jdim * sizeof(std::complex<double>);
  bytes += 2 * jdim * jdim * jdim * sizeof(double);
  bytes += jdim * jdim * jdim * 3 * sizeof(std::complex<double>);
  bytes += jdim * jdim * jdim * 3 * sizeof(double);
  bytes += ncoeff * sizeof(double);
  bytes += jdim * jdim * jdim * jdim * jdim * sizeof(std::complex<double>);
  return bytes;
}