Commit 4ea9dea1 authored by Stan Moore's avatar Stan Moore
Browse files

More Kokkos FFT refactor

parent ef30d0ed
Loading
Loading
Loading
Loading
+36 −36
Original line number Diff line number Diff line
@@ -85,10 +85,10 @@ FFT3dKokkos<DeviceType>::~FFT3dKokkos()
/* ---------------------------------------------------------------------- */

template<class DeviceType>
void FFT3dKokkos<DeviceType>::compute(typename AT::t_FFT_SCALAR_1d d_in, typename AT::t_FFT_SCALAR_1d d_out, int flag)
void FFT3dKokkos<DeviceType>::compute(typename FFT_AT::t_FFT_SCALAR_1d d_in, typename FFT_AT::t_FFT_SCALAR_1d d_out, int flag)
{
  typename AT::t_FFT_DATA_1d d_in_data((FFT_DATA*)d_in.data(),d_in.size()/2);
  typename AT::t_FFT_DATA_1d d_out_data((FFT_DATA*)d_out.data(),d_out.size()/2);
  typename FFT_AT::t_FFT_DATA_1d d_in_data((FFT_DATA*)d_in.data(),d_in.size()/2);
  typename FFT_AT::t_FFT_DATA_1d d_out_data((FFT_DATA*)d_out.data(),d_out.size()/2);

  fft_3d_kokkos(d_in_data,d_out_data,flag,plan);
}
@@ -96,9 +96,9 @@ void FFT3dKokkos<DeviceType>::compute(typename AT::t_FFT_SCALAR_1d d_in, typenam
/* ---------------------------------------------------------------------- */

template<class DeviceType>
void FFT3dKokkos<DeviceType>::timing1d(typename AT::t_FFT_SCALAR_1d d_in, int nsize, int flag)
void FFT3dKokkos<DeviceType>::timing1d(typename FFT_AT::t_FFT_SCALAR_1d d_in, int nsize, int flag)
{
  typename AT::t_FFT_DATA_1d d_in_data((FFT_DATA*)d_in.data(),d_in.size()/2);
  typename FFT_AT::t_FFT_DATA_1d d_in_data((FFT_DATA*)d_in.data(),d_in.size()/2);

  fft_3d_1d_only_kokkos(d_in_data,nsize,flag,plan);
}
@@ -137,11 +137,11 @@ template<class DeviceType>
struct norm_functor {
public:
  typedef DeviceType device_type;
  typedef ArrayTypes<DeviceType> AT;
  typename AT::t_FFT_DATA_1d_um d_out;
  typedef FFTArrayTypes<DeviceType> FFT_AT;
  typename FFT_AT::t_FFT_DATA_1d_um d_out;
  int norm;

  norm_functor(typename AT::t_FFT_DATA_1d &d_out_, int norm_):
  norm_functor(typename FFT_AT::t_FFT_DATA_1d &d_out_, int norm_):
    d_out(d_out_),norm(norm_) {}

  KOKKOS_INLINE_FUNCTION
@@ -153,8 +153,8 @@ public:
#elif defined(FFT_MKL)
    d_out(i) *= norm;
#else // FFT_KISS
    d_out(i,0) *= norm;
    d_out(i,1) *= norm;
    d_out(i).re *= norm;
    d_out(i).im *= norm;
#endif
  }
};
@@ -164,14 +164,14 @@ template<class DeviceType>
struct kiss_fft_functor {
public:
  typedef DeviceType device_type;
  typedef ArrayTypes<DeviceType> AT;
  typename AT::t_FFT_DATA_1d_um d_data,d_tmp;
  typedef FFTArrayTypes<DeviceType> FFT_AT;
  typename FFT_AT::t_FFT_DATA_1d_um d_data,d_tmp;
  kiss_fft_state_kokkos<DeviceType> st;
  int length;

  kiss_fft_functor() {}

