Skip to content

Commit

Permalink
Eagerly invert plan on formation of AdjointPlan: correct eltype and r…
Browse files Browse the repository at this point in the history
…emove output_size (#113)

* Eagerly compute inv(p) in adjoint plans, and use it to fix eltype and
remove output_size

* Remove output_size tests, replace with size test for all plans in TestUtils

* Ensure <= 1 pass over array in adjoint application

* Try to fix type stability issues

* Fix another type instability

* Replace division with multiplication in adjoint loop

* Switch isapprox to equality

Co-authored-by: Steven G. Johnson <stevenj@mit.edu>

* Lift out multiply-by-2

---------

Co-authored-by: Steven G. Johnson <stevenj@mit.edu>
  • Loading branch information
gaurav-arya and stevengj authored Sep 4, 2023
1 parent fae1170 commit be4aa9b
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 67 deletions.
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ It is also relevant to implementers of FFT plans that wish to support adjoints.
```@docs
Base.adjoint
AbstractFFTs.AdjointStyle
AbstractFFTs.output_size
AbstractFFTs.adjoint_mul
AbstractFFTs.FFTAdjointStyle
AbstractFFTs.RFFTAdjointStyle
Expand Down
2 changes: 1 addition & 1 deletion docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ To define a new FFT implementation in your own module, you should

* To support adjoints in a new plan, define the trait [`AbstractFFTs.AdjointStyle`](@ref).
`AbstractFFTs` implements the following adjoint styles: [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), [`AbstractFFTs.IRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref).
To define a new adjoint style, define the methods [`AbstractFFTs.adjoint_mul`](@ref) and [`AbstractFFTs.output_size`](@ref).
To define a new adjoint style, define the method [`AbstractFFTs.adjoint_mul`](@ref).

