Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chain rules for FFT plans via AdjointPlans #67

Merged
merged 28 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ad71816
Implement AdjointPlans
gaurav-arya Jun 9, 2022
c91ad50
Implement chain rules for FFT plans
gaurav-arya Jun 9, 2022
061eef9
Test plan adjoints and AD rules
gaurav-arya Jun 9, 2022
497ff4d
Apply suggestions from adjoint plan code review
gaurav-arya Jun 9, 2022
5d5c06c
Include irrft_dim in RealInverseProjectionStyle
gaurav-arya Jun 9, 2022
ef84edf
update to new fftdims interface
gaurav-arya Jul 1, 2022
d7ff394
fix broken tests
gaurav-arya Jul 1, 2022
aa8e575
Explicitly don't support mul! for adjoint plans
gaurav-arya Jul 1, 2022
9d99886
Document adjoint plans
gaurav-arya Jul 1, 2022
ac7c78c
remove incorrectly thrown error
gaurav-arya Jul 1, 2022
8474141
Update adjoint plan docs
gaurav-arya Jul 14, 2022
769c090
Update adjoint docs
gaurav-arya Jul 14, 2022
3ed83df
Fix typos
gaurav-arya Jul 14, 2022
552d49f
tweak adjoint doc string
gaurav-arya Jul 14, 2022
1e9ece2
Tweaks to adjoint description
gaurav-arya Jul 15, 2022
8ddfa97
Immutable AdjointPlan
gaurav-arya Jul 16, 2022
87758c8
Add rules and tests for ScaledPlan
gaurav-arya Aug 6, 2022
09b8b38
Apply suggestions from code review
gaurav-arya Aug 16, 2022
d967aa2
More tweaks to address code review
gaurav-arya Aug 16, 2022
2a423e2
Restrict to T<:Real for rfft adjoint
gaurav-arya Aug 16, 2022
eedba14
Get type T correct for test irfft
gaurav-arya Aug 16, 2022
25bb86b
Test complex input when appropriate for adjoint tests
gaurav-arya Aug 16, 2022
2a2d685
Merge remote-tracking branch 'origin/master' into adjoint
gaurav-arya Aug 28, 2022
fe3b06a
Add plan_inv implementation for adjoint plan and test it
gaurav-arya Aug 28, 2022
266c88f
Merge branch 'master' into adjoint
devmotion Jun 30, 2023
403ce47
Apply suggestions from code review
devmotion Jul 4, 2023
e137ae3
Apply suggestions from code review
devmotion Jul 5, 2023
e601347
Test in-place plans
devmotion Jul 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ julia = "^1.0"
[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["ChainRulesCore", "ChainRulesTestUtils", "Random", "Test", "Unitful"]
test = ["ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"]
23 changes: 1 addition & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,5 @@ This allows multiple FFT packages to co-exist with the same underlying `fft(x)`

## Developer information

To define a new FFT implementation in your own module, you should
To define a new FFT implementation in your own module, see [defining a new implementation](https://juliamath.github.io/AbstractFFTs.jl/stable/implementations/#Defining-a-new-implementation).

* Define a new subtype (e.g. `MyPlan`) of `AbstractFFTs.Plan{T}` for FFTs and related transforms on arrays of `T`.
This must have a `pinv::Plan` field, initially undefined when a `MyPlan` is created, that is used for caching the
inverse plan.

* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).

* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to
0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`.

* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` (or `A_mul_B!`) method.
This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs.

* If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the
inverse plan to `p`, and `plan_bfft(x, region; kws...)` for an unnormalized inverse ("backwards") transform of `x`.

* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.

The normalization convention for your FFT should be that it computes $y_k = \sum_j \exp\(-2 \pi i \cdot \frac{j k}{n}\) x_j$
for a transform of length $n$, and the "backwards" (unnormalized inverse) transform computes the same thing but with
$\exp\(+2 \pi i \cdot \frac{j k}{n}\)$.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

[compat]
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ AbstractFFTs.plan_rfft
AbstractFFTs.plan_brfft
AbstractFFTs.plan_irfft
AbstractFFTs.fftdims
Base.adjoint
AbstractFFTs.fftshift
AbstractFFTs.fftshift!
AbstractFFTs.ifftshift
Expand Down
41 changes: 28 additions & 13 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,31 @@ The following packages extend the functionality provided by AbstractFFTs:

## Defining a new implementation

Implementations should implement `LinearAlgebra.mul!(Y, plan, X)` (or
`A_mul_B!(y, p::MyPlan, x)` on Julia prior to 0.7.0-DEV.3204) so as to support
pre-allocated output arrays.
We don't define `*` in terms of `mul!` generically here, however, because
of subtleties for in-place and real FFT plans.

To support `inv`, `\`, and `ldiv!(y, plan, x)`, we require `Plan` subtypes
to have a `pinv::Plan` field, which caches the inverse plan, and which should be
initially undefined.
They should also implement `plan_inv(p)` to construct the inverse of a plan `p`.

Implementations only need to provide the unnormalized backwards FFT,
similar to FFTW, and we do the scaling generically to get the inverse FFT.
To define a new FFT implementation in your own module, you should

* Define a new subtype (e.g. `MyPlan`) of `AbstractFFTs.Plan{T}` for FFTs and related transforms on arrays of `T`.
This must have a `pinv::Plan` field, initially undefined when a `MyPlan` is created, that is used for caching the
inverse plan.

* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).

* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` that computes the transform `p` of `x` and stores the result in `y`.

* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` method.
This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs.

* If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the
inverse plan to `p`, and `plan_bfft(x, region; kws...)` for an unnormalized inverse ("backwards") transform of `x`.
Implementations only need to provide the unnormalized backwards FFT, similar to FFTW, and we do the scaling generically
to get the inverse FFT.

* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.

* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return:
* `AbstractFFTs.NoProjectionStyle()`,
* `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref),
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.

