Commit 0322ebd0 authored by Stan Moore's avatar Stan Moore
Browse files

WIP: add cuFFT support

parent ab6be65a
Loading
Loading
Loading
Loading
+147 −46
Original line number Diff line number Diff line
@@ -12,10 +12,15 @@
------------------------------------------------------------------------- */

#include <mpi.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "fft3d_kokkos.h"
#include "remap_kokkos.h"
#include "kokkos_type.h"
#include "error.h"


using namespace LAMMPS_NS;

/* ---------------------------------------------------------------------- */
@@ -65,17 +70,6 @@ void FFT3dKokkos<DeviceType>::timing1d(typename AT::t_FFT_SCALAR_1d d_in, int ns
  fft_3d_1d_only_kokkos(d_in_data,nsize,flag,plan);
}

#include <mpi.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "fft3d_kokkos.h"
#include "remap_kokkos.h"

// include kissfft implementation

#include "kissfft_kokkos.h"

#define MIN(A,B) ((A) < (B) ? (A) : (B))
#define MAX(A,B) ((A) > (B) ? (A) : (B))

@@ -109,8 +103,7 @@ void FFT3dKokkos<DeviceType>::timing1d(typename AT::t_FFT_SCALAR_1d d_in, int ns
   plan         plan returned by previous call to fft_3d_create_plan
------------------------------------------------------------------------- */

#include "kokkos_type.h"

#ifdef FFT_KISSFFT
template<class DeviceType>
struct kiss_fft_functor {
public:
@@ -157,16 +150,20 @@ public:
      d_out(i,1) *= norm;
  }
};
#endif

const int forward  = 0;
const int backward = 1;

template<class DeviceType>
void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_FFT_DATA_1d d_in, typename ArrayTypes<DeviceType>::t_FFT_DATA_1d d_out, int flag, struct fft_plan_3d_kokkos<DeviceType> *plan)
{
  int i,total,length,offset,num;
  int i,total,length,offset,num,dim;
  FFT_SCALAR norm;
  typename ArrayTypes<DeviceType>::t_FFT_DATA_1d d_data,d_copy;
  typename ArrayTypes<DeviceType>::t_FFT_DATA_1d d_tmp = typename ArrayTypes<DeviceType>::t_FFT_DATA_1d("fft_3d:tmp",d_in.dimension_0());
  typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d d_in_scalar,d_data_scalar,d_out_scalar,d_copy_scalar,d_scratch_scalar;
  kiss_fft_functor<DeviceType> f;

  int dir = (flag == -1) ? forward : backward;

  // pre-remap to prepare for 1st FFTs if needed
  // copy = loc for remap result
@@ -189,16 +186,25 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F

  total = plan->total1;
  length = plan->length1;
  dim = 0;

  #if defined(FFT_FFTW3)
    fftw_execute_dft(plan->plan_1D[dim][dir],d_data.data(),d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExecZ2Z(plan->plan_fast,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
  #else
    typename ArrayTypes<DeviceType>::t_FFT_DATA_1d d_tmp = typename ArrayTypes<DeviceType>::t_FFT_DATA_1d("fft_3d:tmp",d_in.dimension_0());
    kiss_fft_functor<DeviceType> f;
    if (flag == -1)
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_fast_forward,length);
    else
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_fast_backward,length);
    Kokkos::parallel_for(total/length,f);
    DeviceType::fence();

    d_data = d_tmp;
    d_tmp = typename ArrayTypes<DeviceType>::t_FFT_DATA_1d("fft_3d:tmp",d_in.dimension_0());
  #endif


  // 1st mid-remap to prepare for 2nd FFTs
  // copy = loc for remap result
@@ -219,16 +225,22 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F

  total = plan->total2;
  length = plan->length2;
  dim = 1;

  #if defined(FFT_FFTW3)
    fftw_execute_dft(plan->plan_1D[dim][dir],d_data.data(),d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExecZ2Z(plan->plan_mid,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
  #else
    if (flag == -1)
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_mid_forward,length);
    else
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_mid_backward,length);
    Kokkos::parallel_for(total/length,f);
    DeviceType::fence();

    d_data = d_tmp;
    d_tmp = typename ArrayTypes<DeviceType>::t_FFT_DATA_1d("fft_3d:tmp",d_in.dimension_0());
  #endif

  // 2nd mid-remap to prepare for 3rd FFTs
  // copy = loc for remap result
@@ -249,15 +261,23 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F

  total = plan->total3;
  length = plan->length3;

  dim = 2;

  #if defined(FFT_FFTW3)
    fftw_execute_dft(plan->plan_1D[dim][dir],d_data.data(),d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExecZ2Z(plan->plan_slow,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
    // CUDA
  #else
    if (flag == -1)
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_slow_forward,length);
    else
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_slow_backward,length);
    Kokkos::parallel_for(total/length,f);
    DeviceType::fence();

    d_data = d_tmp;
  #endif


  // post-remap to put data in output format if needed
  // destination is always out
@@ -276,9 +296,13 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F
  if (flag == 1 && plan->scaled) {
    norm = plan->norm;
    num = plan->normnum;
  #if defined(FFT_CUFFT)
    //scale(ptr, norm, num); //////////////////////
  #else
    norm_functor<DeviceType> f(d_out,norm);
    Kokkos::parallel_for(num,f);
    DeviceType::fence();
  #endif
  }

}
@@ -323,6 +347,13 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
  int np1,np2,ip1,ip2;
  int list[50];
 
  #ifdef FFT_FFTW
  if (nthreads > 1) { /////////////
    std::cout << "fftw_init_threads: " << fftw_init_threads() << std::endl;;
    fftw_plan_with_nthreads(nthreads);
  }
  #endif

  // query MPI info

  MPI_Comm_rank(comm,&me);
