Commit 17bdb57b authored by Dan Ibanez's avatar Dan Ibanez
Browse files

try PairTableKokkos : public PairTable

realize that there is a lot of copy-paste
in this codebase.
parent 62dea1bb
Loading
Loading
Loading
Loading
+1 −262
Original line number Diff line number Diff line
@@ -59,17 +59,8 @@ PairTableKokkos<DeviceType>::PairTableKokkos(LAMMPS *lmp) : Pair(lmp)
template<class DeviceType>
PairTableKokkos<DeviceType>::~PairTableKokkos()
{
/*  for (int m = 0; m < ntables; m++) free_table(&tables[m]);
  memory->sfree(tables);

  if (allocated) {
    memory->destroy(setflag);
    memory->destroy(cutsq);
    memory->destroy(tabindex);
  }*/
  delete h_table;
  delete d_table;

}

/* ---------------------------------------------------------------------- */
@@ -190,26 +181,15 @@ compute_fpair(const F_FLOAT& rsq, const int& i, const int&j, const int& itype, c
  union_int_float_t rsq_lookup;
  double fpair;
  const int tidx = d_table_const.tabindex(itype,jtype);
  //const Table* const tb = &tables[tabindex[itype][jtype]];

  //if (rsq < d_table_const.innersq(tidx))
  //  error->one(FLERR,"Pair distance < table inner cutoff");

  if (Specialisation::TabStyle == LOOKUP) {
    const int itable = static_cast<int> ((rsq - d_table_const.innersq(tidx)) * d_table_const.invdelta(tidx));
    //if (itable >= tlm1)
    //  error->one(FLERR,"Pair distance > table outer cutoff");
    fpair = d_table_const.f(tidx,itable);
  } else if (Specialisation::TabStyle == LINEAR) {
    const int itable = static_cast<int> ((rsq - d_table_const.innersq(tidx)) * d_table_const.invdelta(tidx));
    //if (itable >= tlm1)
    //  error->one(FLERR,"Pair distance > table outer cutoff");
    const double fraction = (rsq - d_table_const.rsq(tidx,itable)) * d_table_const.invdelta(tidx);
    fpair = d_table_const.f(tidx,itable) + fraction*d_table_const.df(tidx,itable);
  } else if (Specialisation::TabStyle == SPLINE) {
    const int itable = static_cast<int> ((rsq - d_table_const.innersq(tidx)) * d_table_const.invdelta(tidx));
    //if (itable >= tlm1)
    //  error->one(FLERR,"Pair distance > table outer cutoff");
    const double b = (rsq - d_table_const.rsq(tidx,itable)) * d_table_const.invdelta(tidx);
    const double a = 1.0 - b;
    fpair = a * d_table_const.f(tidx,itable) + b * d_table_const.f(tidx,itable+1) +
@@ -235,26 +215,15 @@ compute_evdwl(const F_FLOAT& rsq, const int& i, const int&j, const int& itype, c
  double evdwl;
  union_int_float_t rsq_lookup;
  const int tidx = d_table_const.tabindex(itype,jtype);
  //const Table* const tb = &tables[tabindex[itype][jtype]];

  //if (rsq < d_table_const.innersq(tidx))
  //  error->one(FLERR,"Pair distance < table inner cutoff");

  if (Specialisation::TabStyle == LOOKUP) {
    const int itable = static_cast<int> ((rsq - d_table_const.innersq(tidx)) * d_table_const.invdelta(tidx));
    //if (itable >= tlm1)
    //  error->one(FLERR,"Pair distance > table outer cutoff");
    evdwl = d_table_const.e(tidx,itable);
  } else if (Specialisation::TabStyle == LINEAR) {
    const int itable = static_cast<int> ((rsq - d_table_const.innersq(tidx)) * d_table_const.invdelta(tidx));
    //if (itable >= tlm1)
    //  error->one(FLERR,"Pair distance > table outer cutoff");
    const double fraction = (rsq - d_table_const.rsq(tidx,itable)) * d_table_const.invdelta(tidx);
    evdwl = d_table_const.e(tidx,itable) + fraction*d_table_const.de(tidx,itable);
  } else if (Specialisation::TabStyle == SPLINE) {
    const int itable = static_cast<int> ((rsq - d_table_const.innersq(tidx)) * d_table_const.invdelta(tidx));
    //if (itable >= tlm1)
    //  error->one(FLERR,"Pair distance > table outer cutoff");
    const double b = (rsq - d_table_const.rsq(tidx,itable)) * d_table_const.invdelta(tidx);
    const double a = 1.0 - b;
    evdwl = a * d_table_const.e(tidx,itable) + b * d_table_const.e(tidx,itable+1) +
@@ -383,18 +352,10 @@ void PairTableKokkos<DeviceType>::create_kokkos_tables()
template<class DeviceType>
void PairTableKokkos<DeviceType>::allocate()
{
  allocated = 1;
  const int nt = atom->ntypes + 1;

  memory->create(setflag,nt,nt,"pair:setflag");
  memory->create_kokkos(d_table->cutsq,h_table->cutsq,cutsq,nt,nt,"pair:cutsq");
  memory->create_kokkos(d_table->tabindex,h_table->tabindex,tabindex,nt,nt,"pair:tabindex");
  PairTable::allocate();

  d_table_const.cutsq = d_table->cutsq;
  d_table_const.tabindex = d_table->tabindex;
  memset(&setflag[0][0],0,nt*nt*sizeof(int));
  memset(&cutsq[0][0],0,nt*nt*sizeof(double));
  memset(&tabindex[0][0],0,nt*nt*sizeof(int));
}

/* ----------------------------------------------------------------------
@@ -1082,143 +1043,6 @@ double PairTableKokkos<DeviceType>::splint(double *xa, double *ya, double *y2a,
  return y;
}

/* ----------------------------------------------------------------------
   proc 0 writes to restart file
------------------------------------------------------------------------- */

template<class DeviceType>
void PairTableKokkos<DeviceType>::write_restart(FILE *fp)
{
  write_restart_settings(fp);
}

/* ----------------------------------------------------------------------
   proc 0 reads from restart file, bcasts
------------------------------------------------------------------------- */

template<class DeviceType>
void PairTableKokkos<DeviceType>::read_restart(FILE *fp)
{
  read_restart_settings(fp);
  allocate();
}

/* ----------------------------------------------------------------------
   proc 0 writes to restart file
------------------------------------------------------------------------- */

template<class DeviceType>
void PairTableKokkos<DeviceType>::write_restart_settings(FILE *fp)
{
  fwrite(&tabstyle,sizeof(int),1,fp);
  fwrite(&tablength,sizeof(int),1,fp);
  fwrite(&ewaldflag,sizeof(int),1,fp);
  fwrite(&pppmflag,sizeof(int),1,fp);
  fwrite(&msmflag,sizeof(int),1,fp);
  fwrite(&dispersionflag,sizeof(int),1,fp);
  fwrite(&tip4pflag,sizeof(int),1,fp);
}

/* ----------------------------------------------------------------------
   proc 0 reads from restart file, bcasts
------------------------------------------------------------------------- */

template<class DeviceType>
void PairTableKokkos<DeviceType>::read_restart_settings(FILE *fp)
{
  if (comm->me == 0) {
    fread(&tabstyle,sizeof(int),1,fp);
    fread(&tablength,sizeof(int),1,fp);
    fread(&ewaldflag,sizeof(int),1,fp);
    fread(&pppmflag,sizeof(int),1,fp);
    fread(&msmflag,sizeof(int),1,fp);
    fread(&dispersionflag,sizeof(int),1,fp);
    fread(&tip4pflag,sizeof(int),1,fp);
  }
  MPI_Bcast(&tabstyle,1,MPI_INT,0,world);
  MPI_Bcast(&tablength,1,MPI_INT,0,world);
  MPI_Bcast(&ewaldflag,1,MPI_INT,0,world);
  MPI_Bcast(&pppmflag,1,MPI_INT,0,world);
  MPI_Bcast(&msmflag,1,MPI_INT,0,world);
  MPI_Bcast(&dispersionflag,1,MPI_INT,0,world);
  MPI_Bcast(&tip4pflag,1,MPI_INT,0,world);
}

/* ---------------------------------------------------------------------- */

template<class DeviceType>
double PairTableKokkos<DeviceType>::single(int i, int j, int itype, int jtype, double rsq,
                         double factor_coul, double factor_lj,
                         double &fforce)
{
  int itable;
  double fraction,value,a,b,phi;
  int tlm1 = tablength - 1;

  Table *tb = &tables[tabindex[itype][jtype]];
  if (rsq < tb->innersq) error->one(FLERR,"Pair distance < table inner cutoff");

  if (tabstyle == LOOKUP) {
    itable = static_cast<int> ((rsq-tb->innersq) * tb->invdelta);
    if (itable >= tlm1) error->one(FLERR,"Pair distance > table outer cutoff");
    fforce = factor_lj * tb->f[itable];
  } else if (tabstyle == LINEAR) {
    itable = static_cast<int> ((rsq-tb->innersq) * tb->invdelta);
    if (itable >= tlm1) error->one(FLERR,"Pair distance > table outer cutoff");
    fraction = (rsq - tb->rsq[itable]) * tb->invdelta;
    value = tb->f[itable] + fraction*tb->df[itable];
    fforce = factor_lj * value;
  } else if (tabstyle == SPLINE) {
    itable = static_cast<int> ((rsq-tb->innersq) * tb->invdelta);
    if (itable >= tlm1) error->one(FLERR,"Pair distance > table outer cutoff");
    b = (rsq - tb->rsq[itable]) * tb->invdelta;
    a = 1.0 - b;
    value = a * tb->f[itable] + b * tb->f[itable+1] +
      ((a*a*a-a)*tb->f2[itable] + (b*b*b-b)*tb->f2[itable+1]) *
      tb->deltasq6;
    fforce = factor_lj * value;
  } else {
    union_int_float_t rsq_lookup;
    rsq_lookup.f = rsq;
    itable = rsq_lookup.i & tb->nmask;
    itable >>= tb->nshiftbits;
    fraction = (rsq_lookup.f - tb->rsq[itable]) * tb->drsq[itable];
    value = tb->f[itable] + fraction*tb->df[itable];
    fforce = factor_lj * value;
  }

  if (tabstyle == LOOKUP)
    phi = tb->e[itable];
  else if (tabstyle == LINEAR || tabstyle == BITMAP)
    phi = tb->e[itable] + fraction*tb->de[itable];
  else
    phi = a * tb->e[itable] + b * tb->e[itable+1] +
      ((a*a*a-a)*tb->e2[itable] + (b*b*b-b)*tb->e2[itable+1]) * tb->deltasq6;
  return factor_lj*phi;
}

/* ----------------------------------------------------------------------
   return the Coulomb cutoff for tabled potentials
   called by KSpace solvers which require that all pairwise cutoffs be the same
   loop over all tables not just those indexed by tabindex[i][j] since
     no way to know which tables are active since pair::init() not yet called
------------------------------------------------------------------------- */

template<class DeviceType>
void *PairTableKokkos<DeviceType>::extract(const char *str, int &dim)
{
  if (strcmp(str,"cut_coul") != 0) return NULL;
  if (ntables == 0) error->all(FLERR,"All pair coeffs are not set");

  double cut_coul = tables[0].cut;
  for (int m = 1; m < ntables; m++)
    if (tables[m].cut != cut_coul)
      error->all(FLERR,
                 "Pair table cutoffs must all be equal to use with KSpace");
  dim = 0;
  return &tables[0].cut;
}

template<class DeviceType>
void PairTableKokkos<DeviceType>::init_style()
{
@@ -1246,91 +1070,6 @@ void PairTableKokkos<DeviceType>::init_style()
  }
}

/*
template <class DeviceType> template<int NEIGHFLAG>
KOKKOS_INLINE_FUNCTION
void PairTableKokkos<DeviceType>::
ev_tally(EV_FLOAT &ev, const int &i, const int &j, const F_FLOAT &fpair,
         const F_FLOAT &delx, const F_FLOAT &dely, const F_FLOAT &delz) const
{
  const int EFLAG = eflag;
  const int NEWTON_PAIR = newton_pair;
  const int VFLAG = vflag_either;

  if (EFLAG) {
    if (eflag_atom) {
      E_FLOAT epairhalf = 0.5 * (ev.evdwl + ev.ecoul);
      if (NEWTON_PAIR || i < nlocal) eatom[i] += epairhalf;
      if (NEWTON_PAIR || j < nlocal) eatom[j] += epairhalf;
    }
  }

  if (VFLAG) {
    const E_FLOAT v0 = delx*delx*fpair;
    const E_FLOAT v1 = dely*dely*fpair;
    const E_FLOAT v2 = delz*delz*fpair;
    const E_FLOAT v3 = delx*dely*fpair;
    const E_FLOAT v4 = delx*delz*fpair;
    const E_FLOAT v5 = dely*delz*fpair;

    if (vflag_global) {
      if (NEIGHFLAG) {
        if (NEWTON_PAIR) {
          ev.v[0] += v0;
          ev.v[1] += v1;
          ev.v[2] += v2;
          ev.v[3] += v3;
          ev.v[4] += v4;
          ev.v[5] += v5;
        } else {
          if (i < nlocal) {
            ev.v[0] += 0.5*v0;
            ev.v[1] += 0.5*v1;
            ev.v[2] += 0.5*v2;
            ev.v[3] += 0.5*v3;
            ev.v[4] += 0.5*v4;
            ev.v[5] += 0.5*v5;
          }
          if (j < nlocal) {
            ev.v[0] += 0.5*v0;
            ev.v[1] += 0.5*v1;
            ev.v[2] += 0.5*v2;
            ev.v[3] += 0.5*v3;
            ev.v[4] += 0.5*v4;
            ev.v[5] += 0.5*v5;
          }
        }
      } else {
        ev.v[0] += 0.5*v0;
        ev.v[1] += 0.5*v1;
        ev.v[2] += 0.5*v2;
        ev.v[3] += 0.5*v3;
        ev.v[4] += 0.5*v4;
        ev.v[5] += 0.5*v5;
      }
    }

    if (vflag_atom) {
      if (NEWTON_PAIR || i < nlocal) {
        d_vatom(i,0) += 0.5*v0;
        d_vatom(i,1) += 0.5*v1;
        d_vatom(i,2) += 0.5*v2;
        d_vatom(i,3) += 0.5*v3;
        d_vatom(i,4) += 0.5*v4;
        d_vatom(i,5) += 0.5*v5;
      }
      if (NEWTON_PAIR || (NEIGHFLAG && j < nlocal)) {
        d_vatom(j,0) += 0.5*v0;
        d_vatom(j,1) += 0.5*v1;
        d_vatom(j,2) += 0.5*v2;
        d_vatom(j,3) += 0.5*v3;
        d_vatom(j,4) += 0.5*v4;
        d_vatom(j,5) += 0.5*v5;
      }
    }
  }
}
*/
template<class DeviceType>
void PairTableKokkos<DeviceType>::cleanup_copy() {
  // WHY needed: this prevents parent copy from deallocating any arrays
+3 −20
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ PairStyle(table/kk/host,PairTableKokkos<LMPHostType>)
#ifndef LMP_PAIR_TABLE_KOKKOS_H
#define LMP_PAIR_TABLE_KOKKOS_H

#include "pair.h"
#include "pair_table.h"
#include "pair_kokkos.h"
#include "neigh_list_kokkos.h"
#include "atom_kokkos.h"
@@ -38,7 +38,7 @@ template <class DeviceType, int NEIGHFLAG, int TABSTYLE>
class PairTableComputeFunctor;

template<class DeviceType>
class PairTableKokkos : public Pair {
class PairTableKokkos : public PairTable {
 public:

  enum {EnabledNeighFlags=FULL|HALFTHREAD|HALF|N2};
@@ -53,28 +53,15 @@ class PairTableKokkos : public Pair {
  template<int TABSTYLE>
  void compute_style(int, int);

  /*template<int EVFLAG, int NEIGHFLAG, int NEWTON_PAIR, int TABSTYLE>
  KOKKOS_FUNCTION
  EV_FLOAT compute_item(const int& i,
                        const NeighListKokkos<DeviceType> &list) const;
*/
  void settings(int, char **);
  void coeff(int, char **);
  double init_one(int, int);
  void write_restart(FILE *);
  void read_restart(FILE *);
  void write_restart_settings(FILE *);
  void read_restart_settings(FILE *);
  double single(int, int, int, int, double, double, double, double &);
  void *extract(const char *, int &);

  void init_style();


 protected:
  enum{LOOKUP,LINEAR,SPLINE,BITMAP};

  int tabstyle,tablength;
  /*struct TableDeviceConst {
    typename ArrayTypes<DeviceType>::t_ffloat_2d_randomread cutsq;
    typename ArrayTypes<DeviceType>::t_int_2d_randomread tabindex;
@@ -116,19 +103,15 @@ class PairTableKokkos : public Pair {
    double innersq,delta,invdelta,deltasq6;
    double *rsq,*drsq,*e,*de,*f,*df,*e2,*f2;
  };
  int ntables;
  Table *tables;
  TableDeviceConst d_table_const;
  TableDevice* d_table;
  TableHost* h_table;

  int **tabindex;
  F_FLOAT m_cutsq[MAX_TYPES_STACKPARAMS+1][MAX_TYPES_STACKPARAMS+1];

  typename ArrayTypes<DeviceType>::t_ffloat_2d d_cutsq;

  void allocate();
  void read_table(Table *, char *, char *);
  virtual void allocate();
  void param_extract(Table *, char *);
  void bcast_table(Table *);
  void spline_table(Table *);