Commit 031077b4 authored by Stefan Paquay's avatar Stefan Paquay
Browse files

Made enforce2d also set rotations to in-plane.

parent 962946ee
Loading
Loading
Loading
Loading
+92 −14
Original line number Diff line number Diff line
@@ -12,13 +12,16 @@
------------------------------------------------------------------------- */

/* ----------------------------------------------------------------------
   Contributing authors: Stefan Paquay (Brandeis University)
   Contributing authors: Stefan Paquay & Matthew Peterson (Brandeis University)
------------------------------------------------------------------------- */

#include "atom_masks.h"
#include "atom_kokkos.h"
#include "comm.h"
#include "error.h"
#include "fix_enforce2d_kokkos.h"


using namespace LAMMPS_NS;


@@ -30,14 +33,21 @@ FixEnforce2DKokkos<DeviceType>::FixEnforce2DKokkos(LAMMPS *lmp, int narg, char *
  atomKK = (AtomKokkos *) atom;
  execution_space = ExecutionSpaceFromDevice<DeviceType>::space;

  datamask_read   = X_MASK | V_MASK | F_MASK | MASK_MASK;
  datamask_modify = X_MASK | V_MASK | F_MASK;
  datamask_read   = X_MASK | V_MASK | F_MASK | OMEGA_MASK | MASK_MASK;
  /* TORQUE_MASK | ANGMOM_MASK | */ // MASK_MASK;

  datamask_modify = X_MASK | V_MASK | F_MASK | OMEGA_MASK; // |
	  /* TORQUE_MASK | ANGMOM_MASK */ ;
}


template <class DeviceType>
void FixEnforce2DKokkos<DeviceType>::setup(int vflag)
{
  if( comm->me == 0 ){
    fprintf(screen, "omega, angmom and torque flags are %d, %d, %d\n",
            atomKK->omega_flag, atomKK->angmom_flag, atomKK->torque_flag );
  }
  post_force(vflag);
}

@@ -52,13 +62,71 @@ void FixEnforce2DKokkos<DeviceType>::post_force(int vflag)
  v = atomKK->k_v.view<DeviceType>();
  f = atomKK->k_f.view<DeviceType>();

  if( atomKK->omega_flag )
    omega  = atomKK->k_omega.view<DeviceType>();

  if( atomKK->angmom_flag )
    angmom = atomKK->k_angmom.view<DeviceType>();

  if( atomKK->torque_flag )
    torque = atomKK->k_torque.view<DeviceType>();


  mask = atomKK->k_mask.view<DeviceType>();

  int nlocal = atomKK->nlocal;
  if (igroup == atomKK->firstgroup) nlocal = atomKK->nfirst;

  FixEnforce2DKokkosPostForceFunctor<DeviceType> functor(this);
  int flag_mask = 0;
  if( atomKK->omega_flag ) flag_mask  |= 1;
  if( atomKK->angmom_flag ) flag_mask |= 2;
  if( atomKK->torque_flag ) flag_mask |= 4;

  switch( flag_mask ){
    case 0:{
      FixEnforce2DKokkosPostForceFunctor<DeviceType,0,0,0> functor(this);
      Kokkos::parallel_for(nlocal,functor);
      break;
    }
    case 1:{
      FixEnforce2DKokkosPostForceFunctor<DeviceType,1,0,0> functor(this);
      Kokkos::parallel_for(nlocal,functor);
      break;
    }
    case 2:{
      FixEnforce2DKokkosPostForceFunctor<DeviceType,0,1,0> functor(this);
      Kokkos::parallel_for(nlocal,functor);
      break;
    }
    case 3:{
      FixEnforce2DKokkosPostForceFunctor<DeviceType,1,1,0> functor(this);
      Kokkos::parallel_for(nlocal,functor);
      break;
    }
    case 4:{
      FixEnforce2DKokkosPostForceFunctor<DeviceType,0,0,1> functor(this);
      Kokkos::parallel_for(nlocal,functor);
      break;
    }
    case 5:{
      FixEnforce2DKokkosPostForceFunctor<DeviceType,1,0,1> functor(this);
      Kokkos::parallel_for(nlocal,functor);
      break;
    }
    case 6:{
      FixEnforce2DKokkosPostForceFunctor<DeviceType,0,1,1> functor(this);
      Kokkos::parallel_for(nlocal,functor);
      break;
    }
    case 7:{
      FixEnforce2DKokkosPostForceFunctor<DeviceType,1,1,1> functor(this);
      Kokkos::parallel_for(nlocal,functor);
      break;
    }
    default:
      error->all(FLERR, "flag_mask outside of what it should be");
  }


  // Probably sync here again?
  atomKK->sync(execution_space,datamask_read);
@@ -66,23 +134,33 @@ void FixEnforce2DKokkos<DeviceType>::post_force(int vflag)

  for (int m = 0; m < nfixlist; m++)
    flist[m]->enforce2d();


}


template <class DeviceType>
template <int omega_flag, int angmom_flag, int torque_flag>
void FixEnforce2DKokkos<DeviceType>::post_force_item( int i ) const
{

  if (mask[i] & groupbit){
    v(i,2) = 0;
    x(i,2) = 0;
    f(i,2) = 0;
    // x(i,2) = 0; // Enforce2d does not set x[2] to zero either... :/
    v(i,2) = 0.0;
    f(i,2) = 0.0;

    // Add for omega, angmom, torque...
    if(omega_flag){
      omega(i,0) = 0.0;
      omega(i,1) = 0.0;
    }

    if(angmom_flag){
      angmom(i,0) = 0.0;
      angmom(i,1) = 0.0;
    }

    if(torque_flag){
      torque(i,0) = 0.0;
      torque(i,1) = 0.0;
    }
  }
}


+11 −4
Original line number Diff line number Diff line
@@ -37,8 +37,9 @@ class FixEnforce2DKokkos : public FixEnforce2D {
  void setup(int);
  void post_force(int);

  template <int omega_flag, int angmom_flag, int torque_flag>
  KOKKOS_INLINE_FUNCTION
  void post_force_item(int) const;
  void post_force_item(const int i) const;

  // void min_setup(int);       Kokkos does not support minimization (yet)
  // void min_post_force(int);  Kokkos does not support minimization (yet)
@@ -50,20 +51,26 @@ class FixEnforce2DKokkos : public FixEnforce2D {
  typename ArrayTypes<DeviceType>::t_v_array v;
  typename ArrayTypes<DeviceType>::t_f_array f;

  typename ArrayTypes<DeviceType>::t_v_array omega;
  typename ArrayTypes<DeviceType>::t_v_array angmom;
  typename ArrayTypes<DeviceType>::t_f_array torque;

  typename ArrayTypes<DeviceType>::t_int_1d mask;
};


template <class DeviceType>
template <class DeviceType, int omega_flag, int angmom_flag, int torque_flag>
struct FixEnforce2DKokkosPostForceFunctor {
  typedef DeviceType device_type;
  FixEnforce2DKokkos<DeviceType> c;

  FixEnforce2DKokkosPostForceFunctor(FixEnforce2DKokkos<DeviceType>* c_ptr):
    c(*c_ptr) {c.cleanup_copy();};

  KOKKOS_INLINE_FUNCTION
  void operator()(const int i) const {
    c.post_force_item(i);
    // c.template? Really C++?
    c.template post_force_item <omega_flag, angmom_flag, torque_flag>(i);
  }
};