Commit ca5aa1f9 authored by Stan Moore's avatar Stan Moore
Browse files

Add cuFFT norm functor

parent 0322ebd0
Loading
Loading
Loading
Loading
+36 −6
Original line number Diff line number Diff line
@@ -103,6 +103,28 @@ void FFT3dKokkos<DeviceType>::timing1d(typename AT::t_FFT_SCALAR_1d d_in, int ns
   plan         plan returned by previous call to fft_3d_create_plan
------------------------------------------------------------------------- */

#ifdef FFT_CUFFT
template<class DeviceType>
struct cufft_norm_functor {
public:
  typedef DeviceType device_type;
  typedef ArrayTypes<DeviceType> AT;
  typename AT::t_FFT_SCALAR_1d_um d_out;
  int norm;

  cufft_norm_functor(typename AT::t_FFT_SCALAR_1d &d_out_, int norm_):
    d_out(d_out_)
    {
      norm = norm_;
    }

  KOKKOS_INLINE_FUNCTION
  void operator() (const int &i) const {
      d_out(i) *= norm;
  }
};
#endif

#ifdef FFT_KISSFFT
template<class DeviceType>
struct kiss_fft_functor {
@@ -131,14 +153,14 @@ public:
};

template<class DeviceType>
struct norm_functor {
struct kiss_norm_functor {
public:
  typedef DeviceType device_type;
  typedef ArrayTypes<DeviceType> AT;
  typename AT::t_FFT_DATA_1d_um d_out;
  int norm;

  norm_functor(typename AT::t_FFT_DATA_1d &d_out_, int norm_):
  kiss_norm_functor(typename AT::t_FFT_DATA_1d &d_out_, int norm_):
    d_out(d_out_)
    {
      norm = norm_;
@@ -297,9 +319,13 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F
    norm = plan->norm;
    num = plan->normnum;
  #if defined(FFT_CUFFT)
    //scale(ptr, norm, num); //////////////////////
    typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d d_norm_scalar = 
     typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
    cufft_norm_functor<DeviceType> f(d_norm_scalar,norm);
    Kokkos::parallel_for(num,f);
    DeviceType::fence();
  #else
    norm_functor<DeviceType> f(d_out,norm);
    kiss_norm_functor<DeviceType> f(d_out,norm);
    Kokkos::parallel_for(num,f);
    DeviceType::fence();
  #endif
@@ -757,9 +783,13 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename ArrayTypes<DeviceTy
    FFT_SCALAR norm = plan->norm;
    num = MIN(plan->normnum,nsize);
  #if defined(FFT_CUFFT)
    //scale(ptr, norm, num); ///////////////
    typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d d_norm_scalar = 
     typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
    cufft_norm_functor<DeviceType> f(d_norm_scalar,norm);
    Kokkos::parallel_for(num,f);
    DeviceType::fence();
  #else
    norm_functor<DeviceType> f(d_data,norm);
    kiss_norm_functor<DeviceType> f(d_data,norm);
    Kokkos::parallel_for(num,f);
    DeviceType::fence();
  #endif