From e2780de1c1aa3ca28952dfed2d00fdd5e692a58f Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Jul 2023 19:12:00 -0400 Subject: [PATCH 1/8] Eagerly compute inv(p) in adjoint plans, and use it to fix eltype and remove output_size --- docs/src/api.md | 1 - docs/src/implementations.md | 2 +- ext/AbstractFFTsTestExt.jl | 4 ++-- src/definitions.jl | 29 ++++++++--------------------- 4 files changed, 11 insertions(+), 25 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 713e62d..45e4ee2 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -38,7 +38,6 @@ It is also relevant to implementers of FFT plans that wish to support adjoints. ```@docs Base.adjoint AbstractFFTs.AdjointStyle -AbstractFFTs.output_size AbstractFFTs.adjoint_mul AbstractFFTs.FFTAdjointStyle AbstractFFTs.RFFTAdjointStyle diff --git a/docs/src/implementations.md b/docs/src/implementations.md index 81deb76..8861621 100644 --- a/docs/src/implementations.md +++ b/docs/src/implementations.md @@ -35,7 +35,7 @@ To define a new FFT implementation in your own module, you should * To support adjoints in a new plan, define the trait [`AbstractFFTs.AdjointStyle`](@ref). `AbstractFFTs` implements the following adjoint styles: [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), [`AbstractFFTs.IRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref). - To define a new adjoint style, define the methods [`AbstractFFTs.adjoint_mul`](@ref) and [`AbstractFFTs.output_size`](@ref). + To define a new adjoint style, define the method [`AbstractFFTs.adjoint_mul`](@ref). The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``. diff --git a/ext/AbstractFFTsTestExt.jl b/ext/AbstractFFTsTestExt.jl index ccea93a..451c8c7 100644 --- a/ext/AbstractFFTsTestExt.jl +++ b/ext/AbstractFFTsTestExt.jl @@ -73,9 +73,9 @@ function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; rea _copy = copy_input ? copy : identity y = rand(eltype(P * _copy(x)), size(P * _copy(x))) # test basic properties - @test_skip eltype(P') === typeof(y) # (AbstractFFTs.jl#110) + @test eltype(P') === eltype(y) @test (P')' === P # test adjoint of adjoint - @test size(P') == AbstractFFTs.output_size(P) # test size of adjoint + @test size(P') == size(y) # test size of adjoint # test correctness of adjoint and its inverse via the dot test if !real_plan @test dot(y, P * _copy(x)) ≈ dot(P' * _copy(y), x) diff --git a/src/definitions.jl b/src/definitions.jl index 5dc703f..f6985e3 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -259,7 +259,6 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale) ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α) size(p::ScaledPlan) = size(p.p) -output_size(p::ScaledPlan) = output_size(p.p) fftdims(p::ScaledPlan) = fftdims(p.p) @@ -640,20 +639,6 @@ Adjoint style for unitary transforms, whose adjoint equals their inverse. """ struct UnitaryAdjointStyle <: AdjointStyle end -""" - output_size(p::Plan, [dim]) - -Return the size of the output of a plan `p`, optionally at a specified dimension `dim`. - -Implementations of a new adjoint style `AS <: AbstractFFTs.AdjointStyle` should define `output_size(::Plan, ::AS)`. -""" -output_size(p::Plan) = output_size(p, AdjointStyle(p)) -output_size(p::Plan, dim) = output_size(p)[dim] -output_size(p::Plan, ::FFTAdjointStyle) = size(p) -output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p)) -output_size(p::Plan, s::IRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p)) -output_size(p::Plan, ::UnitaryAdjointStyle) = size(p) - struct AdjointPlan{T,P<:Plan} <: Plan{T} p::P AdjointPlan{T,P}(p) where {T,P} = new(p) @@ -669,13 +654,15 @@ Return a plan that performs the adjoint operation of the original plan. Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`, coverage of `Base.adjoint` in downstream implementations may be limited. """ -Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p) +# We eagerly form the plan inverse in the adjoint(p) call, which will be cached for subsequent calls. +# This is reasonable, as inv(p) would do the same, and necessary in order to compute the correct input +# type for the adjoint plan and encode it in its type. +Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{eltype(inv(p)), typeof(p)}(p) Base.adjoint(p::AdjointPlan) = p.p # always have AdjointPlan inside ScaledPlan. Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale) -size(p::AdjointPlan) = output_size(p.p) -output_size(p::AdjointPlan) = size(p.p) +size(p::AdjointPlan) = size(inv(p.p)) fftdims(p::AdjointPlan) = fftdims(p.p) Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x) @@ -701,7 +688,7 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T< N = normalization(T, size(p), dims) halfdim = first(dims) d = size(p, halfdim) - n = output_size(p, halfdim) + n = size(inv(p), halfdim) scale = reshape( [(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n], ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) @@ -711,10 +698,10 @@ end function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T} dims = fftdims(p) - N = normalization(real(T), output_size(p), dims) + N = normalization(real(T), size(inv(p)), dims) halfdim = first(dims) n = size(p, halfdim) - d = output_size(p, halfdim) + d = size(inv(p), halfdim) scale = reshape( [(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n], ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) From 43cd52a8f4ca89c64a4bbcc050a98df53e0b33b2 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 29 Jul 2023 19:16:27 -0400 Subject: [PATCH 2/8] Remove output_size tests, replace with size test for all plans in TestUtils --- ext/AbstractFFTsTestExt.jl | 1 + test/runtests.jl | 35 ----------------------------------- 2 files changed, 1 insertion(+), 35 deletions(-) diff --git a/ext/AbstractFFTsTestExt.jl b/ext/AbstractFFTsTestExt.jl index 451c8c7..d75249d 100644 --- a/ext/AbstractFFTsTestExt.jl +++ b/ext/AbstractFFTsTestExt.jl @@ -54,6 +54,7 @@ const TEST_CASES = ( function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false) _copy = copy_input ? copy : identity + @test size(P) == size(x) if !inplace_plan @test P * _copy(x) ≈ x_transformed @test P \ (P * _copy(x)) ≈ x diff --git a/test/runtests.jl b/test/runtests.jl index fe74897..c8821c9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -118,41 +118,6 @@ end @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 end -@testset "output size" begin - @testset "complex fft output size" begin - for x_shape in ((3,), (3, 4), (3, 4, 5)) - N = length(x_shape) - real_x = randn(x_shape) - complex_x = randn(ComplexF64, x_shape) - for x in (real_x, complex_x) - for dims in unique((1, 1:N, N)) - P = plan_fft(x, dims) - @test @inferred(AbstractFFTs.output_size(P)) == size(x) - @test AbstractFFTs.output_size(P') == size(x) - Pinv = plan_ifft(x) - @test AbstractFFTs.output_size(Pinv) == size(x) - @test AbstractFFTs.output_size(Pinv') == size(x) - end - end - end - end - @testset "real fft output size" begin - for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths - N = ndims(x) - for dims in unique((1, 1:N, N)) - P = plan_rfft(x, dims) - Px_sz = size(P * x) - @test AbstractFFTs.output_size(P) == Px_sz - @test AbstractFFTs.output_size(P') == size(x) - y = randn(ComplexF64, Px_sz) - Pinv = plan_irfft(y, size(x)[first(dims)], dims) - @test AbstractFFTs.output_size(Pinv) == size(Pinv * y) - @test AbstractFFTs.output_size(Pinv') == size(y) - end - end - end -end - # Test that dims defaults to 1:ndims for fft-like functions @testset "Default dims" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5)) From f9b5de31ff2d5ab9837fd302ca2f484d7db865fa Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 22 Aug 2023 16:50:33 -0400 Subject: [PATCH 3/8] Ensure <= 1 pass over array in adjoint application --- src/definitions.jl | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index 5c6e3f2..5aca959 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -680,7 +680,15 @@ adjoint_mul(p::Plan, x::AbstractArray) = adjoint_mul(p, x, AdjointStyle(p)) function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T} dims = fftdims(p) N = normalization(T, size(p), dims) - return (p \ x) / N + pinv = inv(p) + # Optimization: when pinv is a ScaledPlan, check if we can avoid a loop over x. + # Even if not, ensure that we do only one pass by combining the normalization with the plan. + if pinv isa ScaledPlan && isapprox(pinv.scale, N) + scaled_pinv = pinv.p + else + scaled_pinv = (1/N) * pinv + end + return scaled_pinv * x end function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real} @@ -688,7 +696,15 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T< N = normalization(T, size(p), dims) halfdim = first(dims) d = size(p, halfdim) - n = size(inv(p), halfdim) + pinv = inv(p) + n = size(pinv, halfdim) + # Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice. + if pinv isa ScaledPlan + N /= pinv.scale + scaled_pinv = pinv.p + else + scaled_pinv = pinv + end y = map(x, CartesianIndices(x)) do xj, j i = j[halfdim] yj = if i == 1 || (i == n && 2 * (i - 1) == d) @@ -698,7 +714,7 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T< end return yj end - return p \ y + return scaled_pinv * y end function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T} @@ -706,8 +722,16 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T N = normalization(real(T), size(inv(p)), dims) halfdim = first(dims) n = size(p, halfdim) - d = size(inv(p), halfdim) - y = p \ x + pinv = inv(p) + d = size(pinv, halfdim) + # Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice. + if pinv isa ScaledPlan + N /= pinv.scale + scaled_pinv = pinv.p + else + scaled_pinv = pinv + end + y = scaled_pinv * x z = map(y, CartesianIndices(y)) do yj, j i = j[halfdim] zj = if i == 1 || (i == n && 2 * (i - 1) == d) From 18631da82e465bdbf14cafa0d58423dbbc9ac17f Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 22 Aug 2023 17:13:22 -0400 Subject: [PATCH 4/8] Try to fix type stability issues --- src/definitions.jl | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index 5aca959..eb68e88 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -681,30 +681,26 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T} dims = fftdims(p) N = normalization(T, size(p), dims) pinv = inv(p) - # Optimization: when pinv is a ScaledPlan, check if we can avoid a loop over x. + # Optimization: when pinv is a ScaledPlan, check if we can avoid a loop over x. # Even if not, ensure that we do only one pass by combining the normalization with the plan. - if pinv isa ScaledPlan && isapprox(pinv.scale, N) - scaled_pinv = pinv.p + scaled_pinv = if pinv isa ScaledPlan && isapprox(pinv.scale, N) + pinv.p else - scaled_pinv = (1/N) * pinv + (1/N) * pinv end return scaled_pinv * x end function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real} dims = fftdims(p) - N = normalization(T, size(p), dims) + _N = normalization(T, size(p), dims) halfdim = first(dims) d = size(p, halfdim) pinv = inv(p) n = size(pinv, halfdim) # Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice. - if pinv isa ScaledPlan - N /= pinv.scale - scaled_pinv = pinv.p - else - scaled_pinv = pinv - end + N = pinv isa ScaledPlan ? _N / pinv.scale : _N + scaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv y = map(x, CartesianIndices(x)) do xj, j i = j[halfdim] yj = if i == 1 || (i == n && 2 * (i - 1) == d) @@ -719,18 +715,14 @@ end function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T} dims = fftdims(p) - N = normalization(real(T), size(inv(p)), dims) + _N = normalization(real(T), size(inv(p)), dims) halfdim = first(dims) n = size(p, halfdim) pinv = inv(p) d = size(pinv, halfdim) # Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice. - if pinv isa ScaledPlan - N /= pinv.scale - scaled_pinv = pinv.p - else - scaled_pinv = pinv - end + N = pinv isa ScaledPlan ? _N / pinv.scale : _N + scaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv y = scaled_pinv * x z = map(y, CartesianIndices(y)) do yj, j i = j[halfdim] From ee4b2253fba95762a4f5d734205fa338f1c98edb Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 22 Aug 2023 17:24:20 -0400 Subject: [PATCH 5/8] Fix another type instability --- src/definitions.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index eb68e88..efead20 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -683,12 +683,11 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T} pinv = inv(p) # Optimization: when pinv is a ScaledPlan, check if we can avoid a loop over x. # Even if not, ensure that we do only one pass by combining the normalization with the plan. - scaled_pinv = if pinv isa ScaledPlan && isapprox(pinv.scale, N) - pinv.p + if pinv isa ScaledPlan && isapprox(pinv.scale, N) + return pinv.p * x else - (1/N) * pinv + return (1/N * pinv) * x end - return scaled_pinv * x end function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real} From fe0f7f4d0677e019608183f65caae963aef95d14 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 22 Aug 2023 20:51:03 -0400 Subject: [PATCH 6/8] Replace division with multiplication in adjoint loop --- src/definitions.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index efead20..0ab9f16 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -692,43 +692,43 @@ end function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real} dims = fftdims(p) - _N = normalization(T, size(p), dims) + N = normalization(T, size(p), dims) halfdim = first(dims) d = size(p, halfdim) pinv = inv(p) n = size(pinv, halfdim) # Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice. - N = pinv isa ScaledPlan ? _N / pinv.scale : _N - scaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv + scale = pinv isa ScaledPlan ? pinv.scale / 2N : 1 / 2N + unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv y = map(x, CartesianIndices(x)) do xj, j i = j[halfdim] yj = if i == 1 || (i == n && 2 * (i - 1) == d) - xj / N + xj * scale * 2 else - xj / (2 * N) + xj * scale end return yj end - return scaled_pinv * y + return unscaled_pinv * y end function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T} dims = fftdims(p) - _N = normalization(real(T), size(inv(p)), dims) + N = normalization(real(T), size(inv(p)), dims) halfdim = first(dims) n = size(p, halfdim) pinv = inv(p) d = size(pinv, halfdim) # Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice. - N = pinv isa ScaledPlan ? _N / pinv.scale : _N - scaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv - y = scaled_pinv * x + scale = pinv isa ScaledPlan ? pinv.scale / N : 1 / N + unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv + y = unscaled_pinv * x z = map(y, CartesianIndices(y)) do yj, j i = j[halfdim] zj = if i == 1 || (i == n && 2 * (i - 1) == d) - yj / N + yj * scale else - 2 * yj / N + yj * scale * 2 end return zj end From a115e97d83e8169dcaa90983b7ae16020cbabbc9 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 22 Aug 2023 20:51:39 -0400 Subject: [PATCH 7/8] Switch isapprox to equality Co-authored-by: Steven G. Johnson --- src/definitions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/definitions.jl b/src/definitions.jl index 0ab9f16..0375737 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -683,7 +683,7 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T} pinv = inv(p) # Optimization: when pinv is a ScaledPlan, check if we can avoid a loop over x. # Even if not, ensure that we do only one pass by combining the normalization with the plan. - if pinv isa ScaledPlan && isapprox(pinv.scale, N) + if pinv isa ScaledPlan && pinv.scale == N return pinv.p * x else return (1/N * pinv) * x From d194cbe0e903086dd3636b3833eb5fcd2fb90866 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 22 Aug 2023 20:56:02 -0400 Subject: [PATCH 8/8] Lift out multiply-by-2 --- src/definitions.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index 0375737..6eafd81 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -699,11 +699,12 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T< n = size(pinv, halfdim) # Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice. scale = pinv isa ScaledPlan ? pinv.scale / 2N : 1 / 2N + twoscale = 2 * scale unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv y = map(x, CartesianIndices(x)) do xj, j i = j[halfdim] yj = if i == 1 || (i == n && 2 * (i - 1) == d) - xj * scale * 2 + xj * twoscale else xj * scale end @@ -721,6 +722,7 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T d = size(pinv, halfdim) # Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice. scale = pinv isa ScaledPlan ? pinv.scale / N : 1 / N + twoscale = 2 * scale unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv y = unscaled_pinv * x z = map(y, CartesianIndices(y)) do yj, j @@ -728,7 +730,7 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T zj = if i == 1 || (i == n && 2 * (i - 1) == d) yj * scale else - yj * scale * 2 + yj * twoscale end return zj end