Commit fafd824b authored by Stan Moore's avatar Stan Moore
Browse files

Small tweaks

parent 06506b9f
Loading
Loading
Loading
Loading
+27 −26
Original line number Diff line number Diff line
@@ -192,11 +192,11 @@ public:
#endif

template<class DeviceType>
void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_FFT_DATA_1d d_in, typename ArrayTypes<DeviceType>::t_FFT_DATA_1d d_out, int flag, struct fft_plan_3d_kokkos<DeviceType> *plan)
void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typename AT::t_FFT_DATA_1d d_out, int flag, struct fft_plan_3d_kokkos<DeviceType> *plan)
{
  int total,length;
  typename ArrayTypes<DeviceType>::t_FFT_DATA_1d d_data,d_copy;
  typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d d_in_scalar,d_data_scalar,d_out_scalar,d_copy_scalar,d_scratch_scalar;
  typename AT::t_FFT_DATA_1d d_data,d_copy;
  typename AT::t_FFT_SCALAR_1d d_in_scalar,d_data_scalar,d_out_scalar,d_copy_scalar,d_scratch_scalar;

  // pre-remap to prepare for 1st FFTs if needed
  // copy = loc for remap result
@@ -205,9 +205,9 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F
    if (plan->pre_target == 0) d_copy = d_out;
    else d_copy = plan->d_copy;

     d_in_scalar = typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_in.data(),d_in.size());
     d_copy_scalar = typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_copy.data(),d_copy.size());
     d_scratch_scalar = typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(plan->d_scratch.data(),plan->d_scratch.size());
     d_in_scalar = typename AT::t_FFT_SCALAR_1d(d_in.data(),d_in.size());
     d_copy_scalar = typename AT::t_FFT_SCALAR_1d(d_copy.data(),d_copy.size());
     d_scratch_scalar = typename AT::t_FFT_SCALAR_1d(plan->d_scratch.data(),plan->d_scratch.size());

    remapKK->remap_3d_kokkos(d_in_scalar, d_copy_scalar, 
             d_scratch_scalar, plan->pre_plan);
@@ -228,7 +228,8 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F
  #elif defined(FFT_CUFFT)
    cufftExecZ2Z(plan->plan_fast,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
  #else
    typename ArrayTypes<DeviceType>::t_FFT_DATA_1d d_tmp = typename ArrayTypes<DeviceType>::t_FFT_DATA_1d("fft_3d:tmp",d_in.dimension_0());
    typename AT::t_FFT_DATA_1d d_tmp = 
     typename AT::t_FFT_DATA_1d("fft_3d:tmp",d_in.dimension_0());
    kiss_fft_functor<DeviceType> f;
    if (flag == -1)
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_fast_forward,length);
@@ -236,7 +237,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_fast_backward,length);
    Kokkos::parallel_for(total/length,f);
    d_data = d_tmp;
    d_tmp = typename ArrayTypes<DeviceType>::t_FFT_DATA_1d("fft_3d:tmp",d_in.dimension_0());
    d_tmp = typename AT::t_FFT_DATA_1d("fft_3d:tmp",d_in.dimension_0());
  #endif


@@ -246,9 +247,9 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F
  if (plan->mid1_target == 0) d_copy = d_out;
  else d_copy = plan->d_copy;

  d_data_scalar = typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_data.data(),d_data.size()*2);
  d_copy_scalar = typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_copy.data(),d_copy.size()*2);
  d_scratch_scalar = typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(plan->d_scratch.data(),plan->d_scratch.size()*2);
  d_data_scalar = typename AT::t_FFT_SCALAR_1d(d_data.data(),d_data.size()*2);
  d_copy_scalar = typename AT::t_FFT_SCALAR_1d(d_copy.data(),d_copy.size()*2);
  d_scratch_scalar = typename AT::t_FFT_SCALAR_1d(plan->d_scratch.data(),plan->d_scratch.size()*2);

  remapKK->remap_3d_kokkos(d_data_scalar, d_copy_scalar, 
           d_scratch_scalar, plan->mid1_plan);
@@ -274,7 +275,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_mid_backward,length);
    Kokkos::parallel_for(total/length,f);
    d_data = d_tmp;
    d_tmp = typename ArrayTypes<DeviceType>::t_FFT_DATA_1d("fft_3d:tmp",d_in.dimension_0());
    d_tmp = typename AT::t_FFT_DATA_1d("fft_3d:tmp",d_in.dimension_0());
  #endif

  // 2nd mid-remap to prepare for 3rd FFTs
