Skip to content

Commit

Permalink
Throw error for in-place plans
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jul 4, 2023
1 parent 266c88f commit c332e89
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
16 changes: 14 additions & 2 deletions ext/AbstractFFTsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,18 @@ end

# plans
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
y = P * x
y = P * x
if Base.mightalias(y, x)
throw(ArgumentError("differentiation rules are not supported for in-place plans"))

Check warning on line 166 in ext/AbstractFFTsChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsChainRulesCoreExt.jl#L166

Added line #L166 was not covered by tests
end
Δy = P * Δx
return y, Δy
end
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
y = P * x
if Base.mightalias(y, x)
throw(ArgumentError("differentiation rules are not supported for in-place plans"))

Check warning on line 174 in ext/AbstractFFTsChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsChainRulesCoreExt.jl#L174

Added line #L174 was not covered by tests
end
project_x = ChainRulesCore.ProjectTo(x)
Pt = P'
function mul_plan_pullback(ȳ)
Expand All @@ -177,12 +183,18 @@ function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArra
end

function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
y = P * x
y = P * x
if Base.mightalias(y, x)
throw(ArgumentError("differentiation rules are not supported for in-place plans"))

Check warning on line 188 in ext/AbstractFFTsChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsChainRulesCoreExt.jl#L188

Added line #L188 was not covered by tests
end
Δy = P * Δx .+ (ΔP.scale / P.scale) .* y
return y, Δy
end
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
y = P * x
if Base.mightalias(y, x)
throw(ArgumentError("differentiation rules are not supported for in-place plans"))

Check warning on line 196 in ext/AbstractFFTsChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsChainRulesCoreExt.jl#L196

Added line #L196 was not covered by tests
end
Pt = P'
scale = P.scale
project_x = ChainRulesCore.ProjectTo(x)
Expand Down
4 changes: 2 additions & 2 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return p.p \ (x ./ scale)
return p.p \ (x ./ convert(typeof(x), scale))
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
Expand All @@ -651,7 +651,7 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle)
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return scale ./ N .* (p.p \ x)
return convert(typeof(x), scale) ./ N .* (p.p \ x)
end

# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
Expand Down

0 comments on commit c332e89

Please sign in to comment.