Skip to content

Commit

Permalink
[nfft] fix direct trafo check
Browse files Browse the repository at this point in the history
  • Loading branch information
HugoStrand committed Nov 9, 2017
1 parent 058002a commit b535177
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions triqs/experimental/nfft_buf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace triqs {
using namespace triqs::gfs;

using dcomplex = std::complex<double>;

// NFFT buffer
template <int Rank> class nfft_buf_t {

Expand Down Expand Up @@ -74,24 +74,27 @@ namespace triqs {
// -- Default init
//nfft_init(plan_ptr.get(), Rank, buf_extents.ptr(), buf_size);

/// compute the next highest power of 2 of 32-bit v
auto next_highest_power_of_two = [](unsigned int v) {
v--; v |= v >> 1; v |= v >> 2; v |= v >> 4; v |= v >> 8; v |= v >> 16; v++;
return v;
};

/// compute the next highest power of 2 of 32-bit v
auto next_highest_power_of_two = [](unsigned int v) {
v--;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v++;
return v;
};

// Init nfft_plan
mini_vector<int, Rank> extents_fftw;
for (int i = 0; i < Rank; i++)
extents_fftw[i] = 2 * next_highest_power_of_two(buf_extents[i]);
mini_vector<int, Rank> extents_fftw;
for (int i = 0; i < Rank; i++) extents_fftw[i] = 2 * next_highest_power_of_two(buf_extents[i]);

unsigned nfft_flags = PRE_PHI_HUT | PRE_PSI | MALLOC_X | MALLOC_F_HAT |
MALLOC_F | FFTW_INIT | FFT_OUT_OF_PLACE | NFFT_SORT_NODES;
unsigned nfft_flags = PRE_PHI_HUT | PRE_PSI | MALLOC_X | MALLOC_F_HAT | MALLOC_F | FFTW_INIT | FFT_OUT_OF_PLACE | NFFT_SORT_NODES;
unsigned fftw_flags = FFTW_ESTIMATE | FFTW_DESTROY_INPUT;

int m = 6; // Truncation order for the window functions
nfft_init_guru(plan_ptr.get(), Rank, buf_extents.ptr(), buf_size,
extents_fftw.ptr(), m, nfft_flags, fftw_flags);
nfft_init_guru(plan_ptr.get(), Rank, buf_extents.ptr(), buf_size, extents_fftw.ptr(), m, nfft_flags, fftw_flags);
}

~nfft_buf_t() {
Expand Down Expand Up @@ -211,12 +214,13 @@ namespace triqs {
void do_nfft() {

// nfft_adjoint() uses a window function (Kaiser-Bessel by default)
// that cannot be constructed for plan_ptr->N_total < 12.
// In the small N_total case one has to call nfft_adjoint_direct() instead,
// which is also faster for the smaller N_total.
// that cannot be constructed for plan_ptr->N[i] < plan_ptr->m.
// In the small N[i] case one has to call nfft_adjoint_direct() instead,
// which is also faster for the smaller N[i].
//
// C.f. https://github.com/NFFT/nfft/issues/34
if (plan_ptr->N_total < 12) {
auto N_min = *std::min_element(plan_ptr->N, plan_ptr->N + plan_ptr->d);
if (N_min <= plan_ptr->m) {

// Execute transform
nfft_adjoint_direct(plan_ptr.get());
Expand Down

0 comments on commit b535177

Please sign in to comment.