Skip to content

Commit

Permalink
Increasing robustness of adjoint plan optimizations (#123)
Browse files Browse the repository at this point in the history
* Add some safeguards to prioritize type stability in adjoint plans

* Test adjoint plans with float32's

* Add subtype

* Apply suggestions from code review

* Rename dummy -> wrapper
  • Loading branch information
gaurav-arya authored Dec 4, 2023
1 parent a67bf15 commit 111eda5
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
13 changes: 4 additions & 9 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -681,13 +681,8 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(T, size(p), dims)
pinv = inv(p)
# Optimization: when pinv is a ScaledPlan, check if we can avoid a loop over x.
# Even if not, ensure that we do only one pass by combining the normalization with the plan.
if pinv isa ScaledPlan && pinv.scale == N
return pinv.p * x
else
return (1/N * pinv) * x
end
# Ensure that we do only one pass over the array by combining the normalization with the plan.
return (inv(N) * pinv) * x
end

function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
Expand All @@ -698,7 +693,7 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<
pinv = inv(p)
n = size(pinv, halfdim)
# Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice.
scale = pinv isa ScaledPlan ? pinv.scale / 2N : 1 / 2N
scale = pinv isa ScaledPlan ? pinv.scale / 2N : inv(2N)
twoscale = 2 * scale
unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv
y = map(x, CartesianIndices(x)) do xj, j
Expand All @@ -721,7 +716,7 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T
pinv = inv(p)
d = size(pinv, halfdim)
# Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice.
scale = pinv isa ScaledPlan ? pinv.scale / N : 1 / N
scale = pinv isa ScaledPlan ? pinv.scale / N : inv(N)
twoscale = 2 * scale
unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv
y = unscaled_pinv * x
Expand Down
16 changes: 16 additions & 0 deletions test/TestPlans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,4 +278,20 @@ AbstractFFTs.plan_inv(p::InplaceTestPlan) = InplaceTestPlan(AbstractFFTs.plan_in
# Don't cache inverse of inplace wrapper plan (only inverse of inner plan)
Base.inv(p::InplaceTestPlan) = InplaceTestPlan(inv(p.plan))

# A wrapper plan whose inverse is not an instance of AbstractFFTs.ScaledPlan, for testing purposes

struct WrapperTestPlan{T,P<:Plan{T}} <: Plan{T}
plan::P
end

Base.size(p::WrapperTestPlan) = size(p.plan)
Base.ndims(p::WrapperTestPlan) = ndims(p.plan)
AbstractFFTs.fftdims(p::WrapperTestPlan) = fftdims(p.plan)
AbstractFFTs.AdjointStyle(p::WrapperTestPlan) = AbstractFFTs.AdjointStyle(p.plan)

Base.:*(p::WrapperTestPlan, x::AbstractArray) = p.plan * x

AbstractFFTs.plan_inv(p::WrapperTestPlan) = WrapperTestPlan(AbstractFFTs.plan_inv(p.plan))
Base.inv(p::WrapperTestPlan) = WrapperTestPlan(inv(p.plan))

end
33 changes: 33 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,39 @@ end
end
end

@testset "Adjoint plan on single-precision" begin
# fft
p = plan_fft(zeros(ComplexF32, 3))
u = rand(ComplexF32, 3)
@test eltype(p' * (p * u)) == eltype(u)
# rfft
p = plan_rfft(zeros(Float32, 3))
u = rand(Float32, 3)
@test eltype(p' * (p * u)) == eltype(u)
# brfft
p = plan_brfft(zeros(ComplexF32, 3), 5)
u = rand(ComplexF32, 3)
@test eltype(p' * (p * u)) == eltype(u)
end

@testset "Adjoint plan application when plan inverse is not a ScaledPlan" begin
# fft
p0 = plan_fft(zeros(ComplexF64, 3))
p = TestPlans.WrapperTestPlan(p0)
u = rand(ComplexF64, 3)
@test p' * u p0' * u
# rfft
p0 = plan_rfft(zeros(3))
p = TestPlans.WrapperTestPlan(p0)
u = rand(ComplexF64, 2)
@test p' * u p0' * u
# brfft
p0 = plan_brfft(zeros(ComplexF64, 3), 5)
p = TestPlans.WrapperTestPlan(p0)
u = rand(Float64, 5)
@test p' * u p0' * u
end

@testset "ChainRules" begin
@testset "shift functions" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
Expand Down

0 comments on commit 111eda5

Please sign in to comment.