@@ -338,13 +369,13 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl

  plan = new struct fft_plan_3d_kokkos<DeviceType>;
  remapKK = new RemapKokkos<DeviceType>(lmp);
  kissfftKK = new KissFFTKokkos<DeviceType>();
  if (plan == NULL) return NULL;

  // remap from initial distribution to layout needed for 1st set of 1d FFTs
  // not needed if all procs own entire fast axis initially
  // first indices = distribution after 1st set of FFTs


  if (in_ilo == 0 && in_ihi == nfast-1) flag = 0;
  else flag = 1;

@@ -529,9 +560,41 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
    plan->d_scratch = typename ArrayTypes<DeviceType>::t_FFT_DATA_1d("fft3d:scratch",scratch_size);
  }

  
    kissfftKK = new KissFFTKokkos<DeviceType>();

  // system specific pre-computation of 1d FFT coeffs
  // and scaling normalization

#if defined(FFT_FFTW3)

  for (int dim = 0; dim < 3; dim++) {
    for (int dir = 0; dir < 2; dir++) {
      plan->plan_1D[dim][dir] =
        fftw_plan_many_dft(1, &n[dim],plan->totals[dim]/plan->lengths[dim],
                           NULL,&n[dim],1,plan->lengths[dim],
                           NULL,&n[dim],1,plan->lengths[dim],
                           (dir == forward) ? FFTW_FORWARD : FFTW_BACKWARD,
                           FFTW_ESTIMATE);
    }
  }

#elif defined(FFT_CUFFT)
  cufftPlanMany(&(plan->plan_fast), 1, &nfast,
    &nfast,1,plan->length1,
    &nfast,1,plan->length1,
    CUFFT_Z2Z,plan->total1/plan->length1);

  cufftPlanMany(&(plan->plan_mid), 1, &nmid,
    &nmid,1,plan->length2,
    &nmid,1,plan->length2,
    CUFFT_Z2Z,plan->total2/plan->length2);

  cufftPlanMany(&(plan->plan_slow), 1, &nslow,
    &nslow,1,plan->length3,
    &nslow,1,plan->length3,
    CUFFT_Z2Z,plan->total3/plan->length3);