  kiss_fft_functor(typename AT::t_FFT_DATA_1d &d_data_,typename AT::t_FFT_DATA_1d &d_tmp_, kiss_fft_state_kokkos<DeviceType> &st_, int length_):
  kiss_fft_functor(typename FFT_AT::t_FFT_DATA_1d &d_data_,typename FFT_AT::t_FFT_DATA_1d &d_tmp_, kiss_fft_state_kokkos<DeviceType> &st_, int length_):
    d_data(d_data_),
    d_tmp(d_tmp_),
    st(st_)
@@ -188,11 +188,11 @@ public:
#endif

template<class DeviceType>
void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typename AT::t_FFT_DATA_1d d_out, int flag, struct fft_plan_3d_kokkos<DeviceType> *plan)
void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename FFT_AT::t_FFT_DATA_1d d_in, typename FFT_AT::t_FFT_DATA_1d d_out, int flag, struct fft_plan_3d_kokkos<DeviceType> *plan)
{
  int total,length;
  typename AT::t_FFT_DATA_1d d_data,d_copy;
  typename AT::t_FFT_SCALAR_1d d_in_scalar,d_data_scalar,d_out_scalar,d_copy_scalar,d_scratch_scalar;
  typename FFT_AT::t_FFT_DATA_1d d_data,d_copy;
  typename FFT_AT::t_FFT_SCALAR_1d d_in_scalar,d_data_scalar,d_out_scalar,d_copy_scalar,d_scratch_scalar;

  // pre-remap to prepare for 1st FFTs if needed
  // copy = loc for remap result
@@ -201,9 +201,9 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
    if (plan->pre_target == 0) d_copy = d_out;
    else d_copy = plan->d_copy;

     d_in_scalar = typename AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_in.data(),d_in.size()*2);
     d_copy_scalar = typename AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_copy.data(),d_copy.size()*2);
     d_scratch_scalar = typename AT::t_FFT_SCALAR_1d((FFT_SCALAR*)plan->d_scratch.data(),plan->d_scratch.size()*2);
     d_in_scalar = typename FFT_AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_in.data(),d_in.size()*2);
     d_copy_scalar = typename FFT_AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_copy.data(),d_copy.size()*2);
     d_scratch_scalar = typename FFT_AT::t_FFT_SCALAR_1d((FFT_SCALAR*)plan->d_scratch.data(),plan->d_scratch.size()*2);

    remapKK->remap_3d_kokkos(d_in_scalar, d_copy_scalar,
             d_scratch_scalar, plan->pre_plan);
@@ -229,8 +229,8 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
  #elif defined(FFT_CUFFT)
    cufftExec(plan->plan_fast,d_data.data(),d_data.data(),flag);
  #else
    typename AT::t_FFT_DATA_1d d_tmp =
     typename AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0());
    typename FFT_AT::t_FFT_DATA_1d d_tmp =
     typename FFT_AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0());
    kiss_fft_functor<DeviceType> f;
    if (flag == -1)
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_fast_forward,length);
@@ -238,7 +238,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_fast_backward,length);
    Kokkos::parallel_for(total/length,f);
    d_data = d_tmp;
    d_tmp = typename AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0());
    d_tmp = typename FFT_AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0());
  #endif


@@ -248,9 +248,9 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
  if (plan->mid1_target == 0) d_copy = d_out;
  else d_copy = plan->d_copy;

  d_data_scalar = typename AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_data.data(),d_data.size()*2);
  d_copy_scalar = typename AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_copy.data(),d_copy.size()*2);
  d_scratch_scalar = typename AT::t_FFT_SCALAR_1d((FFT_SCALAR*)plan->d_scratch.data(),plan->d_scratch.size()*2);
  d_data_scalar = typename FFT_AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_data.data(),d_data.size()*2);
  d_copy_scalar = typename FFT_AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_copy.data(),d_copy.size()*2);
  d_scratch_scalar = typename FFT_AT::t_FFT_SCALAR_1d((FFT_SCALAR*)plan->d_scratch.data(),plan->d_scratch.size()*2);

  remapKK->remap_3d_kokkos(d_data_scalar, d_copy_scalar,
           d_scratch_scalar, plan->mid1_plan);
@@ -281,7 +281,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_mid_backward,length);
    Kokkos::parallel_for(total/length,f);
    d_data = d_tmp;
    d_tmp = typename AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0());
    d_tmp = typename FFT_AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0());
  #endif

  // 2nd mid-remap to prepare for 3rd FFTs
@@ -290,9 +290,9 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
  if (plan->mid2_target == 0) d_copy = d_out;
  else d_copy = plan->d_copy;

  d_data_scalar = typename AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_data.data(),d_data.size()*2);
  d_copy_scalar = typename AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_copy.data(),d_copy.size()*2);
  d_scratch_scalar = typename AT::t_FFT_SCALAR_1d((FFT_SCALAR*)plan->d_scratch.data(),plan->d_scratch.size()*2);
  d_data_scalar = typename FFT_AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_data.data(),d_data.size()*2);
  d_copy_scalar = typename FFT_AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_copy.data(),d_copy.size()*2);
  d_scratch_scalar = typename FFT_AT::t_FFT_SCALAR_1d((FFT_SCALAR*)plan->d_scratch.data(),plan->d_scratch.size()*2);

  remapKK->remap_3d_kokkos(d_data_scalar, d_copy_scalar,
           d_scratch_scalar, plan->mid2_plan);