@@ -283,9 +284,9 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F
  if (plan->mid2_target == 0) d_copy = d_out;
  else d_copy = plan->d_copy;

  d_data_scalar = typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
  d_copy_scalar = typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_copy.data(),d_copy.size());
  d_scratch_scalar = typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(plan->d_scratch.data(),plan->d_scratch.size());
  d_data_scalar = typename AT::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
  d_copy_scalar = typename AT::t_FFT_SCALAR_1d(d_copy.data(),d_copy.size());
  d_scratch_scalar = typename AT::t_FFT_SCALAR_1d(plan->d_scratch.data(),plan->d_scratch.size());

  remapKK->remap_3d_kokkos(d_data_scalar, d_copy_scalar, 
           d_scratch_scalar, plan->mid2_plan);
@@ -320,9 +321,9 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F
  // destination is always out

  if (plan->post_plan) {
    d_data_scalar = typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
    d_out_scalar = typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_out.data(),d_out.size());
    d_scratch_scalar = typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(plan->d_scratch.data(),plan->d_scratch.size());
    d_data_scalar = typename AT::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
    d_out_scalar = typename AT::t_FFT_SCALAR_1d(d_out.data(),d_out.size());
    d_scratch_scalar = typename AT::t_FFT_SCALAR_1d(plan->d_scratch.data(),plan->d_scratch.size());

    remapKK->remap_3d_kokkos(d_data_scalar, d_out_scalar, 
             d_scratch_scalar, plan->post_plan);
@@ -334,8 +335,8 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F
    int norm = plan->norm;
    FFT_SCALAR num = plan->normnum;
  #if defined(FFT_CUFFT)
    typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d d_norm_scalar = 
     typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
    typename AT::t_FFT_SCALAR_1d d_norm_scalar = 
     typename AT::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
    cufft_norm_functor<DeviceType> f(d_norm_scalar,norm);
    Kokkos::parallel_for(num,f);
  #elif defined(FFT_KISSFFT)
@@ -584,11 +585,11 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
  *nbuf = copy_size + scratch_size;

  if (copy_size) {
    plan->d_copy = typename ArrayTypes<DeviceType>::t_FFT_DATA_1d("fft3d:copy",copy_size);
    plan->d_copy = typename AT::t_FFT_DATA_1d("fft3d:copy",copy_size);
  }

  if (scratch_size) {
    plan->d_scratch = typename ArrayTypes<DeviceType>::t_FFT_DATA_1d("fft3d:scratch",scratch_size);
    plan->d_scratch = typename AT::t_FFT_DATA_1d("fft3d:scratch",scratch_size);
  }

  // system specific pre-computation of 1d FFT coeffs
@@ -743,7 +744,7 @@ void FFT3dKokkos<DeviceType>::bifactor(int n, int *factor1, int *factor2)
------------------------------------------------------------------------- */

template<class DeviceType>
void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename ArrayTypes<DeviceType>::t_FFT_DATA_1d d_data, int nsize, int flag,
void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename AT::t_FFT_DATA_1d d_data, int nsize, int flag,
                    struct fft_plan_3d_kokkos<DeviceType> *plan)
{
  // total = size of data needed in each dim
@@ -784,7 +785,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename ArrayTypes<DeviceTy
  cufftExecZ2Z(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
#else
  kiss_fft_functor<DeviceType> f;
  typename ArrayTypes<DeviceType>::t_FFT_DATA_1d d_tmp = typename ArrayTypes<DeviceType>::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0());
  typename AT::t_FFT_DATA_1d d_tmp = typename AT::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0());
  if (flag == -1) {
    f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_fast_forward,length1);
    Kokkos::parallel_for(total1/length1,f);
@@ -813,8 +814,8 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename ArrayTypes<DeviceTy
    FFT_SCALAR norm = plan->norm;
    int num = MIN(plan->normnum,nsize);
  #if defined(FFT_CUFFT)
    typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d d_norm_scalar = 
     typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
    typename AT::t_FFT_SCALAR_1d d_norm_scalar = 
     typename AT::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
    cufft_norm_functor<DeviceType> f(d_norm_scalar,norm);
    Kokkos::parallel_for(num,f);
  #elif defined(FFT_KISSFFT)