From c332e896ef32cb5e144cf29bd99136509a1b087e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 4 Jul 2023 11:59:56 +0200 Subject: [PATCH] Throw error for in-place plans --- ext/AbstractFFTsChainRulesCoreExt.jl | 16 ++++++++++++++-- src/definitions.jl | 4 ++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/ext/AbstractFFTsChainRulesCoreExt.jl b/ext/AbstractFFTsChainRulesCoreExt.jl index aa19724..277c78f 100644 --- a/ext/AbstractFFTsChainRulesCoreExt.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -161,12 +161,18 @@ end # plans function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray) - y = P * x + y = P * x + if Base.mightalias(y, x) + throw(ArgumentError("differentiation rules are not supported for in-place plans")) + end Δy = P * Δx return y, Δy end function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray) y = P * x + if Base.mightalias(y, x) + throw(ArgumentError("differentiation rules are not supported for in-place plans")) + end project_x = ChainRulesCore.ProjectTo(x) Pt = P' function mul_plan_pullback(ȳ) @@ -177,12 +183,18 @@ function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArra end function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray) - y = P * x + y = P * x + if Base.mightalias(y, x) + throw(ArgumentError("differentiation rules are not supported for in-place plans")) + end Δy = P * Δx .+ (ΔP.scale / P.scale) .* y return y, Δy end function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray) y = P * x + if Base.mightalias(y, x) + throw(ArgumentError("differentiation rules are not supported for in-place plans")) + end Pt = P' scale = P.scale project_x = ChainRulesCore.ProjectTo(x) diff --git a/src/definitions.jl b/src/definitions.jl index 9ce8f0d..cb492c4 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -638,7 +638,7 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where [(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n], ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) ) - return p.p \ (x ./ scale) + return p.p \ (x ./ convert(typeof(x), scale)) end function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T} @@ -651,7 +651,7 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) [(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n], ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) ) - return scale ./ N .* (p.p \ x) + return convert(typeof(x), scale) ./ N .* (p.p \ x) end # Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).