The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.
50 changes: 50 additions & 0 deletions ext/AbstractFFTsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,54 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
return y, ifftshift_pullback
end

# plans
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
y = P * x
if Base.mightalias(y, x)
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add tests that this error (and the others below) are thrown?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a bit more involved if we actually want to probe it with FFT plans since currently the test suite does not contain any in-place test plans. I guess the easiest option would be to just re-use the existing out-of-place plans and define in-place updates re-using their implementations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added tests 🙂

end
Δy = P * Δx
return y, Δy
end
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
y = P * x
if Base.mightalias(y, x)
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
end
project_x = ChainRulesCore.ProjectTo(x)
Pt = P'
function mul_plan_pullback(ȳ)
x̄ = project_x(Pt * ȳ)
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄
end
return y, mul_plan_pullback
end

function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
y = P * x
if Base.mightalias(y, x)
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
end
Δy = P * Δx .+ (ΔP.scale / P.scale) .* y
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it inconsistent at all that here we use the tangent of the scale part of P but none of the tangent of the wrapped plan?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm it seems plans are assumed to be constant (AFAICT from the initial version of the PR) but the scaling might change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess there's probably never a good reason a user would want (co)tangents for a Plan. In almost every case the scale of a Plan is just a constant that again a user would never want a (co)tangent for, but perhaps there is one user out there who does, so I can see the point in this.

return y, Δy
end
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
y = P * x
if Base.mightalias(y, x)
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
end
Pt = P'
scale = P.scale
project_x = ChainRulesCore.ProjectTo(x)
project_scale = ChainRulesCore.ProjectTo(scale)
function mul_scaledplan_pullback(ȳ)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if the mul!(y, p, x, a, b) API was supported by AbstractFFTs, because then ChainRules could also define an inplaceable thunk here, and Enzyme rules could avoid an allocation, but maybe outside the scope of this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would require the FFT plan to support fused mul! which isn't guaranteed. To create a fallback implementation, the plan must cache y.

cache = get_cache(plan)
copy!(cache, y)
mul!(y, plan, x)
axpby!(b, cache, a, y)