#else
  plan->cfg_fast_forward = KissFFTKokkos<DeviceType>::kiss_fft_alloc_kokkos(nfast,0,NULL,NULL);
  plan->cfg_fast_backward = KissFFTKokkos<DeviceType>::kiss_fft_alloc_kokkos(nfast,1,NULL,NULL);

@@ -556,6 +619,7 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
    plan->cfg_slow_forward = KissFFTKokkos<DeviceType>::kiss_fft_alloc_kokkos(nslow,0,NULL,NULL);
    plan->cfg_slow_backward = KissFFTKokkos<DeviceType>::kiss_fft_alloc_kokkos(nslow,1,NULL,NULL);
  }
#endif

  if (scaled == 0)
    plan->scaled = 0;
@@ -583,9 +647,31 @@ void FFT3dKokkos<DeviceType>::fft_3d_destroy_plan_kokkos(struct fft_plan_3d_kokk

  delete plan;
  delete remapKK;

  delete kissfftKK;
}

/* ----------------------------------------------------------------------
   divide n into 2 factors of as equal size as possible
------------------------------------------------------------------------- */

template<class DeviceType>
void FFT3dKokkos<DeviceType>::bifactor(int n, int *factor1, int *factor2)
{
  int n1,n2,facmax;

  facmax = static_cast<int> (sqrt((double) n));

  for (n1 = facmax; n1 > 0; n1--) {
    n2 = n/n1;
    if (n1*n2 == n) {
      *factor1 = n1;
      *factor2 = n2;
      return;
    }
  }
}

