Commit 9a43a682 authored by Stan Moore's avatar Stan Moore
Browse files

Fix issues

parent f96609a0
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ if(PKG_KSPACE)
  else()
    set(FFT "KISS" CACHE STRING "FFT library for KSPACE package")
  endif()
  set(FFT_VALUES KISS FFTW MKL CUFFT)
  set(FFT_VALUES KISS FFTW3 MKL CUFFT)
  set_property(CACHE FFT PROPERTY STRINGS ${FFT_VALUES})
  validate_option(FFT FFT_VALUES)
  string(TOUPPER ${FFT} FFT)
+37 −30
Original line number Diff line number Diff line
@@ -224,11 +224,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
      DftiComputeBackward(plan->handle_fast,(FFT_DATA *)d_data.data());
  #elif defined(FFT_FFTW3)
    if (flag == -1)
      fftw_execute_dft(plan->plan_fast_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
      FFTW_API(execute_dft)(plan->plan_fast_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
    else
      fftw_execute_dft(plan->plan_fast_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
      FFTW_API(execute_dft)(plan->plan_fast_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExecZ2Z(plan->plan_fast,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
    cufftExec(plan->plan_fast,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
  #else
    typename AT::t_FFT_DATA_1d d_tmp = 
     typename AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0());
@@ -270,11 +270,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
      DftiComputeBackward(plan->handle_mid,(FFT_DATA *)d_data.data());
  #elif defined(FFT_FFTW3)
    if (flag == -1)
      fftw_execute_dft(plan->plan_mid_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
      FFTW_API(execute_dft)(plan->plan_mid_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
    else
      fftw_execute_dft(plan->plan_mid_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
      FFTW_API(execute_dft)(plan->plan_mid_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExecZ2Z(plan->plan_mid,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
    cufftExec(plan->plan_mid,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
  #else
    if (flag == -1)
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_mid_forward,length);
@@ -312,11 +312,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
      DftiComputeBackward(plan->handle_slow,(FFT_DATA *)d_data.data());
  #elif defined(FFT_FFTW3)
    if (flag == -1)
      fftw_execute_dft(plan->plan_slow_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
      FFTW_API(execute_dft)(plan->plan_slow_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
    else
      fftw_execute_dft(plan->plan_slow_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
      FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
  #elif defined(FFT_CUFFT)
    cufftExecZ2Z(plan->plan_slow,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
    cufftExec(plan->plan_slow,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
  #else
    if (flag == -1)
      f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_slow_forward,length);
@@ -640,42 +640,44 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
  }

#elif defined(FFT_FFTW3)
  if (nthreads > 1)
    fftw_plan_with_nthreads(nthreads);
  if (nthreads > 1) {
    FFTW_API(init_threads)();
    FFTW_API(plan_with_nthreads)(nthreads);
  }

  plan->plan_fast_forward =
    fftw_plan_many_dft(1, &nfast,plan->total1/plan->length1,
    FFTW_API(plan_many_dft)(1, &nfast,plan->total1/plan->length1,
                       NULL,&nfast,1,plan->length1,
                       NULL,&nfast,1,plan->length1,
                       FFTW_FORWARD,FFTW_ESTIMATE);

  plan->plan_fast_backward =
    fftw_plan_many_dft(1, &nfast,plan->total1/plan->length1,
    FFTW_API(plan_many_dft)(1, &nfast,plan->total1/plan->length1,
                       NULL,&nfast,1,plan->length1,
                       NULL,&nfast,1,plan->length1,
                       FFTW_BACKWARD,FFTW_ESTIMATE);

  plan->plan_mid_forward =
    fftw_plan_many_dft(1, &nmid,plan->total2/plan->length2,
    FFTW_API(plan_many_dft)(1, &nmid,plan->total2/plan->length2,
                       NULL,&nmid,1,plan->length2,
                       NULL,&nmid,1,plan->length2,
                       FFTW_FORWARD,FFTW_ESTIMATE);

  plan->plan_mid_backward =
    fftw_plan_many_dft(1, &nmid,plan->total2/plan->length2,
    FFTW_API(plan_many_dft)(1, &nmid,plan->total2/plan->length2,
                       NULL,&nmid,1,plan->length2,
                       NULL,&nmid,1,plan->length2,
                       FFTW_BACKWARD,FFTW_ESTIMATE);


  plan->plan_slow_forward =
    fftw_plan_many_dft(1, &nslow,plan->total3/plan->length3,
    FFTW_API(plan_many_dft)(1, &nslow,plan->total3/plan->length3,
                       NULL,&nslow,1,plan->length3,
                       NULL,&nslow,1,plan->length3,
                       FFTW_FORWARD,FFTW_ESTIMATE);

  plan->plan_slow_backward =
    fftw_plan_many_dft(1, &nslow,plan->total3/plan->length3,
    FFTW_API(plan_many_dft)(1, &nslow,plan->total3/plan->length3,
                       NULL,&nslow,1,plan->length3,
                       NULL,&nslow,1,plan->length3,
                       FFTW_BACKWARD,FFTW_ESTIMATE);
@@ -683,17 +685,17 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
  cufftPlanMany(&(plan->plan_fast), 1, &nfast,
    &nfast,1,plan->length1,
    &nfast,1,plan->length1,
    CUFFT_Z2Z,plan->total1/plan->length1);
    CUFFT_TYPE,plan->total1/plan->length1);

  cufftPlanMany(&(plan->plan_mid), 1, &nmid,
    &nmid,1,plan->length2,
    &nmid,1,plan->length2,
    CUFFT_Z2Z,plan->total2/plan->length2);
    CUFFT_TYPE,plan->total2/plan->length2);

  cufftPlanMany(&(plan->plan_slow), 1, &nslow,
    &nslow,1,plan->length3,
    &nslow,1,plan->length3,
    CUFFT_Z2Z,plan->total3/plan->length3);
    CUFFT_TYPE,plan->total3/plan->length3);
#else
  kissfftKK = new KissFFTKokkos<DeviceType>();

@@ -758,6 +760,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_destroy_plan_kokkos(struct fft_plan_3d_kokk
  FFTW_API(destroy_plan)(plan->plan_mid_backward);
  FFTW_API(destroy_plan)(plan->plan_fast_forward);
  FFTW_API(destroy_plan)(plan->plan_fast_backward);




  FFTW_API(cleanup_threads)();
#elif defined (FFT_KISSFFT)
  delete kissfftKK;
#endif
@@ -839,18 +846,18 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename AT::t_FFT_DATA_1d d
  }
#elif defined(FFT_FFTW3)
  if (flag == -1) {
    fftw_execute_dft(plan->plan_fast_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    fftw_execute_dft(plan->plan_mid_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    fftw_execute_dft(plan->plan_slow_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_fast_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_mid_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_slow_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
  } else {
    fftw_execute_dft(plan->plan_fast_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    fftw_execute_dft(plan->plan_mid_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    fftw_execute_dft(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_fast_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_mid_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
    FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
  }
#elif defined(FFT_CUFFT)
  cufftExecZ2Z(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExecZ2Z(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExecZ2Z(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExec(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExec(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
  cufftExec(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
#else
  kiss_fft_functor<DeviceType> f;
  typename AT::t_FFT_DATA_1d d_tmp = typename AT::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0());
@@ -882,7 +889,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename AT::t_FFT_DATA_1d d
    FFT_SCALAR norm = plan->norm;
    int num = MIN(plan->normnum,nsize);

    norm_functor<DeviceType> f(d_out,norm);
    norm_functor<DeviceType> f(d_data,norm);
    Kokkos::parallel_for(num,f);
  }
}