Commit d1e751d7 authored by Stan Moore's avatar Stan Moore
Browse files

Fix thread safety issue in fused forward comm

parent 08273c40
Loading
Loading
Loading
Loading
+22 −16
Original line number Diff line number Diff line
@@ -281,6 +281,7 @@ struct AtomVecKokkos_PackCommSelfFused {
  typename ArrayTypes<DeviceType>::t_int_1d_const _pbc_flag;
  typename ArrayTypes<DeviceType>::t_int_1d_const _firstrecv;
  typename ArrayTypes<DeviceType>::t_int_1d_const _sendnum_scan;
  typename ArrayTypes<DeviceType>::t_int_1d_const _g2l;
  X_FLOAT _xprd,_yprd,_zprd,_xy,_xz,_yz;

  AtomVecKokkos_PackCommSelfFused(
@@ -290,6 +291,7 @@ struct AtomVecKokkos_PackCommSelfFused {
      const typename DAT::tdual_int_1d &pbc_flag,
      const typename DAT::tdual_int_1d &firstrecv,
      const typename DAT::tdual_int_1d &sendnum_scan,
      const typename DAT::tdual_int_1d &g2l,
      const X_FLOAT &xprd, const X_FLOAT &yprd, const X_FLOAT &zprd,
      const X_FLOAT &xy, const X_FLOAT &xz, const X_FLOAT &yz):
      _x(x.view<DeviceType>()),_xw(x.view<DeviceType>()),
@@ -298,6 +300,7 @@ struct AtomVecKokkos_PackCommSelfFused {
      _pbc_flag(pbc_flag.view<DeviceType>()),
      _firstrecv(firstrecv.view<DeviceType>()),
      _sendnum_scan(sendnum_scan.view<DeviceType>()),
      _g2l(g2l.view<DeviceType>()),
      _xprd(xprd),_yprd(yprd),_zprd(zprd),
      _xy(xy),_xz(xz),_yz(yz) {};

@@ -310,42 +313,45 @@ struct AtomVecKokkos_PackCommSelfFused {
    if (iswap > 0)
      i = ii - _sendnum_scan[iswap-1];
      const int _nfirst = _firstrecv[iswap];
      const int nlocal = _firstrecv[0];

      int j = _list(iswap,i);
      if (j >= nlocal)
        j = _g2l(j-nlocal);

      const int j = _list(iswap,i);
      if (_pbc_flag(iswap) == 0) {
      if (_pbc_flag(ii) == 0) {
          _xw(i+_nfirst,0) = _x(j,0);
          _xw(i+_nfirst,1) = _x(j,1);
          _xw(i+_nfirst,2) = _x(j,2);
      } else {
        if (TRICLINIC == 0) {
          _xw(i+_nfirst,0) = _x(j,0) + _pbc(iswap,0)*_xprd;
          _xw(i+_nfirst,1) = _x(j,1) + _pbc(iswap,1)*_yprd;
          _xw(i+_nfirst,2) = _x(j,2) + _pbc(iswap,2)*_zprd;
          _xw(i+_nfirst,0) = _x(j,0) + _pbc(ii,0)*_xprd;
          _xw(i+_nfirst,1) = _x(j,1) + _pbc(ii,1)*_yprd;
          _xw(i+_nfirst,2) = _x(j,2) + _pbc(ii,2)*_zprd;
        } else {
          _xw(i+_nfirst,0) = _x(j,0) + _pbc(iswap,0)*_xprd + _pbc(iswap,5)*_xy + _pbc(iswap,4)*_xz;
          _xw(i+_nfirst,1) = _x(j,1) + _pbc(iswap,1)*_yprd + _pbc(iswap,3)*_yz;
          _xw(i+_nfirst,2) = _x(j,2) + _pbc(iswap,2)*_zprd;
          _xw(i+_nfirst,0) = _x(j,0) + _pbc(ii,0)*_xprd + _pbc(ii,5)*_xy + _pbc(ii,4)*_xz;
          _xw(i+_nfirst,1) = _x(j,1) + _pbc(ii,1)*_yprd + _pbc(ii,3)*_yz;
          _xw(i+_nfirst,2) = _x(j,2) + _pbc(ii,2)*_zprd;
        }
      }

  }
};

/* ---------------------------------------------------------------------- */

int AtomVecKokkos::pack_comm_self_fused(const int &n, const DAT::tdual_int_2d &list, const DAT::tdual_int_1d &sendnum_scan,
                                         const DAT::tdual_int_1d &firstrecv, const DAT::tdual_int_1d &pbc_flag, const DAT::tdual_int_2d &pbc) {
                                         const DAT::tdual_int_1d &firstrecv, const DAT::tdual_int_1d &pbc_flag, const DAT::tdual_int_2d &pbc,
                                         const DAT::tdual_int_1d &g2l) {
  if(commKK->forward_comm_on_host) {
    sync(Host,X_MASK);
    modified(Host,X_MASK);
    if(domain->triclinic) {
    struct AtomVecKokkos_PackCommSelfFused<LMPHostType,1> f(atomKK->k_x,list,pbc,pbc_flag,firstrecv,sendnum_scan,
    struct AtomVecKokkos_PackCommSelfFused<LMPHostType,1> f(atomKK->k_x,list,pbc,pbc_flag,firstrecv,sendnum_scan,g2l,
        domain->xprd,domain->yprd,domain->zprd,
        domain->xy,domain->xz,domain->yz);
    Kokkos::parallel_for(n,f);
    } else {
    struct AtomVecKokkos_PackCommSelfFused<LMPHostType,0> f(atomKK->k_x,list,pbc,pbc_flag,firstrecv,sendnum_scan,
    struct AtomVecKokkos_PackCommSelfFused<LMPHostType,0> f(atomKK->k_x,list,pbc,pbc_flag,firstrecv,sendnum_scan,g2l,
        domain->xprd,domain->yprd,domain->zprd,
        domain->xy,domain->xz,domain->yz);
    Kokkos::parallel_for(n,f);
@@ -354,12 +360,12 @@ int AtomVecKokkos::pack_comm_self_fused(const int &n, const DAT::tdual_int_2d &l
    sync(Device,X_MASK);
    modified(Device,X_MASK);
    if(domain->triclinic) {
    struct AtomVecKokkos_PackCommSelfFused<LMPDeviceType,1> f(atomKK->k_x,list,pbc,pbc_flag,firstrecv,sendnum_scan,
    struct AtomVecKokkos_PackCommSelfFused<LMPDeviceType,1> f(atomKK->k_x,list,pbc,pbc_flag,firstrecv,sendnum_scan,g2l,
        domain->xprd,domain->yprd,domain->zprd,
        domain->xy,domain->xz,domain->yz);
    Kokkos::parallel_for(n,f);
    } else {
    struct AtomVecKokkos_PackCommSelfFused<LMPDeviceType,0> f(atomKK->k_x,list,pbc,pbc_flag,firstrecv,sendnum_scan,
    struct AtomVecKokkos_PackCommSelfFused<LMPDeviceType,0> f(atomKK->k_x,list,pbc,pbc_flag,firstrecv,sendnum_scan,g2l,
        domain->xprd,domain->yprd,domain->zprd,
        domain->xy,domain->xz,domain->yz);
    Kokkos::parallel_for(n,f);
+2 −1
Original line number Diff line number Diff line
@@ -56,7 +56,8 @@ class AtomVecKokkos : public AtomVec {
                         const DAT::tdual_int_1d &sendnum_scan,
                         const DAT::tdual_int_1d &firstrecv,
                         const DAT::tdual_int_1d &pbc_flag,
                         const DAT::tdual_int_2d &pbc);
                         const DAT::tdual_int_2d &pbc,
                         const DAT::tdual_int_1d &g2l);

  virtual int
    pack_comm_kokkos(const int &n, const DAT::tdual_int_2d &list,
+46 −28
Original line number Diff line number Diff line
@@ -140,31 +140,6 @@ void CommKokkos::init()
    forward_comm_classic = true;
}

/* ---------------------------------------------------------------------- */

void CommKokkos::setup()
{
  CommBrick::setup();

  k_pbc_flag = DAT::tdual_int_1d("comm:pbc_flag",nswap);
  k_pbc = DAT::tdual_int_2d("comm:pbc",nswap,6);

  for (int iswap = 0; iswap < nswap; iswap++) {
    k_pbc_flag.h_view[iswap] = pbc_flag[iswap];
    k_pbc.h_view(iswap,0) = pbc[iswap][0];
    k_pbc.h_view(iswap,1) = pbc[iswap][1];
    k_pbc.h_view(iswap,2) = pbc[iswap][2];
    k_pbc.h_view(iswap,3) = pbc[iswap][3];
    k_pbc.h_view(iswap,4) = pbc[iswap][4];
    k_pbc.h_view(iswap,5) = pbc[iswap][5];
  }
  k_pbc_flag.modify<LMPHostType>();
  k_pbc.modify<LMPHostType>();

  k_pbc_flag.sync<LMPDeviceType>();
  k_pbc.sync<LMPDeviceType>();
}

/* ----------------------------------------------------------------------
   forward communication of atom coords every timestep
   other per-atom attributes may also be sent via pack/unpack routines
@@ -211,11 +186,12 @@ void CommKokkos::forward_comm_device(int dummy)
  k_sendlist.sync<DeviceType>();
  atomKK->sync(ExecutionSpaceFromDevice<DeviceType>::space,X_MASK);

  if (comm->nprocs == 1) {
  if (comm->nprocs == 1 && !ghost_velocity) {
    k_swap.sync<DeviceType>();
    k_swap2.sync<DeviceType>();
    k_pbc.sync<DeviceType>();
    n = avec->pack_comm_self_fused(totalsend,k_sendlist,k_sendnum_scan,
                    k_firstrecv,k_pbc_flag,k_pbc);
                    k_firstrecv,k_pbc_flag,k_pbc,k_g2l);
  } else {

  for (int iswap = 0; iswap < nswap; iswap++) {
@@ -783,7 +759,7 @@ void CommKokkos::borders()
    atomKK->modified(Host,ALL_MASK);
  }

  if (comm->nprocs == 1 && !forward_comm_classic)
  if (comm->nprocs == 1 && !ghost_velocity && !forward_comm_classic)
    copy_swap_info();
}

@@ -1092,7 +1068,49 @@ void CommKokkos::copy_swap_info()
  }
  totalsend = scan;

  int* list = NULL;
  memory->create(list,totalsend,"comm:list");
  if (totalsend > k_pbc.extent(0)) {
    k_pbc = DAT::tdual_int_2d("comm:pbc",totalsend,6);
    k_swap2 = DAT::tdual_int_2d("comm:swap2",2,totalsend);
    k_pbc_flag = Kokkos::subview(k_swap2,0,Kokkos::ALL);
    k_g2l = Kokkos::subview(k_swap2,1,Kokkos::ALL);
  }

  // create map of ghost atoms to local atoms
  // store periodic boundary transform from local to ghost

  for (int iswap = 0; iswap < nswap; iswap++) {
    for (int i = 0; i < sendnum[iswap]; i++) {
      int source = sendlist[iswap][i] - atom->nlocal;
      int dest = firstrecv[iswap] + i - atom->nlocal;
      k_pbc_flag.h_view(dest) = pbc_flag[iswap];
      k_pbc.h_view(dest,0) = pbc[iswap][0];
      k_pbc.h_view(dest,1) = pbc[iswap][1];
      k_pbc.h_view(dest,2) = pbc[iswap][2];
      k_pbc.h_view(dest,3) = pbc[iswap][3];
      k_pbc.h_view(dest,4) = pbc[iswap][4];
      k_pbc.h_view(dest,5) = pbc[iswap][5];
      k_g2l.h_view(dest) = atom->nlocal + source;

      if (source >= 0) {
        k_pbc_flag.h_view(dest) = k_pbc_flag.h_view(dest) || k_pbc_flag.h_view(source);
        k_pbc.h_view(dest,0) += k_pbc.h_view(source,0);
        k_pbc.h_view(dest,1) += k_pbc.h_view(source,1);
        k_pbc.h_view(dest,2) += k_pbc.h_view(source,2);
        k_pbc.h_view(dest,3) += k_pbc.h_view(source,3);
        k_pbc.h_view(dest,4) += k_pbc.h_view(source,4);
        k_pbc.h_view(dest,5) += k_pbc.h_view(source,5);
        k_g2l.h_view(dest) = k_g2l.h_view(source);
      }
    }
  }

  k_swap.modify<LMPHostType>();
  k_swap2.modify<LMPHostType>();
  k_pbc.modify<LMPHostType>();

  memory->destroy(list);
}

/* ----------------------------------------------------------------------
+2 −1
Original line number Diff line number Diff line
@@ -33,7 +33,6 @@ class CommKokkos : public CommBrick {
  CommKokkos(class LAMMPS *);
  ~CommKokkos();
  void init();
  void setup();

  void forward_comm(int dummy = 0);    // forward comm of atom coords
  void reverse_comm();                 // reverse comm of atom coords
@@ -66,8 +65,10 @@ class CommKokkos : public CommBrick {
  //double *buf_recv;                 // recv buffer for all comm

  DAT::tdual_int_2d k_swap;
  DAT::tdual_int_2d k_swap2;
  DAT::tdual_int_2d k_pbc;
  DAT::tdual_int_1d k_pbc_flag;
  DAT::tdual_int_1d k_g2l;
  DAT::tdual_int_1d k_firstrecv;
  DAT::tdual_int_1d k_sendnum_scan;
  int totalsend;