Commit 5aae0956 authored by Stan Moore's avatar Stan Moore
Browse files

Add support for single precision

parent c515b7dc
Loading
Loading
Loading
Loading
+27 −29
Original line number Diff line number Diff line
@@ -222,11 +222,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ

  #if defined(FFT_FFTW3)
    if (flag == -1)
      fftw_execute_dft(plan->plan_fast_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
      FFTW_API(execute_dft)(plan->plan_fast_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
    else
      fftw_execute_dft(plan->plan_fast_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
      FFTW_API(execute_dft)(plan->plan_fast_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExecZ2Z(plan->plan_fast,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
    cufftExec(plan->plan_fast,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
  #else
    typename AT::t_FFT_DATA_1d d_tmp = 
     typename AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0());
@@ -263,11 +263,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ

  #if defined(FFT_FFTW3)
    if (flag == -1)
      fftw_execute_dft(plan->plan_mid_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
      FFTW_API(execute_dft)(plan->plan_mid_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
    else
      fftw_execute_dft(plan->plan_mid_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
      FFTW_API(execute_dft)(plan->plan_mid_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExecZ2Z(plan->plan_mid,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
    cufftExec(plan->plan_mid,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
  #else
    if (flag == -1)
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_mid_forward,length);
@@ -300,13 +300,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ

  #if defined(FFT_FFTW3)
    if (flag == -1)
      fftw_execute_dft(plan->plan_slow_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
      FFTW_API(execute_dft)(plan->plan_slow_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
    else
      fftw_execute_dft(plan->plan_slow_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());

      FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExecZ2Z(plan->plan_slow,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
    // CUDA
    cufftExec(plan->plan_slow,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
  #else
    if (flag == -1)
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_slow_forward,length);
@@ -600,38 +598,38 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
    fftw_plan_with_nthreads(nthreads);

  plan->plan_fast_forward =
    fftw_plan_many_dft(1, &nfast,plan->total1/plan->length1,
    FFTW_API(plan_many_dft)(1, &nfast,plan->total1/plan->length1,
                       NULL,&nfast,1,plan->length1,
                       NULL,&nfast,1,plan->length1,
                       FFTW_FORWARD,FFTW_ESTIMATE);

  plan->plan_fast_backward =
    fftw_plan_many_dft(1, &nfast,plan->total1/plan->length1,
    FFTW_API(plan_many_dft)(1, &nfast,plan->total1/plan->length1,
                       NULL,&nfast,1,plan->length1,
                       NULL,&nfast,1,plan->length1,
                       FFTW_BACKWARD,FFTW_ESTIMATE);

  plan->plan_mid_forward =
    fftw_plan_many_dft(1, &nmid,plan->total2/plan->length2,
    FFTW_API(plan_many_dft)(1, &nmid,plan->total2/plan->length2,
                       NULL,&nmid,1,plan->length2,
                       NULL,&nmid,1,plan->length2,
                       FFTW_FORWARD,FFTW_ESTIMATE);

  plan->plan_mid_backward =
    fftw_plan_many_dft(1, &nmid,plan->total2/plan->length2,
    FFTW_API(plan_many_dft)(1, &nmid,plan->total2/plan->length2,
                       NULL,&nmid,1,plan->length2,
                       NULL,&nmid,1,plan->length2,
                       FFTW_BACKWARD,FFTW_ESTIMATE);


  plan->plan_slow_forward =
    fftw_plan_many_dft(1, &nslow,plan->total3/plan->length3,
    FFTW_API(plan_many_dft)(1, &nslow,plan->total3/plan->length3,
                       NULL,&nslow,1,plan->length3,
                       NULL,&nslow,1,plan->length3,
                       FFTW_FORWARD,FFTW_ESTIMATE);

  plan->plan_slow_backward =
    fftw_plan_many_dft(1, &nslow,plan->total3/plan->length3,
    FFTW_API(plan_many_dft)(1, &nslow,plan->total3/plan->length3,
                       NULL,&nslow,1,plan->length3,
                       NULL,&nslow,1,plan->length3,
                       FFTW_BACKWARD,FFTW_ESTIMATE);
@@ -639,17 +637,17 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
  cufftPlanMany(&(plan->plan_fast), 1, &nfast,
    &nfast,1,plan->length1,
    &nfast,1,plan->length1,
    CUFFT_Z2Z,plan->total1/plan->length1);
    CUFFT_TYPE,plan->total1/plan->length1);

  cufftPlanMany(&(plan->plan_mid), 1, &nmid,
    &nmid,1,plan->length2,
    &nmid,1,plan->length2,
    CUFFT_Z2Z,plan->total2/plan->length2);
    CUFFT_TYPE,plan->total2/plan->length2);

  cufftPlanMany(&(plan->plan_slow), 1, &nslow,
    &nslow,1,plan->length3,
    &nslow,1,plan->length3,
    CUFFT_Z2Z,plan->total3/plan->length3);
    CUFFT_TYPE,plan->total3/plan->length3);
#else
  kissfftKK = new KissFFTKokkos<DeviceType>();

@@ -771,18 +769,18 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename AT::t_FFT_DATA_1d d

#if defined(FFT_FFTW3)
  if (flag == -1) {
    fftw_execute_dft(plan->plan_fast_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    fftw_execute_dft(plan->plan_mid_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    fftw_execute_dft(plan->plan_slow_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_fast_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_mid_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_slow_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
  } else {
    fftw_execute_dft(plan->plan_fast_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    fftw_execute_dft(plan->plan_mid_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    fftw_execute_dft(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_fast_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_mid_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
  }
#elif defined(FFT_CUFFT)
  cufftExecZ2Z(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExecZ2Z(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExecZ2Z(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExec(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExec(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExec(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
#else
  kiss_fft_functor<DeviceType> f;
  typename AT::t_FFT_DATA_1d d_tmp = typename AT::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0());
+17 −6
Original line number Diff line number Diff line
@@ -17,9 +17,7 @@
#include "pointers.h"
#include "kokkos_type.h"
#include "remap_kokkos.h"

#define FFT_PRECISION 2
typedef double FFT_SCALAR;
#include "fftdata_kokkos.h"

// if user set FFTW, it means FFTW3

@@ -29,11 +27,24 @@ typedef double FFT_SCALAR;

#if defined(FFT_FFTW3)
  #include "fftw3.h"
  #if defined(FFT_SINGLE)
    typedef fftwf_complex FFT_DATA;
    #define FFTW_API(function)  fftwf_ ## function
  #else
    typedef fftw_complex FFT_DATA;
    #define FFTW_API(function) fftw_ ## function
  #endif
#elif defined(FFT_CUFFT)
  #include "cufft.h"
  #if defined(FFT_SINGLE)
    #define cufftExec cufftExecC2C
    #define CUFFT_TYPE CUFFT_C2C
    typedef cufftComplex FFT_DATA;
  #else
    #define cufftExec cufftExecZ2Z
    #define CUFFT_TYPE CUFFT_Z2Z
    typedef cufftDoubleComplex FFT_DATA;
  #endif
#else
  #include "kissfft_kokkos.h"
  #ifndef FFT_KISSFFT
+1 −1
Original line number Diff line number Diff line
@@ -120,7 +120,7 @@ void RemapKokkos<DeviceType>::remap_3d_kokkos(typename AT::t_FFT_SCALAR_1d d_in,
  // post all recvs into scratch space

  for (irecv = 0; irecv < plan->nrecv; irecv++) {
    double* scratch = d_scratch.ptr_on_device() + plan->recv_bufloc[irecv];
    FFT_SCALAR* scratch = d_scratch.ptr_on_device() + plan->recv_bufloc[irecv];
    MPI_Irecv(scratch,plan->recv_size[irecv],
              MPI_FFT_SCALAR,plan->recv_proc[irecv],0,
              plan->comm,&plan->request[irecv]);