Commit a86572f4 authored by Stan Moore's avatar Stan Moore
Browse files

Reduce memory churn in Kokkos package

parent f2c11727
Loading
Loading
Loading
Loading
+26 −28
Original line number Diff line number Diff line
@@ -46,7 +46,8 @@ CommKokkos::CommKokkos(LAMMPS *lmp) : CommBrick(lmp)
  if (sendlist) for (int i = 0; i < maxswap; i++) memory->destroy(sendlist[i]);
  memory->sfree(sendlist);
  sendlist = NULL;
  k_sendlist = ArrayTypes<LMPDeviceType>::tdual_int_2d();
  k_sendlist = DAT::tdual_int_2d();
  k_total_send = DAT::tdual_int_scalar("comm::k_total_send");

  // error check for disallow of OpenMP threads?

@@ -57,12 +58,12 @@ CommKokkos::CommKokkos(LAMMPS *lmp) : CommBrick(lmp)
  memory->destroy(buf_recv);
  buf_recv = NULL;

  k_exchange_sendlist = ArrayTypes<LMPDeviceType>::
  k_exchange_sendlist = DAT::
    tdual_int_1d("comm:k_exchange_sendlist",100);
  k_exchange_copylist = ArrayTypes<LMPDeviceType>::
  k_exchange_copylist = DAT::
    tdual_int_1d("comm:k_exchange_copylist",100);
  k_count = ArrayTypes<LMPDeviceType>::tdual_int_1d("comm:k_count",1);
  k_sendflag = ArrayTypes<LMPDeviceType>::tdual_int_1d("comm:k_sendflag",100);
  k_count = DAT::tdual_int_1d("comm:k_count",1);
  k_sendflag = DAT::tdual_int_1d("comm:k_sendflag",100);

  memory->destroy(maxsendlist);
  maxsendlist = NULL;