Feels out of scope for this PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that mul! has to guarantee allocation-free or fused computations (but maybe I'm wrong). Usually, ! only indicates that some (usually but not necessarily the first) argument is updated in-place but sometimes other arguments are updated as well and/or the update is not allocation-free.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is my understanding that LinearAlgebra.mul! is allocation-free. That is what gives it performance advantage over Base.*. To my knowledge, no mutating LinearAlgebra routine allocates a copy of the base array.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is my understanding that LinearAlgebra.mul! is allocation-free.

I quickly checked the Julia repo, and there are a few open issues that show that at least in practice such a guarantee does not exist: JuliaLang/julia#49332 JuliaLang/julia#46865 Arguably these are just bugs but on the other hand the docstring of mul! also does not make any such guarantees.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For both cases, the allocation size is independent of the array size indicating that the arrays are not being allocated. Looks like a spurious size tuple allocation to me.

Examples:

julia> versioninfo()
Julia Version 1.9.1
Commit 147bdf428cd (2023-06-07 08:27 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 8 × Apple M2

JuliaLang/julia#49332

julia> using LinearAlgebra, BenchmarkTools

julia> A = rand(ComplexF64,4,4,1000,1000);

julia> B = similar(A);

julia> a,b = (view(B,:,:,1,1),view(A,:,:,1,1));

julia> @btime mul!($b,$a,$a); # 4x4 * 4x4
  311.283 ns (10 allocations: 608 bytes)

julia> A = rand(ComplexF64,128,128,10,10);

julia> B = similar(A);

julia> a,b = (view(B,:,:,1,1),view(A,:,:,1,1));

julia> @btime mul!($b,$a,$a); # 128x128 * 128x128
  170.542 μs (10 allocations: 608 bytes)

JuliaLang/julia#46865

julia> N = 5_000;

julia> A = rand(N, N); B = rand(N, N); C = rand(N, N);

julia> @time mul!(C, A, B, true, true);
  1.729141 seconds (1 allocation: 16 bytes)

julia> @time mul!(C, A, B);
  1.637079 seconds

julia> @time A * B; # allocates N x N array
  1.421422 seconds (2 allocations: 190.735 MiB, 0.13% gc time)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my understanding from skimming through the issues - and why I wrote arguably these could be considered to be bugs. My main point: There are no guarantees in Julia regarding allocation, the language or the JIT-compiler does not enforce any contracts, so it's only possible to document interfaces and trust people to implement them accordingly. But in the case of mul! no such guarantees are documented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think it makes sense for AbstractFTTs to ultimately support downstream packages implementing either 3-arg or 5-arg mul!, with each defaulting to the other (yes stackoverflow, but if implementing one of them is required, then no overflow can exist). But I do also think this needn't happen in this PR.

x̄ = ChainRulesCore.@thunk(project_x(Pt * ȳ))
scale_tangent = ChainRulesCore.@thunk(project_scale(AbstractFFTs.dot(y, ȳ) / conj(scale)))
plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent)
return ChainRulesCore.NoTangent(), plan_tangent, x̄
end
return y, mul_scaledplan_pullback
end

end # module

79 changes: 79 additions & 0 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T

# size(p) should return the size of the input array for p
size(p::Plan, d) = size(p)[d]
output_size(p::Plan, d) = output_size(p)[d]
ndims(p::Plan) = length(size(p))
length(p::Plan) = prod(size(p))::Int

Expand Down Expand Up @@ -255,6 +256,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)

size(p::ScaledPlan) = size(p.p)
output_size(p::ScaledPlan) = output_size(p.p)

fftdims(p::ScaledPlan) = fftdims(p.p)

Expand Down Expand Up @@ -578,3 +580,80 @@ Pre-plan an optimized real-input unnormalized transform, similar to
the same as for [`brfft`](@ref).
"""
plan_brfft

##############################################################################

struct NoProjectionStyle end
struct RealProjectionStyle end
struct RealInverseProjectionStyle
dim::Int
end
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}

output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))

struct AdjointPlan{T,P<:Plan} <: Plan{T}
p::P
AdjointPlan{T,P}(p) where {T,P} = new(p)
end

"""
(p::Plan)'
adjoint(p::Plan)

Form the adjoint operator of an FFT plan. Returns a plan that performs the adjoint operation of
the original plan. Note that this differs from the corresponding backwards plan in the case of real
FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref).

!!! note
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
coverage of `Base.adjoint` in downstream implementations may be limited.
"""
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p)
Base.adjoint(p::AdjointPlan) = p.p
# always have AdjointPlan inside ScaledPlan.
Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)

size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
return (p.p \ x) / N
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need real here, in contrast to RealInverseProjectionStyle below?

Copy link
Contributor Author

@gaurav-arya gaurav-arya Aug 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right, since we should only expect an rfft plan to operate on real arrays. I've added T<:Real to make this clear.

Also, regarding the use of real(T) for the RealInverseProjectionStyle: I could possibly just match AdjointPlan{Complex{T}} and then use T instead of real(T), since we probably should expect an irfft to operate on complex arrays. (The test plans were actually getting T wrong, i.e. T<:Real for the inverse of a rfft, but I've fixed that in eedba14). However, real(T) seems a little safer in case someone ever wants to write a specialized irfft plan that accepts only real inputs.

halfdim = first(dims)
d = size(p.p, halfdim)
n = output_size(p.p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return p.p \ (x ./ convert(typeof(x), scale))
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(real(T), output_size(p.p), dims)
halfdim = first(dims)
n = size(p.p, halfdim)
d = output_size(p.p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return (convert(typeof(x), scale) ./ N) .* (p.p \ x)
end

# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
inv(p::AdjointPlan) = adjoint(inv(p.p))
Loading
Loading