From fe0f7f4d0677e019608183f65caae963aef95d14 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 22 Aug 2023 20:51:03 -0400 Subject: [PATCH] Replace division with multiplication in adjoint loop --- src/definitions.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index efead20..0ab9f16 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -692,43 +692,43 @@ end function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real} dims = fftdims(p) - _N = normalization(T, size(p), dims) + N = normalization(T, size(p), dims) halfdim = first(dims) d = size(p, halfdim) pinv = inv(p) n = size(pinv, halfdim) # Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice. - N = pinv isa ScaledPlan ? _N / pinv.scale : _N - scaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv + scale = pinv isa ScaledPlan ? pinv.scale / 2N : 1 / 2N + unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv y = map(x, CartesianIndices(x)) do xj, j i = j[halfdim] yj = if i == 1 || (i == n && 2 * (i - 1) == d) - xj / N + xj * scale * 2 else - xj / (2 * N) + xj * scale end return yj end - return scaled_pinv * y + return unscaled_pinv * y end function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T} dims = fftdims(p) - _N = normalization(real(T), size(inv(p)), dims) + N = normalization(real(T), size(inv(p)), dims) halfdim = first(dims) n = size(p, halfdim) pinv = inv(p) d = size(pinv, halfdim) # Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice. - N = pinv isa ScaledPlan ? _N / pinv.scale : _N - scaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv - y = scaled_pinv * x + scale = pinv isa ScaledPlan ? pinv.scale / N : 1 / N + unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv + y = unscaled_pinv * x z = map(y, CartesianIndices(y)) do yj, j i = j[halfdim] zj = if i == 1 || (i == n && 2 * (i - 1) == d) - yj / N + yj * scale else - 2 * yj / N + yj * scale * 2 end return zj end