Commit f8e2543c authored by Stan Moore's avatar Stan Moore
Browse files

Move FFT data types out of kokkos_type.h

parent 1851a9f7
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -160,6 +160,7 @@ action kissfft_kokkos.h kissfft.h
action kokkos.cpp
action kokkos.h
action kokkos_base.h
action kokkos_base_fft.h fft3d.h
action kokkos_few.h
action kokkos_type.h
action memory_kokkos.h
+0 −1
Original line number Diff line number Diff line
@@ -21,7 +21,6 @@
#include <math.h>
#include "fft3d_kokkos.h"
#include "remap_kokkos.h"
#include "kokkos_type.h"
#include "error.h"
#include "kokkos.h"

+2 −4
Original line number Diff line number Diff line
@@ -15,11 +15,9 @@
#define LMP_FFT3D_KOKKOS_H

#include "pointers.h"
#include "kokkos_type.h"
#include "remap_kokkos.h"
#include "fftdata_kokkos.h"


namespace LAMMPS_NS {

// -------------------------------------------------------------------------
@@ -29,7 +27,7 @@ namespace LAMMPS_NS {
template<class DeviceType>
struct fft_plan_3d_kokkos {
  typedef DeviceType device_type;
  typedef ArrayTypes<DeviceType> AT;
  typedef FFTArrayTypes<DeviceType> AT;

  struct remap_plan_3d_kokkos<DeviceType> *pre_plan;       // remap from input -> 1st FFTs
  struct remap_plan_3d_kokkos<DeviceType> *mid1_plan;      // remap from 1st -> 2nd FFTs
@@ -75,7 +73,7 @@ template<class DeviceType>
class FFT3dKokkos : protected Pointers {
 public:
  typedef DeviceType device_type;
  typedef ArrayTypes<DeviceType> AT;
  typedef FFTArrayTypes<DeviceType> AT;

  FFT3dKokkos(class LAMMPS *, MPI_Comm,
        int,int,int,int,int,int,int,int,int,int,int,int,int,int,int,
+63 −0
Original line number Diff line number Diff line
@@ -121,4 +121,67 @@ typedef double FFT_SCALAR;
  #endif
#endif


template <class DeviceType>
struct FFTArrayTypes;

template <>
struct FFTArrayTypes<LMPDeviceType> {

typedef Kokkos::
  DualView<FFT_SCALAR*, Kokkos::LayoutRight, LMPDeviceType> tdual_FFT_SCALAR_1d;
typedef tdual_FFT_SCALAR_1d::t_dev t_FFT_SCALAR_1d;
typedef tdual_FFT_SCALAR_1d::t_dev_um t_FFT_SCALAR_1d_um;

typedef Kokkos::DualView<FFT_SCALAR**,Kokkos::LayoutRight,LMPDeviceType> tdual_FFT_SCALAR_2d;
typedef tdual_FFT_SCALAR_2d::t_dev t_FFT_SCALAR_2d;

typedef Kokkos::DualView<FFT_SCALAR**[3],Kokkos::LayoutRight,LMPDeviceType> tdual_FFT_SCALAR_2d_3;
typedef tdual_FFT_SCALAR_2d_3::t_dev t_FFT_SCALAR_2d_3;

typedef Kokkos::DualView<FFT_SCALAR***,Kokkos::LayoutRight,LMPDeviceType> tdual_FFT_SCALAR_3d;
typedef tdual_FFT_SCALAR_3d::t_dev t_FFT_SCALAR_3d;

typedef Kokkos::
  DualView<FFT_DATA*, Kokkos::LayoutRight, LMPDeviceType> tdual_FFT_DATA_1d;
typedef tdual_FFT_DATA_1d::t_dev t_FFT_DATA_1d;
typedef tdual_FFT_DATA_1d::t_dev_um t_FFT_DATA_1d_um;

typedef Kokkos::
  DualView<int*, LMPDeviceType::array_layout, LMPDeviceType> tdual_int_64;
typedef tdual_int_64::t_dev t_int_64;
typedef tdual_int_64::t_dev_um t_int_64_um;

#ifdef KOKKOS_ENABLE_CUDA
template <>
struct FFTArrayTypes<LMPHostType> {

//Kspace

typedef Kokkos::
  DualView<FFT_SCALAR*, Kokkos::LayoutRight, LMPDeviceType> tdual_FFT_SCALAR_1d;
typedef tdual_FFT_SCALAR_1d::t_host t_FFT_SCALAR_1d;
typedef tdual_FFT_SCALAR_1d::t_host_um t_FFT_SCALAR_1d_um;

typedef Kokkos::DualView<FFT_SCALAR**,Kokkos::LayoutRight,LMPDeviceType> tdual_FFT_SCALAR_2d;
typedef tdual_FFT_SCALAR_2d::t_host t_FFT_SCALAR_2d;

typedef Kokkos::DualView<FFT_SCALAR**[3],Kokkos::LayoutRight,LMPDeviceType> tdual_FFT_SCALAR_2d_3;
typedef tdual_FFT_SCALAR_2d_3::t_host t_FFT_SCALAR_2d_3;

typedef Kokkos::DualView<FFT_SCALAR***,Kokkos::LayoutRight,LMPDeviceType> tdual_FFT_SCALAR_3d;
typedef tdual_FFT_SCALAR_3d::t_host t_FFT_SCALAR_3d;

typedef Kokkos::
  DualView<FFT_DATA*, Kokkos::LayoutRight, LMPDeviceType> tdual_FFT_DATA_1d;
typedef tdual_FFT_DATA_1d::t_host t_FFT_DATA_1d;
typedef tdual_FFT_DATA_1d::t_host_um t_FFT_DATA_1d_um;

typedef Kokkos::
  DualView<int*, LMPDeviceType::array_layout, LMPDeviceType> tdual_int_64;
typedef tdual_int_64::t_host t_int_64;
typedef tdual_int_64::t_host_um t_int_64_um;

};

#endif
+2 −2
Original line number Diff line number Diff line
@@ -517,7 +517,7 @@ void GridCommKokkos<DeviceType>::forward_comm(KSpace *kspace, int which)
  k_packlist.sync<DeviceType>();
  k_unpacklist.sync<DeviceType>();

  KokkosBase* kspaceKKBase = dynamic_cast<KokkosBase*>(kspace);
  KokkosBaseFFT* kspaceKKBase = dynamic_cast<KokkosBaseFFT*>(kspace);

  for (int m = 0; m < nswap; m++) {
    if (swap[m].sendproc == me)
@@ -567,7 +567,7 @@ void GridCommKokkos<DeviceType>::reverse_comm(KSpace *kspace, int which)
  k_packlist.sync<DeviceType>();
  k_unpacklist.sync<DeviceType>();

  KokkosBase* kspaceKKBase = dynamic_cast<KokkosBase*>(kspace);
  KokkosBaseFFT* kspaceKKBase = dynamic_cast<KokkosBaseFFT*>(kspace);

  for (int m = nswap-1; m >= 0; m--) {
    if (swap[m].recvproc == me)
Loading