/* ----------------------------------------------------------------------
   perform just the 1d FFTs needed by a 3d FFT, no data movement
   used for timing purposes
@@ -602,8 +688,6 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename ArrayTypes<DeviceTy
                    struct fft_plan_3d_kokkos<DeviceType> *plan)
{
  int i,total,length,offset,num;
  typename ArrayTypes<DeviceType>::t_FFT_DATA_1d d_tmp = typename ArrayTypes<DeviceType>::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0());
  kiss_fft_functor<DeviceType> f;

  // total = size of data needed in each dim
  // length = length of 1d FFT in each dim
@@ -627,6 +711,18 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename ArrayTypes<DeviceTy
  // perform 1d FFTs in each of 3 dimensions
  // data is just an array of 0.0

  int dir = (flag == 1) ? forward : backward;

#if defined(FFT_FFTW3)
  for (int dim = 0; dim < 3; dim++)
    fftw_execute_dft(plan->plan_1D[dim][dir],data,data);
#elif defined(FFT_CUFFT)
    cufftExecZ2Z(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
    cufftExecZ2Z(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
    cufftExecZ2Z(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
#else
  kiss_fft_functor<DeviceType> f;
  typename ArrayTypes<DeviceType>::t_FFT_DATA_1d d_tmp = typename ArrayTypes<DeviceType>::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0());
  if (flag == -1) {
    f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_fast_forward,length1);
    Kokkos::parallel_for(total1/length1,f);
@@ -652,6 +748,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename ArrayTypes<DeviceTy
    Kokkos::parallel_for(total3/length3,f);
    DeviceType::fence();
  }
#endif

  // scaling if required
  // limit num to size of data
@@ -659,9 +756,13 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename ArrayTypes<DeviceTy
  if (flag == 1 && plan->scaled) {
    FFT_SCALAR norm = plan->norm;
    num = MIN(plan->normnum,nsize);
  #if defined(FFT_CUFFT)
    //scale(ptr, norm, num); ///////////////
  #else
    norm_functor<DeviceType> f(d_data,norm);
    Kokkos::parallel_for(num,f);
    DeviceType::fence();
  #endif
  }
}

+38 −1
Original line number Diff line number Diff line
@@ -15,11 +15,31 @@
#define LMP_FFT3D_KOKKOS_H

#include "pointers.h"
#include "fft3d.h"
#include "kokkos_type.h"
#include "remap_kokkos.h"
#include "kissfft_kokkos.h"

#define FFT_PRECISION 2

typedef double FFT_SCALAR;

#if defined(FFT_FFTW3)
  #include "fftw3.h"
  #if defined(FFT_MKL)
    #include "fftw/fftw3_mkl.h"
  #endif
  typedef fftw_complex FFT_DATA;
#elif defined(FFT_CUFFT)
  #include "cufft.h"
  void scale(double * ptr, double value, int n);
  typedef cufftDoubleComplex FFT_DATA;
#else
  #include "kissfft_kokkos.h"
  #ifndef FFT_KISSFFT
  #define FFT_KISSFFT
  #endif
#endif

namespace LAMMPS_NS {

// -------------------------------------------------------------------------
@@ -46,12 +66,25 @@ struct fft_plan_3d_kokkos {
  double norm;                      // normalization factor for rescaling

                                    // system specific 1d FFT info
#if defined(FFT_FFTW3)
  FFTW_API(plan) plan_fast_forward;
  FFTW_API(plan) plan_fast_backward;
  FFTW_API(plan) plan_mid_forward;
  FFTW_API(plan) plan_mid_backward;
  FFTW_API(plan) plan_slow_forward;
  FFTW_API(plan) plan_slow_backward;
#elif defined(FFT_CUFFT)
  cufftHandle plan_fast;
  cufftHandle plan_mid;
  cufftHandle plan_slow;
#else
  kiss_fft_state_kokkos<DeviceType> cfg_fast_forward;
  kiss_fft_state_kokkos<DeviceType> cfg_fast_backward;
  kiss_fft_state_kokkos<DeviceType> cfg_mid_forward;
  kiss_fft_state_kokkos<DeviceType> cfg_mid_backward;
  kiss_fft_state_kokkos<DeviceType> cfg_slow_forward;
  kiss_fft_state_kokkos<DeviceType> cfg_slow_backward;
#endif
};

template<class DeviceType>
@@ -70,7 +103,9 @@ class FFT3dKokkos : protected Pointers {
 private:
  struct fft_plan_3d_kokkos<DeviceType> *plan;
  RemapKokkos<DeviceType> *remapKK;
#ifdef FFT_KISSFFT
  KissFFTKokkos<DeviceType> *kissfftKK;
#endif

  void fft_3d_kokkos(typename AT::t_FFT_DATA_1d, typename AT::t_FFT_DATA_1d, int, struct fft_plan_3d_kokkos<DeviceType> *);

@@ -82,6 +117,8 @@ class FFT3dKokkos : protected Pointers {
  void fft_3d_destroy_plan_kokkos(struct fft_plan_3d_kokkos<DeviceType> *);

  void fft_3d_1d_only_kokkos(typename AT::t_FFT_DATA_1d, int, int, struct fft_plan_3d_kokkos<DeviceType> *);

  void bifactor(int, int *, int *);
};

}
+42 −42
Original line number Diff line number Diff line
@@ -174,20 +174,20 @@ t_scalar3<Scalar> operator *
  return t_scalar3<Scalar>(a.x*b,a.y*b,a.z*b);
}

#if !defined(__CUDACC__) && !defined(__VECTOR_TYPES_H__)
  struct double2 {
    double x, y;
  };
  struct float2 {
    float x, y;
  };
  struct float4 {
    float x, y, z, w;
  };
  struct double4 {
    double x, y, z, w;
  };
#endif
//#if !defined(__CUDACC__) && !defined(__VECTOR_TYPES_H__)
//  struct double2 {
//    double x, y;
//  };
//  struct float2 {
//    float x, y;
//  };
//  struct float4 {
//    float x, y, z, w;
//  };
//  struct double4 {
//    double x, y, z, w;
//  };
//#endif
// set LMPHostype and LMPDeviceType from Kokkos Default Types
typedef Kokkos::DefaultExecutionSpace LMPDeviceType;
typedef Kokkos::HostSpace::execution_space LMPHostType;
@@ -310,14 +310,14 @@ public:
#endif
#if PRECISION==1
typedef float LMP_FLOAT;
typedef float2 LMP_FLOAT2;
typedef lmp_float3 LMP_FLOAT3;
typedef float4 LMP_FLOAT4;
//typedef float2 LMP_FLOAT2;
//typedef lmp_float3 LMP_FLOAT3;
//typedef float4 LMP_FLOAT4;
#else
typedef double LMP_FLOAT;
typedef double2 LMP_FLOAT2;
typedef lmp_double3 LMP_FLOAT3;
typedef double4 LMP_FLOAT4;
//typedef double2 LMP_FLOAT2;
//typedef lmp_double3 LMP_FLOAT3;
//typedef double4 LMP_FLOAT4;
#endif

#ifndef PREC_FORCE
@@ -326,14 +326,14 @@ typedef double4 LMP_FLOAT4;

#if PREC_FORCE==1
typedef float F_FLOAT;
typedef float2 F_FLOAT2;
typedef lmp_float3 F_FLOAT3;
typedef float4 F_FLOAT4;
//typedef float2 F_FLOAT2;
//typedef lmp_float3 F_FLOAT3;
//typedef float4 F_FLOAT4;
#else
typedef double F_FLOAT;
typedef double2 F_FLOAT2;
typedef lmp_double3 F_FLOAT3;
typedef double4 F_FLOAT4;
//typedef double2 F_FLOAT2;
//typedef lmp_double3 F_FLOAT3;
//typedef double4 F_FLOAT4;
#endif

#ifndef PREC_ENERGY
@@ -342,12 +342,12 @@ typedef double4 F_FLOAT4;

#if PREC_ENERGY==1
typedef float E_FLOAT;
typedef float2 E_FLOAT2;
typedef float4 E_FLOAT4;
//typedef float2 E_FLOAT2;
//typedef float4 E_FLOAT4;
#else
typedef double E_FLOAT;
typedef double2 E_FLOAT2;
typedef double4 E_FLOAT4;
//typedef double2 E_FLOAT2;
//typedef double4 E_FLOAT4;
#endif

struct s_EV_FLOAT {
@@ -500,12 +500,12 @@ typedef struct s_FEV_FLOAT FEV_FLOAT;

#if PREC_POS==1
typedef float X_FLOAT;
typedef float2 X_FLOAT2;
typedef float4 X_FLOAT4;
//typedef float2 X_FLOAT2;
//typedef float4 X_FLOAT4;
#else
typedef double X_FLOAT;
typedef double2 X_FLOAT2;
typedef double4 X_FLOAT4;
//typedef double2 X_FLOAT2;
//typedef double4 X_FLOAT4;
#endif

#ifndef PREC_VELOCITIES
@@ -514,22 +514,22 @@ typedef double4 X_FLOAT4;

#if PREC_VELOCITIES==1
typedef float V_FLOAT;
typedef float2 V_FLOAT2;
typedef float4 V_FLOAT4;
//typedef float2 V_FLOAT2;
//typedef float4 V_FLOAT4;
#else
typedef double V_FLOAT;
typedef double2 V_FLOAT2;
typedef double4 V_FLOAT4;
//typedef double2 V_FLOAT2;
//typedef double4 V_FLOAT4;
#endif

#if PREC_KSPACE==1
typedef float K_FLOAT;
typedef float2 K_FLOAT2;
typedef float4 K_FLOAT4;
//typedef float2 K_FLOAT2;
//typedef float4 K_FLOAT4;
#else
typedef double K_FLOAT;
typedef double2 K_FLOAT2;
typedef double4 K_FLOAT4;
//typedef double2 K_FLOAT2;
//typedef double4 K_FLOAT4;
#endif

typedef int T_INT;