Commit 9fade740 authored by Stan Moore's avatar Stan Moore
Browse files

Fix issue with Kokkos FFT_CUFFT

parent a50563d5
Loading
Loading
Loading
Loading
+6 −6
Original line number Diff line number Diff line
@@ -227,7 +227,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename FFT_AT::t_FFT_DATA_1d d_in,
    else
      FFTW_API(execute_dft)(plan->plan_fast_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExec(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
    cufftExec(plan->plan_fast,d_data.data(),d_data.data(),flag);
  #else
    typename FFT_AT::t_FFT_DATA_1d d_tmp =
     typename FFT_AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0());
@@ -273,7 +273,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename FFT_AT::t_FFT_DATA_1d d_in,
    else
      FFTW_API(execute_dft)(plan->plan_mid_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExec(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
    cufftExec(plan->plan_mid,d_data.data(),d_data.data(),flag);
  #else
    if (flag == -1)
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_mid_forward,length);
@@ -315,7 +315,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename FFT_AT::t_FFT_DATA_1d d_in,
    else
      FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExec(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
    cufftExec(plan->plan_slow,d_data.data(),d_data.data(),flag);
  #else
    if (flag == -1)
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_slow_forward,length);
@@ -859,9 +859,9 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename FFT_AT::t_FFT_DATA_
    FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
  }
#elif defined(FFT_CUFFT)
  cufftExec(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExec(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExec(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExec(plan->plan_fast,d_data.data(),d_data.data(),flag);
  cufftExec(plan->plan_mid,d_data.data(),d_data.data(),flag);
  cufftExec(plan->plan_slow,d_data.data(),d_data.data(),flag);
#else
  kiss_fft_functor<DeviceType> f;
  typename FFT_AT::t_FFT_DATA_1d d_tmp = typename FFT_AT::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0());
+4 −3
Original line number Diff line number Diff line
@@ -11,6 +11,9 @@
   See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */

#include "kokkos_type.h"


#define MAX(A,B) ((A) > (B) ? (A) : (B))

// data types for 2d/3d FFTs
@@ -121,15 +124,13 @@ typedef double FFT_SCALAR;
#endif

// (double[2]*) is not a 1D pointer
#if defined(FFT_FFTW3) || defined(FFT_CUFFT)
#if defined(FFT_FFTW3)
  typedef FFT_SCALAR* FFT_DATA_POINTER;
#else
  typedef FFT_DATA* FFT_DATA_POINTER;
#endif


#include "kokkos_type.h"

template <class DeviceType>
struct FFTArrayTypes;