From 111eda5bda9819706c52f60723d4aa01e6664d89 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 4 Dec 2023 12:11:47 -0500 Subject: [PATCH] Increasing robustness of adjoint plan optimizations (#123) * Add some safeguards to prioritize type stability in adjoint plans * Test adjoint plans with float32's * Add subtype * Apply suggestions from code review * Rename dummy -> wrapper --- src/definitions.jl | 13 ++++--------- test/TestPlans.jl | 16 ++++++++++++++++ test/runtests.jl | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 9 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index 6eafd81..06143ed 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -681,13 +681,8 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T} dims = fftdims(p) N = normalization(T, size(p), dims) pinv = inv(p) - # Optimization: when pinv is a ScaledPlan, check if we can avoid a loop over x. - # Even if not, ensure that we do only one pass by combining the normalization with the plan. - if pinv isa ScaledPlan && pinv.scale == N - return pinv.p * x - else - return (1/N * pinv) * x - end + # Ensure that we do only one pass over the array by combining the normalization with the plan. + return (inv(N) * pinv) * x end function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real} @@ -698,7 +693,7 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T< pinv = inv(p) n = size(pinv, halfdim) # Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice. - scale = pinv isa ScaledPlan ? pinv.scale / 2N : 1 / 2N + scale = pinv isa ScaledPlan ? pinv.scale / 2N : inv(2N) twoscale = 2 * scale unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv y = map(x, CartesianIndices(x)) do xj, j @@ -721,7 +716,7 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T pinv = inv(p) d = size(pinv, halfdim) # Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice. - scale = pinv isa ScaledPlan ? pinv.scale / N : 1 / N + scale = pinv isa ScaledPlan ? pinv.scale / N : inv(N) twoscale = 2 * scale unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv y = unscaled_pinv * x diff --git a/test/TestPlans.jl b/test/TestPlans.jl index 1961113..1c3459a 100644 --- a/test/TestPlans.jl +++ b/test/TestPlans.jl @@ -278,4 +278,20 @@ AbstractFFTs.plan_inv(p::InplaceTestPlan) = InplaceTestPlan(AbstractFFTs.plan_in # Don't cache inverse of inplace wrapper plan (only inverse of inner plan) Base.inv(p::InplaceTestPlan) = InplaceTestPlan(inv(p.plan)) +# A wrapper plan whose inverse is not an instance of AbstractFFTs.ScaledPlan, for testing purposes + +struct WrapperTestPlan{T,P<:Plan{T}} <: Plan{T} + plan::P +end + +Base.size(p::WrapperTestPlan) = size(p.plan) +Base.ndims(p::WrapperTestPlan) = ndims(p.plan) +AbstractFFTs.fftdims(p::WrapperTestPlan) = fftdims(p.plan) +AbstractFFTs.AdjointStyle(p::WrapperTestPlan) = AbstractFFTs.AdjointStyle(p.plan) + +Base.:*(p::WrapperTestPlan, x::AbstractArray) = p.plan * x + +AbstractFFTs.plan_inv(p::WrapperTestPlan) = WrapperTestPlan(AbstractFFTs.plan_inv(p.plan)) +Base.inv(p::WrapperTestPlan) = WrapperTestPlan(inv(p.plan)) + end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index c8821c9..7d8475f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -145,6 +145,39 @@ end end end +@testset "Adjoint plan on single-precision" begin + # fft + p = plan_fft(zeros(ComplexF32, 3)) + u = rand(ComplexF32, 3) + @test eltype(p' * (p * u)) == eltype(u) + # rfft + p = plan_rfft(zeros(Float32, 3)) + u = rand(Float32, 3) + @test eltype(p' * (p * u)) == eltype(u) + # brfft + p = plan_brfft(zeros(ComplexF32, 3), 5) + u = rand(ComplexF32, 3) + @test eltype(p' * (p * u)) == eltype(u) +end + +@testset "Adjoint plan application when plan inverse is not a ScaledPlan" begin + # fft + p0 = plan_fft(zeros(ComplexF64, 3)) + p = TestPlans.WrapperTestPlan(p0) + u = rand(ComplexF64, 3) + @test p' * u ≈ p0' * u + # rfft + p0 = plan_rfft(zeros(3)) + p = TestPlans.WrapperTestPlan(p0) + u = rand(ComplexF64, 2) + @test p' * u ≈ p0' * u + # brfft + p0 = plan_brfft(zeros(ComplexF64, 3), 5) + p = TestPlans.WrapperTestPlan(p0) + u = rand(Float64, 5) + @test p' * u ≈ p0' * u +end + @testset "ChainRules" begin @testset "shift functions" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5))