Commit 3c9098d2 authored by Stan Moore's avatar Stan Moore
Browse files

FFTW3 for Kokkos

parent 80846e3e
Loading
Loading
Loading
Loading
+18 −15
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@
#include "remap_kokkos.h"
#include "kokkos_type.h"
#include "error.h"
#include "kokkos.h"


using namespace LAMMPS_NS;
@@ -31,10 +32,10 @@ FFT3dKokkos<DeviceType>::FFT3dKokkos(LAMMPS *lmp, MPI_Comm comm, int nfast, int
             int in_klo, int in_khi,
             int out_ilo, int out_ihi, int out_jlo, int out_jhi,
             int out_klo, int out_khi,
             int scaled, int permute, int *nbuf, int usecollective,
             int nthreads) : 
             int scaled, int permute, int *nbuf, int usecollective) : 
  Pointers(lmp)
{
  int nthreads = lmp->kokkos->nthreads;
  plan = fft_3d_create_plan_kokkos(comm,nfast,nmid,nslow,
                            in_ilo,in_ihi,in_jlo,in_jhi,in_klo,in_khi,
                            out_ilo,out_ihi,out_jlo,out_jhi,out_klo,out_khi,
@@ -335,7 +336,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F
    cufft_norm_functor<DeviceType> f(d_norm_scalar,norm);
    Kokkos::parallel_for(num,f);
    DeviceType::fence();
  #else
  #elif defined(FFT_KISFFT)
    kiss_norm_functor<DeviceType> f(d_out,norm);
    Kokkos::parallel_for(num,f);
    DeviceType::fence();
@@ -595,41 +596,41 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
  // and scaling normalization

#if defined(FFT_FFTW3)
  if (nthreads > 1)
    fftw_plan_with_nthreads(nthreads);
//  if (nthreads > 1)
//    fftw_plan_with_nthreads(nthreads);

  plan->plan_1D_fast_forward =
  plan->plan_fast_forward =
    fftw_plan_many_dft(1, &nfast,plan->total1/plan->length1,
                       NULL,&nfast,1,plan->length1,
                       NULL,&nfast,1,plan->length1,
                       FFTW_FORWARD,FFTW_ESTIMATE);

  plan->plan_1D_fast_backward =
  plan->plan_fast_backward =
    fftw_plan_many_dft(1, &nfast,plan->total1/plan->length1,
                       NULL,&nfast,1,plan->length1,
                       NULL,&nfast,1,plan->length1,
                       FFTW_BACKWARD,FFTW_ESTIMATE);

  plan->plan_1D_mid_forward =
  plan->plan_mid_forward =
    fftw_plan_many_dft(1, &nmid,plan->total2/plan->length2,
                       NULL,&nmid,1,plan->length2,
                       NULL,&nmid,1,plan->length2,
                       FFTW_FORWARD,FFTW_ESTIMATE);

  plan->plan_1D_mid_backward =
  plan->plan_mid_backward =
    fftw_plan_many_dft(1, &nmid,plan->total2/plan->length2,
                       NULL,&nmid,1,plan->length2,
                       NULL,&nmid,1,plan->length2,
                       FFTW_BACKWARD,FFTW_ESTIMATE);


  plan->plan_1D_slow_forward =
  plan->plan_slow_forward =
    fftw_plan_many_dft(1, &nslow,plan->total3/plan->length3,
                       NULL,&slow,1,plan->length3,
                       NULL,&slow,1,plan->length3,
                       NULL,&nslow,1,plan->length3,
                       NULL,&nslow,1,plan->length3,
                       FFTW_FORWARD,FFTW_ESTIMATE);

  plan->plan_1D_slow_backward =
  plan->plan_slow_backward =
    fftw_plan_many_dft(1, &nslow,plan->total3/plan->length3,
                       NULL,&nslow,1,plan->length3,
                       NULL,&nslow,1,plan->length3,
@@ -705,7 +706,9 @@ void FFT3dKokkos<DeviceType>::fft_3d_destroy_plan_kokkos(struct fft_plan_3d_kokk
  delete plan;
  delete remapKK;

#ifdef FFT_KISSFFT
  delete kissfftKK;
#endif
}

/* ----------------------------------------------------------------------
@@ -827,7 +830,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename ArrayTypes<DeviceTy
    cufft_norm_functor<DeviceType> f(d_norm_scalar,norm);
    Kokkos::parallel_for(num,f);
    DeviceType::fence();
  #else
  #elif defined(FFT_KISFFT)
    kiss_norm_functor<DeviceType> f(d_data,norm);
    Kokkos::parallel_for(num,f);
    DeviceType::fence();
+7 −6
Original line number Diff line number Diff line
@@ -17,21 +17,22 @@
#include "pointers.h"
#include "kokkos_type.h"
#include "remap_kokkos.h"
#include "kissfft_kokkos.h"

#define FFT_PRECISION 2

typedef double FFT_SCALAR;

// if user set FFTW, it means FFTW3

#ifdef FFT_FFTW
#define FFT_FFTW3
#endif

#if defined(FFT_FFTW3)
  #include "fftw3.h"
  #if defined(FFT_MKL)
    #include "fftw/fftw3_mkl.h"
  #endif
  typedef fftw_complex FFT_DATA;
  #define FFTW_API(function)  fftw_ ## function
#elif defined(FFT_CUFFT)
  #include "cufft.h"
  void scale(double * ptr, double value, int n);
  typedef cufftDoubleComplex FFT_DATA;
#else
  #include "kissfft_kokkos.h"
+0 −48
Original line number Diff line number Diff line
@@ -19,8 +19,6 @@
#ifndef FFT_DATA_KOKKOS_H
#define FFT_DATA_KOKKOS_H

#include "kokkos_type.h"

// User-settable FFT precision

// FFT_PRECISION = 1 is single-precision complex (4-byte real, 4-byte imag)
@@ -41,55 +39,9 @@ typedef double FFT_SCALAR;
// Data types for single-precision complex

#if FFT_PRECISION == 1

// use a stripped down version of kiss fft as default fft

#ifndef FFT_KISSFFT
#define FFT_KISSFFT
#endif
#define kiss_fft_scalar_kokkos float
//typedef struct {
//    kiss_fft_scalar re;
//    kiss_fft_scalar im;
//} FFT_DATA;

// -------------------------------------------------------------------------

// Data types for double-precision complex

#elif FFT_PRECISION == 2

// use a stripped down version of kiss fft as default fft

#ifndef FFT_KISSFFT
#define FFT_KISSFFT
#endif
#define kiss_fft_scalar_kokkos double
//typedef struct {
//    kiss_fft_scalar re;
//    kiss_fft_scalar im;
//} FFT_DATA;

// -------------------------------------------------------------------------

#else
#error "FFT_PRECISION needs to be either 1 (=single) or 2 (=double)"
#endif

// -------------------------------------------------------------------------

#define MAXFACTORS 32
/* e.g. an fft of length 128 has 4 factors
 as far as kissfft is concerned: 4*4*4*2  */
template<class DeviceType>
struct kiss_fft_state_kokkos {
  typedef DeviceType device_type;
  typedef ArrayTypes<DeviceType> AT;
  int nfft;
  int inverse;
  typename AT::t_int_64 d_factors;
  typename AT::t_FFT_DATA_1d d_twiddles;
  typename AT::t_FFT_DATA_1d d_scratch;
};

#endif
+18 −4
Original line number Diff line number Diff line
@@ -128,10 +128,10 @@
    do{ (c)[0] *= (s);\
        (c)[1] *= (s); }while(0)

#define  C_ADD( res, a,b)\
    do { \
            (res)[0]=(a)[0]+(b)[0];  (res)[1]=(a)[1]+(b)[1]; \
    }while(0)
//#define  C_ADD( res, a,b)\
//    do { \
//            (res)[0]=(a)[0]+(b)[0];  (res)[1]=(a)[1]+(b)[1]; \
//    }while(0)

#define  C_SUB( res, a,b)\
    do { \
@@ -166,6 +166,20 @@

namespace LAMMPS_NS {

#define MAXFACTORS 32
/* e.g. an fft of length 128 has 4 factors
 as far as kissfft is concerned: 4*4*4*2  */
template<class DeviceType>
struct kiss_fft_state_kokkos {
  typedef DeviceType device_type;
  typedef ArrayTypes<DeviceType> AT;
  int nfft;
  int inverse;
  typename AT::t_int_64 d_factors;
  typename AT::t_FFT_DATA_1d d_twiddles;
  typename AT::t_FFT_DATA_1d d_scratch;
};

template<class DeviceType>
class KissFFTKokkos {
 public:
+2 −2
Original line number Diff line number Diff line
@@ -852,12 +852,12 @@ void PPPMKokkos<DeviceType>::allocate()
  fft1 = new FFT3dKokkos<DeviceType>(lmp,world,nx_pppm,ny_pppm,nz_pppm,
                         nxlo_fft,nxhi_fft,nylo_fft,nyhi_fft,nzlo_fft,nzhi_fft,
                         nxlo_fft,nxhi_fft,nylo_fft,nyhi_fft,nzlo_fft,nzhi_fft,
                         0,0,&tmp,collective_flag,lmp->kokkos->nthreads);
                         0,0,&tmp,collective_flag);

  fft2 = new FFT3dKokkos<DeviceType>(lmp,world,nx_pppm,ny_pppm,nz_pppm,
                         nxlo_fft,nxhi_fft,nylo_fft,nyhi_fft,nzlo_fft,nzhi_fft,
                         nxlo_in,nxhi_in,nylo_in,nyhi_in,nzlo_in,nzhi_in,
                         0,0,&tmp,collective_flag,lmp->kokkos->nthreads);
                         0,0,&tmp,collective_flag);
  remap = new RemapKokkos<DeviceType>(lmp,world,
                          nxlo_in,nxhi_in,nylo_in,nyhi_in,nzlo_in,nzhi_in,
                          nxlo_fft,nxhi_fft,nylo_fft,nyhi_fft,nzlo_fft,nzhi_fft,
Loading