Commit a50563d5 authored by Stan Moore's avatar Stan Moore
Browse files

Fix issue with Kokkos FFTW3

parent 4ea9dea1
Loading
Loading
Loading
Loading
+21 −21
Original line number Diff line number Diff line
@@ -87,8 +87,8 @@ FFT3dKokkos<DeviceType>::~FFT3dKokkos()
template<class DeviceType>
void FFT3dKokkos<DeviceType>::compute(typename FFT_AT::t_FFT_SCALAR_1d d_in, typename FFT_AT::t_FFT_SCALAR_1d d_out, int flag)
{
  typename FFT_AT::t_FFT_DATA_1d d_in_data((FFT_DATA*)d_in.data(),d_in.size()/2);
  typename FFT_AT::t_FFT_DATA_1d d_out_data((FFT_DATA*)d_out.data(),d_out.size()/2);
  typename FFT_AT::t_FFT_DATA_1d d_in_data((FFT_DATA_POINTER)d_in.data(),d_in.size()/2);
  typename FFT_AT::t_FFT_DATA_1d d_out_data((FFT_DATA_POINTER)d_out.data(),d_out.size()/2);

  fft_3d_kokkos(d_in_data,d_out_data,flag,plan);
}
@@ -98,7 +98,7 @@ void FFT3dKokkos<DeviceType>::compute(typename FFT_AT::t_FFT_SCALAR_1d d_in, typ
template<class DeviceType>
void FFT3dKokkos<DeviceType>::timing1d(typename FFT_AT::t_FFT_SCALAR_1d d_in, int nsize, int flag)
{
  typename FFT_AT::t_FFT_DATA_1d d_in_data((FFT_DATA*)d_in.data(),d_in.size()/2);
  typename FFT_AT::t_FFT_DATA_1d d_in_data((FFT_DATA_POINTER)d_in.data(),d_in.size()/2);

  fft_3d_1d_only_kokkos(d_in_data,nsize,flag,plan);
}
@@ -223,11 +223,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename FFT_AT::t_FFT_DATA_1d d_in,
      DftiComputeBackward(plan->handle_fast,d_data.data());
  #elif defined(FFT_FFTW3)
    if (flag == -1)
      FFTW_API(execute_dft)(plan->plan_fast_forward,d_data.data(),d_data.data());
      FFTW_API(execute_dft)(plan->plan_fast_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    else
      FFTW_API(execute_dft)(plan->plan_fast_backward,d_data.data(),d_data.data());
      FFTW_API(execute_dft)(plan->plan_fast_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExec(plan->plan_fast,d_data.data(),d_data.data(),flag);
    cufftExec(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  #else
    typename FFT_AT::t_FFT_DATA_1d d_tmp =
     typename FFT_AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0());
@@ -269,11 +269,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename FFT_AT::t_FFT_DATA_1d d_in,
      DftiComputeBackward(plan->handle_mid,d_data.data());
  #elif defined(FFT_FFTW3)
    if (flag == -1)
      FFTW_API(execute_dft)(plan->plan_mid_forward,d_data.data(),d_data.data());
      FFTW_API(execute_dft)(plan->plan_mid_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    else
      FFTW_API(execute_dft)(plan->plan_mid_backward,d_data.data(),d_data.data());
      FFTW_API(execute_dft)(plan->plan_mid_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExec(plan->plan_mid,d_data.data(),d_data.data(),flag);
    cufftExec(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  #else
    if (flag == -1)
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_mid_forward,length);
@@ -311,11 +311,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename FFT_AT::t_FFT_DATA_1d d_in,
      DftiComputeBackward(plan->handle_slow,d_data.data());
  #elif defined(FFT_FFTW3)
    if (flag == -1)
      FFTW_API(execute_dft)(plan->plan_slow_forward,d_data.data(),d_data.data());
      FFTW_API(execute_dft)(plan->plan_slow_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    else
      FFTW_API(execute_dft)(plan->plan_slow_backward,d_data.data(),d_data.data());
      FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExec(plan->plan_slow,d_data.data(),d_data.data(),flag);
    cufftExec(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  #else
    if (flag == -1)
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_slow_forward,length);
@@ -850,18 +850,18 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename FFT_AT::t_FFT_DATA_
  }
#elif defined(FFT_FFTW3)
  if (flag == -1) {
    FFTW_API(execute_dft)(plan->plan_fast_forward,d_data.data(),d_data.data());
    FFTW_API(execute_dft)(plan->plan_mid_forward,d_data.data(),d_data.data());
    FFTW_API(execute_dft)(plan->plan_slow_forward,d_data.data(),d_data.data());
    FFTW_API(execute_dft)(plan->plan_fast_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_mid_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_slow_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
  } else {
    FFTW_API(execute_dft)(plan->plan_fast_backward,d_data.data(),d_data.data());
    FFTW_API(execute_dft)(plan->plan_mid_backward,d_data.data(),d_data.data());
    FFTW_API(execute_dft)(plan->plan_slow_backward,d_data.data(),d_data.data());
    FFTW_API(execute_dft)(plan->plan_fast_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_mid_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
  }
#elif defined(FFT_CUFFT)
  cufftExec(plan->plan_fast,d_data.data(),d_data.data(),flag);
  cufftExec(plan->plan_mid,d_data.data(),d_data.data(),flag);
  cufftExec(plan->plan_slow,d_data.data(),d_data.data(),flag);
  cufftExec(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExec(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExec(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
#else
  kiss_fft_functor<DeviceType> f;
  typename FFT_AT::t_FFT_DATA_1d d_tmp = typename FFT_AT::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0());
+9 −1
Original line number Diff line number Diff line
@@ -120,6 +120,14 @@ typedef double FFT_SCALAR;
  #endif
#endif

// (double[2]*) is not a 1D pointer
#if defined(FFT_FFTW3) || defined(FFT_CUFFT)
  typedef FFT_SCALAR* FFT_DATA_POINTER;
#else
  typedef FFT_DATA* FFT_DATA_POINTER;
#endif


#include "kokkos_type.h"

template <class DeviceType>
@@ -192,7 +200,7 @@ typedef struct FFTArrayTypes<LMPHostType> FFT_HAT;


#if defined(FFT_KISSFFT)
#include "kissfft_kokkos.h"
#include "kissfft_kokkos.h" // uses t_FFT_DATA_1d, needs to come last
#endif