Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Oct 21, 2023
1 parent d718f90 commit bb52a06
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 10 deletions.
14 changes: 4 additions & 10 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -681,14 +681,8 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(T, size(p), dims)
pinv = inv(p)
# Optimization: when pinv is a ScaledPlan, check if we can avoid a loop over x.
# Even if not, ensure that we do only one pass by combining the normalization with the plan.
RT = AbstractArray{<:T} # canonicalize eltype of returned array to be extra safe about type stability
if pinv isa ScaledPlan && pinv.scale == N
return convert(RT, pinv.p * x)
else
return convert(RT, (inv(N) * pinv) * x)
end
# Ensure that we do only one pass over the array by combining the normalization with the plan.
return (inv(N) * pinv) * x
end

function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
Expand All @@ -699,7 +693,7 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<
pinv = inv(p)
n = size(pinv, halfdim)
# Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice.
scale = pinv isa ScaledPlan ? pinv.scale / 2N : one(pinv.scale) / 2N
scale = pinv isa ScaledPlan ? pinv.scale / 2N : inv(2N)
twoscale = 2 * scale
unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv
y = map(x, CartesianIndices(x)) do xj, j
Expand All @@ -722,7 +716,7 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T
pinv = inv(p)
d = size(pinv, halfdim)
# Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice.
scale = pinv isa ScaledPlan ? pinv.scale / N : one(pinv.scale) / N
scale = pinv isa ScaledPlan ? pinv.scale / N : inv(N)
twoscale = 2 * scale
unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv
y = unscaled_pinv * x
Expand Down
16 changes: 16 additions & 0 deletions test/TestPlans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,4 +278,20 @@ AbstractFFTs.plan_inv(p::InplaceTestPlan) = InplaceTestPlan(AbstractFFTs.plan_in
# Don't cache inverse of inplace wrapper plan (only inverse of inner plan)
Base.inv(p::InplaceTestPlan) = InplaceTestPlan(inv(p.plan))

# A dummy plan whose inverse is not AbstractFFTs.ScaledPlan, for testing purposes

struct DummyTestPlan{T,P<:Plan{T}} <: Plan{T}
plan::P
end

Base.size(p::DummyTestPlan) = size(p.plan)
Base.ndims(p::DummyTestPlan) = ndims(p.plan)
AbstractFFTs.fftdims(p::DummyTestPlan) = fftdims(p.plan)
AbstractFFTs.AdjointStyle(p::DummyTestPlan) = AbstractFFTs.AdjointStyle(p.plan)

Base.:*(p::DummyTestPlan, x::AbstractArray) = p.plan * x

AbstractFFTs.plan_inv(p::DummyTestPlan) = DummyTestPlan(AbstractFFTs.plan_inv(p.plan))
Base.inv(p::DummyTestPlan) = DummyTestPlan(inv(p.plan))

end
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,24 @@ end
@test eltype(p' * (p * u)) == eltype(u)
end

@testset "Adjoint plan application when plan inverse is not a ScaledPlan" begin
# fft
p0 = plan_fft(zeros(ComplexF64, 3))
p = TestPlans.DummyTestPlan(p0)
u = rand(ComplexF64, 3)
@test p' * u p0' * u
# rfft
p0 = plan_rfft(zeros(3))
p = TestPlans.DummyTestPlan(p0)
u = rand(ComplexF64, 2)
@test p' * u p0' * u
# brfft
p0 = plan_brfft(zeros(ComplexF64, 3), 5)
p = TestPlans.DummyTestPlan(p0)
u = rand(Float64, 5)
@test p' * u p0' * u
end

@testset "ChainRules" begin
@testset "shift functions" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
Expand Down

0 comments on commit bb52a06

Please sign in to comment.