Commit 384aef43 authored by stamoor's avatar stamoor
Browse files

Fixing Kokkos bugs

git-svn-id: svn://svn.icms.temple.edu/lammps-ro/trunk@14579 f3b2605a-c512-4ea7-a41b-209d697bcdaa
parent 753429e6
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -97,6 +97,7 @@ void AtomKokkos::allocate_type_arrays()
    mass = k_mass.h_view.ptr_on_device();
    mass_setflag = new int[ntypes+1];
    for (int itype = 1; itype <= ntypes; itype++) mass_setflag[itype] = 0;
    k_mass.modify<LMPHostType>();
  }
}

+2 −0
Original line number Diff line number Diff line
@@ -44,6 +44,7 @@ template<class DeviceType>
double ComputeTempKokkos<DeviceType>::compute_scalar()
{
  atomKK->sync(execution_space,datamask_read);
  atomKK->k_mass.sync<DeviceType>();

  invoked_scalar = update->ntimestep;

@@ -72,6 +73,7 @@ double ComputeTempKokkos<DeviceType>::compute_scalar()
  if (dof < 0.0 && natoms_temp > 0.0)
    error->all(FLERR,"Temperature compute degrees of freedom < 0");
  scalar *= tfactor;

  return scalar;
}

+184 −0
Original line number Diff line number Diff line
@@ -14,9 +14,15 @@
#include "domain_kokkos.h"
#include "atom_kokkos.h"
#include "atom_masks.h"
#include "error.h"
#include "force.h"
#include "kspace.h"

using namespace LAMMPS_NS;

#define BIG   1.0e20
#define SMALL 1.0e-4

/* ---------------------------------------------------------------------- */

DomainKokkos::DomainKokkos(LAMMPS *lmp) : Domain(lmp) {}
@@ -29,6 +35,184 @@ void DomainKokkos::init()
  Domain::init();
}

/* ----------------------------------------------------------------------
   reset global & local boxes due to global box boundary changes
   if shrink-wrapped, determine atom extent and reset boxlo/hi
   for triclinic, atoms must be in lamda coords (0-1) before reset_box is called
------------------------------------------------------------------------- */

template<class DeviceType>
struct DomainResetBoxFunctor{
public:
  typedef DeviceType device_type;
  typename ArrayTypes<DeviceType>::t_x_array x;

  struct value_type {
    double value[3][2] ;
  };

  DomainResetBoxFunctor(DAT::tdual_x_array _x):
    x(_x.view<DeviceType>()) {}

  KOKKOS_INLINE_FUNCTION
  void init(value_type &dst) const {
    dst.value[2][0] = dst.value[1][0] = dst.value[0][0] = BIG;
    dst.value[2][1] = dst.value[1][1] = dst.value[0][1] = -BIG;
  }

  KOKKOS_INLINE_FUNCTION
  void join(volatile value_type &dst,
             const volatile value_type &src) const {
    dst.value[0][0] = MIN(dst.value[0][0],src.value[0][0]);
    dst.value[0][1] = MAX(dst.value[0][1],src.value[0][1]);
    dst.value[1][0] = MIN(dst.value[1][0],src.value[1][0]);
    dst.value[1][1] = MAX(dst.value[1][1],src.value[1][1]);
    dst.value[2][0] = MIN(dst.value[2][0],src.value[2][0]);
    dst.value[2][1] = MAX(dst.value[2][1],src.value[2][1]);
  }

  KOKKOS_INLINE_FUNCTION
  void operator() (const int &i, value_type &dst) const {
    dst.value[0][0] = MIN(dst.value[0][0],x(i,0));
    dst.value[0][1] = MAX(dst.value[0][1],x(i,0));
    dst.value[1][0] = MIN(dst.value[1][0],x(i,1));
    dst.value[1][1] = MAX(dst.value[1][1],x(i,1));
    dst.value[2][0] = MIN(dst.value[2][0],x(i,2));
    dst.value[2][1] = MAX(dst.value[2][1],x(i,2));
  }
};