@@ -330,9 +330,9 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
  // destination is always out

  if (plan->post_plan) {
    d_data_scalar = typename AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_data.data(),d_data.size()*2);
    d_out_scalar = typename AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_out.data(),d_out.size()*2);
    d_scratch_scalar = typename AT::t_FFT_SCALAR_1d((FFT_SCALAR*)plan->d_scratch.data(),plan->d_scratch.size()*2);
    d_data_scalar = typename FFT_AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_data.data(),d_data.size()*2);
    d_out_scalar = typename FFT_AT::t_FFT_SCALAR_1d((FFT_SCALAR*)d_out.data(),d_out.size()*2);
    d_scratch_scalar = typename FFT_AT::t_FFT_SCALAR_1d((FFT_SCALAR*)plan->d_scratch.data(),plan->d_scratch.size()*2);

    remapKK->remap_3d_kokkos(d_data_scalar, d_out_scalar,
             d_scratch_scalar, plan->post_plan);
@@ -588,11 +588,11 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
  *nbuf = copy_size + scratch_size;

  if (copy_size) {
    plan->d_copy = typename AT::t_FFT_DATA_1d("fft3d:copy",copy_size);
    plan->d_copy = typename FFT_AT::t_FFT_DATA_1d("fft3d:copy",copy_size);
  }

  if (scratch_size) {
    plan->d_scratch = typename AT::t_FFT_DATA_1d("fft3d:scratch",scratch_size);
    plan->d_scratch = typename FFT_AT::t_FFT_DATA_1d("fft3d:scratch",scratch_size);
  }

  // system specific pre-computation of 1d FFT coeffs
@@ -810,7 +810,7 @@ void FFT3dKokkos<DeviceType>::bifactor(int n, int *factor1, int *factor2)
------------------------------------------------------------------------- */

template<class DeviceType>
void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename AT::t_FFT_DATA_1d d_data, int nsize, int flag,
void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename FFT_AT::t_FFT_DATA_1d d_data, int nsize, int flag,
                    struct fft_plan_3d_kokkos<DeviceType> *plan)
{
  // total = size of data needed in each dim
@@ -864,7 +864,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename AT::t_FFT_DATA_1d d
  cufftExec(plan->plan_slow,d_data.data(),d_data.data(),flag);
#else
  kiss_fft_functor<DeviceType> f;
  typename AT::t_FFT_DATA_1d d_tmp = typename AT::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0());
  typename FFT_AT::t_FFT_DATA_1d d_tmp = typename FFT_AT::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0());
  if (flag == -1) {
    f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_fast_forward,length1);
    Kokkos::parallel_for(total1/length1,f);
