diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 204bf8f81b..c2718abf5d 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -45,8 +45,15 @@ void FFT_Bundle::initfft(int nx_in, if (this->precision=="single") { - #ifndef __ENABLE_FLOAT_FFTW - float_define = false; + #if not defined (__ENABLE_FLOAT_FFTW) + if (this->device == "cpu"){ + float_define = false; + } + #endif + #if defined(__CUDA) || defined (__ROCM) + if (this->device == "gpu"){ + float_flag = float_define; + } #endif float_flag = float_define; double_flag = true;