The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.
Expand Down
5 changes: 3 additions & 2 deletions ext/AbstractFFTsTestExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ const TEST_CASES = (

function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false)
_copy = copy_input ? copy : identity
@test size(P) == size(x)
if !inplace_plan
@test P * _copy(x) x_transformed
@test P \ (P * _copy(x)) x
Expand All @@ -74,9 +75,9 @@ function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; rea
_copy = copy_input ? copy : identity
y = rand(eltype(P * _copy(x)), size(P * _copy(x)))
# test basic properties
@test_skip eltype(P') === typeof(y) # (AbstractFFTs.jl#110)
@test eltype(P') === eltype(y)
@test (P')' === P # test adjoint of adjoint
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
@test size(P') == size(y) # test size of adjoint
# test correctness of adjoint and its inverse via the dot test
if !real_plan
@test dot(y, P * _copy(x)) dot(P' * _copy(y), x)
Expand Down
60 changes: 32 additions & 28 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)

size(p::ScaledPlan) = size(p.p)
output_size(p::ScaledPlan) = output_size(p.p)

fftdims(p::ScaledPlan) = fftdims(p.p)

Expand Down Expand Up @@ -640,20 +639,6 @@ Adjoint style for unitary transforms, whose adjoint equals their inverse.
"""
struct UnitaryAdjointStyle <: AdjointStyle end

"""
output_size(p::Plan, [dim])
Return the size of the output of a plan `p`, optionally at a specified dimension `dim`.
Implementations of a new adjoint style `AS <: AbstractFFTs.AdjointStyle` should define `output_size(::Plan, ::AS)`.
"""
output_size(p::Plan) = output_size(p, AdjointStyle(p))
output_size(p::Plan, dim) = output_size(p)[dim]
output_size(p::Plan, ::FFTAdjointStyle) = size(p)
output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p))
output_size(p::Plan, s::IRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
output_size(p::Plan, ::UnitaryAdjointStyle) = size(p)

struct AdjointPlan{T,P<:Plan} <: Plan{T}
p::P
AdjointPlan{T,P}(p) where {T,P} = new(p)
Expand All @@ -669,13 +654,15 @@ Return a plan that performs the adjoint operation of the original plan.
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
coverage of `Base.adjoint` in downstream implementations may be limited.
"""
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p)
# We eagerly form the plan inverse in the adjoint(p) call, which will be cached for subsequent calls.
# This is reasonable, as inv(p) would do the same, and necessary in order to compute the correct input
# type for the adjoint plan and encode it in its type.
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{eltype(inv(p)), typeof(p)}(p)
Base.adjoint(p::AdjointPlan) = p.p
# always have AdjointPlan inside ScaledPlan.
Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)

size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)
size(p::AdjointPlan) = size(inv(p.p))
fftdims(p::AdjointPlan) = fftdims(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x)
Expand All @@ -693,40 +680,57 @@ adjoint_mul(p::Plan, x::AbstractArray) = adjoint_mul(p, x, AdjointStyle(p))
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(T, size(p), dims)
return (p \ x) / N
pinv = inv(p)
# Optimization: when pinv is a ScaledPlan, check if we can avoid a loop over x.
# Even if not, ensure that we do only one pass by combining the normalization with the plan.
if pinv isa ScaledPlan && pinv.scale == N
return pinv.p * x
else
return (1/N * pinv) * x
end
end

function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
dims = fftdims(p)
N = normalization(T, size(p), dims)
halfdim = first(dims)
d = size(p, halfdim)
n = output_size(p, halfdim)
pinv = inv(p)
n = size(pinv, halfdim)
# Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice.
scale = pinv isa ScaledPlan ? pinv.scale / 2N : 1 / 2N
twoscale = 2 * scale
unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv
y = map(x, CartesianIndices(x)) do xj, j
i = j[halfdim]
yj = if i == 1 || (i == n && 2 * (i - 1) == d)
xj / N
xj * twoscale
else
xj / (2 * N)
xj * scale
end
return yj
end
return p \ y
return unscaled_pinv * y
end

function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(real(T), output_size(p), dims)
N = normalization(real(T), size(inv(p)), dims)
halfdim = first(dims)
n = size(p, halfdim)
d = output_size(p, halfdim)
y = p \ x
pinv = inv(p)
d = size(pinv, halfdim)
# Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice.
scale = pinv isa ScaledPlan ? pinv.scale / N : 1 / N
twoscale = 2 * scale
unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv
y = unscaled_pinv * x
z = map(y, CartesianIndices(y)) do yj, j
i = j[halfdim]
zj = if i == 1 || (i == n && 2 * (i - 1) == d)
yj / N
yj * scale
else
2 * yj / N
yj * twoscale
end
return zj
end
Expand Down
35 changes: 0 additions & 35 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,41 +118,6 @@ end
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
end

@testset "output size" begin
@testset "complex fft output size" begin
for x_shape in ((3,), (3, 4), (3, 4, 5))
N = length(x_shape)
real_x = randn(x_shape)
complex_x = randn(ComplexF64, x_shape)
for x in (real_x, complex_x)
for dims in unique((1, 1:N, N))
P = plan_fft(x, dims)
@test @inferred(AbstractFFTs.output_size(P)) == size(x)
@test AbstractFFTs.output_size(P') == size(x)
Pinv = plan_ifft(x)
@test AbstractFFTs.output_size(Pinv) == size(x)
@test AbstractFFTs.output_size(Pinv') == size(x)
end
end
end
end
@testset "real fft output size" begin
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths
N = ndims(x)
for dims in unique((1, 1:N, N))
P = plan_rfft(x, dims)
Px_sz = size(P * x)
@test AbstractFFTs.output_size(P) == Px_sz
@test AbstractFFTs.output_size(P') == size(x)
y = randn(ComplexF64, Px_sz)
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
@test AbstractFFTs.output_size(Pinv) == size(Pinv * y)
@test AbstractFFTs.output_size(Pinv') == size(y)
end
end
end
end

# Test that dims defaults to 1:ndims for fft-like functions
@testset "Default dims" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
Expand Down

0 comments on commit be4aa9b

Please sign in to comment.