+8 −8
Original line number Diff line number Diff line
@@ -27,14 +27,14 @@ namespace LAMMPS_NS {
template<class DeviceType>
struct fft_plan_3d_kokkos {
  typedef DeviceType device_type;
  typedef FFTArrayTypes<DeviceType> AT;
  typedef FFTArrayTypes<DeviceType> FFT_AT;

  struct remap_plan_3d_kokkos<DeviceType> *pre_plan;       // remap from input -> 1st FFTs
  struct remap_plan_3d_kokkos<DeviceType> *mid1_plan;      // remap from 1st -> 2nd FFTs
  struct remap_plan_3d_kokkos<DeviceType> *mid2_plan;      // remap from 2nd -> 3rd FFTs
  struct remap_plan_3d_kokkos<DeviceType> *post_plan;      // remap from 3rd FFTs -> output
  typename AT::t_FFT_DATA_1d d_copy;                   // memory for remap results (if needed)
  typename AT::t_FFT_DATA_1d d_scratch;                // scratch space for remaps
  typename FFT_AT::t_FFT_DATA_1d d_copy;                   // memory for remap results (if needed)
  typename FFT_AT::t_FFT_DATA_1d d_scratch;                // scratch space for remaps
  int total1,total2,total3;         // # of 1st,2nd,3rd FFTs (times length)
  int length1,length2,length3;      // length of 1st,2nd,3rd FFTs
  int pre_target;                   // where to put remap results
@@ -73,14 +73,14 @@ template<class DeviceType>
class FFT3dKokkos : protected Pointers {
 public:
  typedef DeviceType device_type;
  typedef FFTArrayTypes<DeviceType> AT;
  typedef FFTArrayTypes<DeviceType> FFT_AT;

  FFT3dKokkos(class LAMMPS *, MPI_Comm,
        int,int,int,int,int,int,int,int,int,int,int,int,int,int,int,
        int,int,int *,int);
  ~FFT3dKokkos();
  void compute(typename AT::t_FFT_SCALAR_1d, typename AT::t_FFT_SCALAR_1d, int);
  void timing1d(typename AT::t_FFT_SCALAR_1d, int, int);
  void compute(typename FFT_AT::t_FFT_SCALAR_1d, typename FFT_AT::t_FFT_SCALAR_1d, int);
  void timing1d(typename FFT_AT::t_FFT_SCALAR_1d, int, int);

 private:
  struct fft_plan_3d_kokkos<DeviceType> *plan;
@@ -90,7 +90,7 @@ class FFT3dKokkos : protected Pointers {
  KissFFTKokkos<DeviceType> *kissfftKK;
#endif

  void fft_3d_kokkos(typename AT::t_FFT_DATA_1d, typename AT::t_FFT_DATA_1d, int, struct fft_plan_3d_kokkos<DeviceType> *);
  void fft_3d_kokkos(typename FFT_AT::t_FFT_DATA_1d, typename FFT_AT::t_FFT_DATA_1d, int, struct fft_plan_3d_kokkos<DeviceType> *);

  struct fft_plan_3d_kokkos<DeviceType> *fft_3d_create_plan_kokkos(MPI_Comm, int, int, int,
                                         int, int, int, int, int,
@@ -99,7 +99,7 @@ class FFT3dKokkos : protected Pointers {

  void fft_3d_destroy_plan_kokkos(struct fft_plan_3d_kokkos<DeviceType> *);

  void fft_3d_1d_only_kokkos(typename AT::t_FFT_DATA_1d, int, int, struct fft_plan_3d_kokkos<DeviceType> *);
  void fft_3d_1d_only_kokkos(typename FFT_AT::t_FFT_DATA_1d, int, int, struct fft_plan_3d_kokkos<DeviceType> *);

  void bifactor(int, int *, int *);
};
+12 −1
Original line number Diff line number Diff line
@@ -106,7 +106,6 @@ typedef double FFT_SCALAR;
    typedef cufftDoubleComplex FFT_DATA;
  #endif
#else
  #include "kissfft_kokkos.h"
  #if defined(FFT_SINGLE)
    #define kiss_fft_scalar float
  #else
@@ -121,6 +120,7 @@ typedef double FFT_SCALAR;
  #endif
#endif

#include "kokkos_type.h"

template <class DeviceType>
struct FFTArrayTypes;
@@ -152,6 +152,8 @@ typedef Kokkos::
typedef tdual_int_64::t_dev t_int_64;
typedef tdual_int_64::t_dev_um t_int_64_um;

};

#ifdef KOKKOS_ENABLE_CUDA
template <>
struct FFTArrayTypes<LMPHostType> {
@@ -185,4 +187,13 @@ typedef tdual_int_64::t_host_um t_int_64_um;
};
#endif

typedef struct FFTArrayTypes<LMPDeviceType> FFT_DAT;
typedef struct FFTArrayTypes<LMPHostType> FFT_HAT;


#if defined(FFT_KISSFFT)
#include "kissfft_kokkos.h"
#endif


#endif
+3 −3
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@
#include "kspace.h"
#include "memory_kokkos.h"
#include "error.h"
#include "kokkos_base.h"
#include "kokkos_base_fft.h"
#include "kokkos.h"

using namespace LAMMPS_NS;
@@ -502,9 +502,9 @@ void GridCommKokkos<DeviceType>::setup()
  }
  nbuf *= MAX(nforward,nreverse);
  //memory->create(buf1,nbuf,"Commgrid:buf1");
  k_buf1 = DAT::tdual_FFT_SCALAR_1d("Commgrid:buf1",nbuf);
  k_buf1 = FFT_DAT::tdual_FFT_SCALAR_1d("Commgrid:buf1",nbuf);
  //memory->create(buf2,nbuf,"Commgrid:buf2");
  k_buf2 = DAT::tdual_FFT_SCALAR_1d("Commgrid:buf2",nbuf);
  k_buf2 = FFT_DAT::tdual_FFT_SCALAR_1d("Commgrid:buf2",nbuf);
}

/* ----------------------------------------------------------------------
+4 −2
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@

#include "pointers.h"
#include "kokkos_type.h"
#include "fftdata_kokkos.h"

#ifdef FFT_SINGLE
typedef float FFT_SCALAR;
@@ -32,6 +33,7 @@ class GridCommKokkos : protected Pointers {
 public:
  typedef DeviceType device_type;
  typedef ArrayTypes<DeviceType> AT;
  typedef FFTArrayTypes<DeviceType> FFT_AT;

  GridCommKokkos(class LAMMPS *, MPI_Comm, int, int,
           int, int, int, int, int, int,
@@ -70,8 +72,8 @@ class GridCommKokkos : protected Pointers {

  int nbuf;
  //FFT_SCALAR *buf1,*buf2;
  DAT::tdual_FFT_SCALAR_1d k_buf1;
  DAT::tdual_FFT_SCALAR_1d k_buf2;
  FFT_DAT::tdual_FFT_SCALAR_1d k_buf1;
  FFT_DAT::tdual_FFT_SCALAR_1d k_buf2;

  struct Swap {
    int sendproc;       // proc to send to for forward comm
Loading