void DomainKokkos::reset_box()
{
  // perform shrink-wrapping
  // compute extent of atoms on this proc
  // for triclinic, this is done in lamda space

  atomKK->sync(Device,X_MASK);

  if (nonperiodic == 2) {

    int nlocal = atom->nlocal;

    DomainResetBoxFunctor<LMPDeviceType>::value_type result;

    DomainResetBoxFunctor<LMPDeviceType>
      f(atomKK->k_x);
    Kokkos::parallel_reduce(nlocal,f,result);
    LMPDeviceType::fence();

    double (*extent)[2] = result.value;
    double all[3][2];

    // compute extent across all procs
    // flip sign of MIN to do it in one Allreduce MAX

    extent[0][0] = -extent[0][0];
    extent[1][0] = -extent[1][0];
    extent[2][0] = -extent[2][0];

    MPI_Allreduce(extent,all,6,MPI_DOUBLE,MPI_MAX,world);

    // for triclinic, convert back to box coords before changing box

    if (triclinic) lamda2x(atom->nlocal);

    // in shrink-wrapped dims, set box by atom extent
    // if minimum set, enforce min box size settings
    // for triclinic, convert lamda extent to box coords, then set box lo/hi
    // decided NOT to do the next comment - don't want to sneakily change tilt
    // for triclinic, adjust tilt factors if 2nd dim is shrink-wrapped,
    //   so that displacement in 1st dim stays the same

    if (triclinic == 0) {
      if (xperiodic == 0) {
        if (boundary[0][0] == 2) boxlo[0] = -all[0][0] - small[0];
        else if (boundary[0][0] == 3)
          boxlo[0] = MIN(-all[0][0]-small[0],minxlo);
        if (boundary[0][1] == 2) boxhi[0] = all[0][1] + small[0];
        else if (boundary[0][1] == 3) boxhi[0] = MAX(all[0][1]+small[0],minxhi);
        if (boxlo[0] > boxhi[0]) error->all(FLERR,"Illegal simulation box");
      }
      if (yperiodic == 0) {
        if (boundary[1][0] == 2) boxlo[1] = -all[1][0] - small[1];
        else if (boundary[1][0] == 3)
          boxlo[1] = MIN(-all[1][0]-small[1],minylo);
        if (boundary[1][1] == 2) boxhi[1] = all[1][1] + small[1];
        else if (boundary[1][1] == 3) boxhi[1] = MAX(all[1][1]+small[1],minyhi);
        if (boxlo[1] > boxhi[1]) error->all(FLERR,"Illegal simulation box");
      }
      if (zperiodic == 0) {
        if (boundary[2][0] == 2) boxlo[2] = -all[2][0] - small[2];
        else if (boundary[2][0] == 3)
          boxlo[2] = MIN(-all[2][0]-small[2],minzlo);
        if (boundary[2][1] == 2) boxhi[2] = all[2][1] + small[2];
        else if (boundary[2][1] == 3) boxhi[2] = MAX(all[2][1]+small[2],minzhi);
        if (boxlo[2] > boxhi[2]) error->all(FLERR,"Illegal simulation box");
      }

    } else {
      double lo[3],hi[3];
      if (xperiodic == 0) {
        lo[0] = -all[0][0]; lo[1] = 0.0; lo[2] = 0.0;
        Domain::lamda2x(lo,lo);
        hi[0] = all[0][1]; hi[1] = 0.0; hi[2] = 0.0;
        Domain::lamda2x(hi,hi);
        if (boundary[0][0] == 2) boxlo[0] = lo[0] - small[0];
        else if (boundary[0][0] == 3) boxlo[0] = MIN(lo[0]-small[0],minxlo);
        if (boundary[0][1] == 2) boxhi[0] = hi[0] + small[0];
        else if (boundary[0][1] == 3) boxhi[0] = MAX(hi[0]+small[0],minxhi);
        if (boxlo[0] > boxhi[0]) error->all(FLERR,"Illegal simulation box");
      }
      if (yperiodic == 0) {
        lo[0] = 0.0; lo[1] = -all[1][0]; lo[2] = 0.0;
        Domain::lamda2x(lo,lo);
        hi[0] = 0.0; hi[1] = all[1][1]; hi[2] = 0.0;
        Domain::lamda2x(hi,hi);
        if (boundary[1][0] == 2) boxlo[1] = lo[1] - small[1];
        else if (boundary[1][0] == 3) boxlo[1] = MIN(lo[1]-small[1],minylo);
        if (boundary[1][1] == 2) boxhi[1] = hi[1] + small[1];
        else if (boundary[1][1] == 3) boxhi[1] = MAX(hi[1]+small[1],minyhi);
        if (boxlo[1] > boxhi[1]) error->all(FLERR,"Illegal simulation box");
        //xy *= (boxhi[1]-boxlo[1]) / yprd;
      }
      if (zperiodic == 0) {
        lo[0] = 0.0; lo[1] = 0.0; lo[2] = -all[2][0];
        Domain::lamda2x(lo,lo);
        hi[0] = 0.0; hi[1] = 0.0; hi[2] = all[2][1];
        Domain::lamda2x(hi,hi);
        if (boundary[2][0] == 2) boxlo[2] = lo[2] - small[2];
        else if (boundary[2][0] == 3) boxlo[2] = MIN(lo[2]-small[2],minzlo);
        if (boundary[2][1] == 2) boxhi[2] = hi[2] + small[2];
        else if (boundary[2][1] == 3) boxhi[2] = MAX(hi[2]+small[2],minzhi);
        if (boxlo[2] > boxhi[2]) error->all(FLERR,"Illegal simulation box");
        //xz *= (boxhi[2]-boxlo[2]) / xprd;
        //yz *= (boxhi[2]-boxlo[2]) / yprd;
      }
    }
  }

  // reset box whether shrink-wrapping or not

  set_global_box();
  set_local_box();

  // if shrink-wrapped & kspace is defined (i.e. using MSM), call setup()
  // also call init() (to test for compatibility) ?

  if (nonperiodic == 2 && force->kspace) {
    //force->kspace->init();
    force->kspace->setup();
  }

  // if shrink-wrapped & triclinic, re-convert to lamda coords for new box
  // re-invoke pbc() b/c x2lamda result can be outside [0,1] due to roundoff

  if (nonperiodic == 2 && triclinic) {
    x2lamda(atom->nlocal);
    pbc();
  }
}

/* ---------------------------------------------------------------------- */

template<class DeviceType, int PERIODIC, int DEFORM_VREMAP>
+1 −0
Original line number Diff line number Diff line
@@ -29,6 +29,7 @@ class DomainKokkos : public Domain {
  DomainKokkos(class LAMMPS *);
  ~DomainKokkos() {}
  void init();
  void reset_box();
  void pbc();
  void remap_all();
  void image_flip(int, int, int);