@@ -659,11 +660,11 @@ struct BuildBorderListFunctor {
  int iswap,maxsendlist;
  int nfirst,nlast,dim;
  typename AT::t_int_2d sendlist;
  typename AT::t_int_1d nsend;
  typename AT::t_int_scalar nsend;

  BuildBorderListFunctor(typename AT::tdual_x_array _x,
                         typename AT::tdual_int_2d _sendlist,
                         typename AT::tdual_int_1d _nsend,int _nfirst,
                         typename AT::tdual_int_scalar _nsend,int _nfirst,
                         int _nlast, int _dim,
                         X_FLOAT _lo, X_FLOAT _hi, int _iswap,
                         int _maxsendlist):
@@ -684,7 +685,7 @@ struct BuildBorderListFunctor {
    for (int i=teamstart + dev.team_rank(); i<teamend; i+=dev.team_size()) {
      if (x(i,dim) >= lo && x(i,dim) <= hi) mysend++;
    }
    const int my_store_pos = dev.team_scan(mysend,&nsend(0));
    const int my_store_pos = dev.team_scan(mysend,&nsend());

    if (my_store_pos+mysend < maxsendlist) {
    mysend = my_store_pos;
@@ -763,37 +764,34 @@ void CommKokkos::borders_device() {
      if (sendflag) {
        if (!bordergroup || ineed >= 2) {
          if (style == SINGLE) {
            typename ArrayTypes<DeviceType>::tdual_int_1d total_send("TS",1);
            total_send.h_view(0) = 0;
            if(exec_space == Device) {
              total_send.template modify<DeviceType>();
              total_send.template sync<LMPDeviceType>();
            }
            k_total_send.h_view() = 0;
            k_total_send.template modify<LMPHostType>();
            k_total_send.template sync<LMPDeviceType>();

            BuildBorderListFunctor<DeviceType> f(atomKK->k_x,k_sendlist,
                total_send,nfirst,nlast,dim,lo,hi,iswap,maxsendlist[iswap]);
                k_total_send,nfirst,nlast,dim,lo,hi,iswap,maxsendlist[iswap]);
            Kokkos::TeamPolicy<DeviceType> config((nlast-nfirst+127)/128,128);
            Kokkos::parallel_for(config,f);

            total_send.template modify<DeviceType>();
            total_send.template sync<LMPHostType>();
            k_total_send.template modify<DeviceType>();
            k_total_send.template sync<LMPHostType>();

            if(total_send.h_view(0) >= maxsendlist[iswap]) {
              grow_list(iswap,total_send.h_view(0));
            if(k_total_send.h_view() >= maxsendlist[iswap]) {
              grow_list(iswap,k_total_send.h_view());
              k_sendlist.modify<DeviceType>();
              total_send.h_view(0) = 0;
              k_total_send.h_view() = 0;
              if(exec_space == Device) {
                total_send.template modify<LMPHostType>();
                total_send.template sync<LMPDeviceType>();
                k_total_send.template modify<LMPHostType>();
                k_total_send.template sync<LMPDeviceType>();
              }
              BuildBorderListFunctor<DeviceType> f(atomKK->k_x,k_sendlist,
                  total_send,nfirst,nlast,dim,lo,hi,iswap,maxsendlist[iswap]);
                  k_total_send,nfirst,nlast,dim,lo,hi,iswap,maxsendlist[iswap]);
              Kokkos::TeamPolicy<DeviceType> config((nlast-nfirst+127)/128,128);
              Kokkos::parallel_for(config,f);
              total_send.template modify<DeviceType>();
              total_send.template sync<LMPHostType>();
              k_total_send.template modify<DeviceType>();
              k_total_send.template sync<LMPHostType>();
            }
            nsend = total_send.h_view(0);
            nsend = k_total_send.h_view();
          } else {
            error->all(FLERR,"Required border comm not yet "
                       "implemented with Kokkos");
@@ -961,7 +959,7 @@ void CommKokkos::grow_send_kokkos(int n, int flag, ExecutionSpace space)
    buf_send = k_buf_send.view<LMPHostType>().ptr_on_device();
  }
  else {
    k_buf_send = ArrayTypes<LMPDeviceType>::
    k_buf_send = DAT::
      tdual_xfloat_2d("comm:k_buf_send",maxsend_border,atom->avec->size_border);
    buf_send = k_buf_send.view<LMPHostType>().ptr_on_device();
  }
@@ -975,7 +973,7 @@ void CommKokkos::grow_recv_kokkos(int n, ExecutionSpace space)
{
  maxrecv = static_cast<int> (BUFFACTOR * n);
  int maxrecv_border = (maxrecv+BUFEXTRA+5)/atom->avec->size_border + 2;
  k_buf_recv = ArrayTypes<LMPDeviceType>::
  k_buf_recv = DAT::
    tdual_xfloat_2d("comm:k_buf_recv",maxrecv_border,atom->avec->size_border);
  buf_recv = k_buf_recv.view<LMPHostType>().ptr_on_device();
}
+1 −0
Original line number Diff line number Diff line
@@ -53,6 +53,7 @@ class CommKokkos : public CommBrick {

 protected:
  DAT::tdual_int_2d k_sendlist;
  DAT::tdual_int_scalar k_total_send;
  DAT::tdual_xfloat_2d k_buf_send,k_buf_recv;
  DAT::tdual_int_1d k_exchange_sendlist,k_exchange_copylist,k_sendflag;
  DAT::tdual_int_1d k_count;
+4 −2
Original line number Diff line number Diff line
@@ -88,12 +88,14 @@ void NPairKokkos<DeviceType,HALF_NEIGH,GHOST,TRI>::copy_stencil_info()

  int maxstencil = ns->get_maxstencil();

  if (maxstencil > k_stencil.dimension_0())
    k_stencil = DAT::tdual_int_1d("neighlist:stencil",maxstencil);
  for (int k = 0; k < maxstencil; k++)
    k_stencil.h_view(k) = ns->stencil[k];
    k_stencil.modify<LMPHostType>();
    k_stencil.sync<DeviceType>();
  if (GHOST) {
    if (maxstencil > k_stencilxyz.dimension_0())
      k_stencilxyz = DAT::tdual_int_1d_3("neighlist:stencilxyz",maxstencil);
    for (int k = 0; k < maxstencil; k++) {
      k_stencilxyz.h_view(k,0) = ns->stencilxyz[k][0];