Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eagerly invert plan on formation of AdjointPlan: correct eltype and remove output_size #113

Merged
merged 9 commits into from
Sep 4, 2023
34 changes: 29 additions & 5 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -680,15 +680,31 @@ adjoint_mul(p::Plan, x::AbstractArray) = adjoint_mul(p, x, AdjointStyle(p))
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(T, size(p), dims)
return (p \ x) / N
pinv = inv(p)
# Optimization: when pinv is a ScaledPlan, check if we can avoid a loop over x.
# Even if not, ensure that we do only one pass by combining the normalization with the plan.
if pinv isa ScaledPlan && isapprox(pinv.scale, N)
gaurav-arya marked this conversation as resolved.
Show resolved Hide resolved
scaled_pinv = pinv.p
else
scaled_pinv = (1/N) * pinv
end
return scaled_pinv * x
end

function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
dims = fftdims(p)
N = normalization(T, size(p), dims)
halfdim = first(dims)
d = size(p, halfdim)
n = size(inv(p), halfdim)
pinv = inv(p)
n = size(pinv, halfdim)
# Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice.
if pinv isa ScaledPlan
N /= pinv.scale
gaurav-arya marked this conversation as resolved.
Show resolved Hide resolved
scaled_pinv = pinv.p
else
scaled_pinv = pinv
end
y = map(x, CartesianIndices(x)) do xj, j
i = j[halfdim]
yj = if i == 1 || (i == n && 2 * (i - 1) == d)
Expand All @@ -698,16 +714,24 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<
end
return yj
end
return p \ y
return scaled_pinv * y
end

function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(real(T), size(inv(p)), dims)
halfdim = first(dims)
n = size(p, halfdim)
d = size(inv(p), halfdim)
y = p \ x
pinv = inv(p)
d = size(pinv, halfdim)
# Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice.
if pinv isa ScaledPlan
N /= pinv.scale
scaled_pinv = pinv.p
else
scaled_pinv = pinv
end
y = scaled_pinv * x
z = map(y, CartesianIndices(y)) do yj, j
i = j[halfdim]
zj = if i == 1 || (i == n && 2 * (i - 1) == d)
Expand Down
Loading