From 450bd993744e5b9a471e54d61a7350d7da225650 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 8 Mar 2024 17:29:26 +0100 Subject: [PATCH 1/9] Switch backends to ADTypes.jl --- Project.toml | 2 + README.md | 41 +++--- benchmark/Project.toml | 2 +- benchmark/benchmarks.jl | 55 ++----- benchmark/utils.jl | 128 +++++++++++++---- docs/Project.toml | 1 + docs/make.jl | 4 +- docs/src/backends.md | 73 ---------- docs/src/index.md | 41 +++--- docs/src/interface.md | 4 +- ...fferentiationInterfaceChainRulesCoreExt.jl | 48 +------ .../DifferentiationInterfaceEnzymeExt.jl | 6 +- .../forward.jl | 17 +-- .../reverse.jl | 20 +-- ext/DifferentiationInterfaceFiniteDiffExt.jl | 45 ++---- ext/DifferentiationInterfaceForwardDiffExt.jl | 37 ++--- ...tiationInterfacePolyesterForwardDiffExt.jl | 23 +-- ext/DifferentiationInterfaceReverseDiffExt.jl | 27 ++-- ext/DifferentiationInterfaceZygoteExt.jl | 34 +++-- src/DifferentiationInterface.jl | 18 +-- src/array_array.jl | 69 +++++++-- src/array_scalar.jl | 62 ++++++-- src/backends.jl | 122 +++++----------- src/backends_abstract.jl | 68 --------- src/pullback.jl | 20 +-- src/pushforward.jl | 20 +-- src/scalar_array.jl | 56 ++++++-- src/scalar_scalar.jl | 24 +++- src/utils.jl | 2 +- test/Project.toml | 2 +- test/backends.jl | 11 -- test/diffractor.jl | 15 -- test/enzyme_forward.jl | 10 +- test/enzyme_reverse.jl | 10 +- test/finitediff.jl | 8 +- test/forwarddiff.jl | 10 +- test/polyesterforwarddiff.jl | 8 +- test/reversediff.jl | 10 +- test/runtests.jl | 7 - test/utils.jl | 134 +++++++++++------- test/zygote.jl | 8 +- 41 files changed, 571 insertions(+), 731 deletions(-) delete mode 100644 docs/src/backends.md delete mode 100644 src/backends_abstract.jl delete mode 100644 test/backends.jl delete mode 100644 test/diffractor.jl diff --git a/Project.toml b/Project.toml index 06c8b9268..681abadcf 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Guillaume Dalle", "Adrian Hill"] version = "0.1.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -28,6 +29,7 @@ DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceZygoteExt = ["Zygote"] [compat] +ADTypes = "0.2.6" ChainRulesCore = "1.19" DiffResults = "1.1" DocStringExtensions = "0.9" diff --git a/README.md b/README.md index f76db1417..a7d9707a0 100644 --- a/README.md +++ b/README.md @@ -13,18 +13,20 @@ This package provides a backend-agnostic syntax to differentiate functions `f(x) It started out as an experimental redesign for [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl). -## Example +## Supported backends -```jldoctest -julia> using DifferentiationInterface, Enzyme +We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl): -julia> backend = EnzymeReverseBackend(); +- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) with `AutoEnzyme(Val(:forward))` or `AutoEnzyme(Val(:reverse))` +- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) with `AutoFiniteDiff()` +- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) with `AutoForwardDiff()` +- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) with `AutoPolyesterForwardDiff(; chunksize=C)` +- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) with `AutoReverseDiff()` +- [Zygote.jl](https://github.com/FluxML/Zygote.jl) with `AutoZygote()` -julia> f(x) = sum(abs2, x); +We also support one more backend which is not yet part of ADTypes.jl (see [ADTypes.jl#21](https://github.com/SciML/ADTypes.jl/pull/21)): -julia> value_and_gradient(backend, f, [1., 2., 3.]) -(14.0, [2.0, 4.0, 6.0]) -``` +- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) with `AutoChainRules(ruleconfig)` ## Design @@ -40,22 +42,15 @@ From these primitives, several utilities are defined, depending on the type of t | scalar input | derivative | multiderivative | | array input | gradient | jacobian | -## Supported backends - -Forward mode: - -- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) -- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) -- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) -- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) +## Example -Reverse mode: +```jldoctest +julia> import DifferentiationInterface, ADTypes, ForwardDiff -- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) -- [Zygote.jl](https://github.com/FluxML/Zygote.jl) -- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) +julia> backend = ADTypes.AutoForwardDiff(); -Experimental: +julia> f(x) = sum(abs2, x); -- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) -- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) (currently broken due to [#277](https://github.com/JuliaDiff/Diffractor.jl/issues/277)) +julia> DifferentiationInterface.value_and_gradient(backend, f, [1., 2., 3.]) +(14.0, [2.0, 4.0, 6.0]) +``` diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 36f358793..5961d7e8a 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -1,9 +1,9 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 7c9145c51..e620b7bdd 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -1,3 +1,4 @@ +using ADTypes using BenchmarkTools using DifferentiationInterface using LinearAlgebra @@ -41,38 +42,16 @@ end ## Backends -forward_custom_backends = [ - EnzymeForwardBackend(; custom=true), - FiniteDiffBackend(; custom=true), - ForwardDiffBackend(; custom=true), - PolyesterForwardDiffBackend(4; custom=true), +all_backends = [ + AutoEnzyme(Val(:forward)), + AutoEnzyme(Val(:reverse)), + AutoFiniteDiff(), + AutoForwardDiff(), + AutoPolyesterForwardDiff(; chunksize=4), + AutoReverseDiff(), + AutoZygote(), ] -forward_fallback_backends = [ - EnzymeForwardBackend(; custom=false), - FiniteDiffBackend(; custom=false), - ForwardDiffBackend(; custom=false), -] - -reverse_custom_backends = [ - ZygoteBackend(; custom=true), - EnzymeReverseBackend(; custom=true), - ReverseDiffBackend(; custom=true), -] - -reverse_fallback_backends = [ - ZygoteBackend(; custom=false), - EnzymeReverseBackend(; custom=false), - ReverseDiffBackend(; custom=false), -] - -all_backends = vcat( - forward_custom_backends, - forward_fallback_backends, - reverse_custom_backends, - reverse_fallback_backends, -) - ## Suite function make_suite() @@ -83,11 +62,7 @@ function make_suite() for backend in all_backends add_derivative_benchmarks!(SUITE, backend, scalar_to_scalar, 1, 1) - end - for backend in forward_fallback_backends add_pushforward_benchmarks!(SUITE, backend, scalar_to_scalar, 1, 1) - end - for backend in reverse_fallback_backends add_pullback_benchmarks!(SUITE, backend, scalar_to_scalar, 1, 1) end @@ -97,11 +72,7 @@ function make_suite() for backend in all_backends add_multiderivative_benchmarks!(SUITE, backend, scalar_to_vector, 1, m) - end - for backend in forward_fallback_backends add_pushforward_benchmarks!(SUITE, backend, scalar_to_vector, 1, m) - end - for backend in reverse_fallback_backends add_pullback_benchmarks!(SUITE, backend, scalar_to_vector, 1, m) end end @@ -112,11 +83,7 @@ function make_suite() for backend in all_backends add_gradient_benchmarks!(SUITE, backend, vector_to_scalar, n, 1) - end - for backend in forward_fallback_backends add_pushforward_benchmarks!(SUITE, backend, vector_to_scalar, n, 1) - end - for backend in reverse_fallback_backends add_pullback_benchmarks!(SUITE, backend, vector_to_scalar, n, 1) end end @@ -127,11 +94,7 @@ function make_suite() for backend in all_backends add_jacobian_benchmarks!(SUITE, backend, vector_to_vector, n, m) - end - for backend in forward_fallback_backends add_pushforward_benchmarks!(SUITE, backend, vector_to_vector, n, m) - end - for backend in reverse_fallback_backends add_pullback_benchmarks!(SUITE, backend, vector_to_vector, n, m) end end diff --git a/benchmark/utils.jl b/benchmark/utils.jl index cea537f28..0a1d7f6bb 100644 --- a/benchmark/utils.jl +++ b/benchmark/utils.jl @@ -1,28 +1,43 @@ -using DifferentiationInterface +using ADTypes +using ADTypes: AbstractADType using BenchmarkTools +using DifferentiationInterface + +## Pretty printing with type piracy + +Base.string(::AutoEnzyme{Val{:forward}}) = "Enzyme (forward)" +Base.string(::AutoEnzyme{Val{:reverse}}) = "Enzyme (reverse)" +Base.string(::AutoFiniteDiff) = "FiniteDiff" +Base.string(::AutoForwardDiff) = "ForwardDiff" +Base.string(::AutoPolyesterForwardDiff) = "PolyesterForwardDiff" +Base.string(::AutoReverseDiff) = "ReverseDiff" +Base.string(::AutoZygote) = "Zygote" + +## Benchmark suite function add_pushforward_benchmarks!( - suite::BenchmarkGroup, backend::AbstractBackend, f::F, n::Integer, m::Integer + suite::BenchmarkGroup, backend::AbstractADType, f::F, n::Integer, m::Integer ) where {F} x = n == 1 ? randn() : randn(n) dx = n == 1 ? randn() : randn(n) dy = m == 1 ? 0.0 : zeros(m) - if autodiff_mode(backend) != :forward || !handles_types(backend, typeof(x), typeof(dy)) + if !isa(autodiff_mode(backend), Val{:forward}) || + !handles_types(backend, typeof(x), typeof(dy)) return nothing end - suite["value_and_pushforward"][(n, m)][string(backend)] = @benchmarkable begin + suite["value_and_pushforward"][(n, m)]["$(string(backend))"] = @benchmarkable begin value_and_pushforward($backend, $f, $x, $dx) end - suite["value_and_pushforward!"][(n, m)][string(backend)] = @benchmarkable begin + suite["value_and_pushforward!"][(n, m)]["$(string(backend))"] = @benchmarkable begin value_and_pushforward!($dy, $backend, $f, $x, $dx) end - suite["pushforward"][(n, m)][string(backend)] = @benchmarkable begin + suite["pushforward"][(n, m)]["$(string(backend))"] = @benchmarkable begin pushforward($backend, $f, $x, $dx) end - suite["pushforward!"][(n, m)][string(backend)] = @benchmarkable begin + suite["pushforward!"][(n, m)]["$(string(backend))"] = @benchmarkable begin pushforward!($dy, $backend, $f, $x, $dx) end @@ -30,27 +45,28 @@ function add_pushforward_benchmarks!( end function add_pullback_benchmarks!( - suite::BenchmarkGroup, backend::AbstractBackend, f::F, n::Integer, m::Integer + suite::BenchmarkGroup, backend::AbstractADType, f::F, n::Integer, m::Integer ) where {F} x = n == 1 ? randn() : randn(n) dx = n == 1 ? 0.0 : zeros(n) dy = m == 1 ? randn() : randn(m) - if autodiff_mode(backend) != :reverse || !handles_types(backend, typeof(x), typeof(dy)) + if !isa(autodiff_mode(backend), Val{:reverse}) || + !handles_types(backend, typeof(x), typeof(dy)) return nothing end - suite["value_and_pullback"][(n, m)][string(backend)] = @benchmarkable begin + suite["value_and_pullback"][(n, m)]["$(string(backend))"] = @benchmarkable begin value_and_pullback($backend, $f, $x, $dy) end - suite["value_and_pullback!"][(n, m)][string(backend)] = @benchmarkable begin + suite["value_and_pullback!"][(n, m)]["$(string(backend))"] = @benchmarkable begin value_and_pullback!($dx, $backend, $f, $x, $dy) end - suite["pullback"][(n, m)][string(backend)] = @benchmarkable begin + suite["pullback"][(n, m)]["$(string(backend))"] = @benchmarkable begin pullback($backend, $f, $x, $dy) end - suite["pullback!"][(n, m)][string(backend)] = @benchmarkable begin + suite["pullback!"][(n, m)]["$(string(backend))"] = @benchmarkable begin pullback!($dx, $backend, $f, $x, $dy) end @@ -58,7 +74,7 @@ function add_pullback_benchmarks!( end function add_derivative_benchmarks!( - suite::BenchmarkGroup, backend::AbstractBackend, f::F, n::Integer, m::Integer + suite::BenchmarkGroup, backend::AbstractADType, f::F, n::Integer, m::Integer ) where {F} @assert n == m == 1 if !handles_types(backend, Number, Number) @@ -67,19 +83,27 @@ function add_derivative_benchmarks!( x = randn() - suite["value_and_derivative"][(1, 1)][string(backend)] = @benchmarkable begin + suite["value_and_derivative"][(1, 1)]["$(string(backend))"] = @benchmarkable begin value_and_derivative($backend, $f, $x) end - suite["derivative"][(1, 1)][string(backend)] = @benchmarkable begin + suite["value_and_derivative"][(1, 1)]["$(string(backend)) - fallback"] = @benchmarkable begin + value_and_derivative(Val(:fallback), $backend, $f, $x) + end + + suite["derivative"][(1, 1)]["$(string(backend))"] = @benchmarkable begin derivative($backend, $f, $x) end + suite["derivative"][(1, 1)]["$(string(backend)) - fallback"] = @benchmarkable begin + derivative(Val(:fallback), $backend, $f, $x) + end + return nothing end function add_multiderivative_benchmarks!( - suite::BenchmarkGroup, backend::AbstractBackend, f::F, n::Integer, m::Integer + suite::BenchmarkGroup, backend::AbstractADType, f::F, n::Integer, m::Integer ) where {F} @assert n == 1 if !handles_types(backend, Number, Vector) @@ -89,25 +113,39 @@ function add_multiderivative_benchmarks!( x = randn() multider = zeros(m) - suite["value_and_multiderivative"][(1, m)][string(backend)] = @benchmarkable begin + suite["value_and_multiderivative"][(1, m)]["$(string(backend))"] = @benchmarkable begin value_and_multiderivative($backend, $f, $x) end - suite["value_and_multiderivative!"][(1, m)][string(backend)] = @benchmarkable begin + suite["value_and_multiderivative!"][(1, m)]["$(string(backend))"] = @benchmarkable begin value_and_multiderivative!($multider, $backend, $f, $x) end - suite["multiderivative"][(1, m)][string(backend)] = @benchmarkable begin + suite["value_and_multiderivative"][(1, m)]["$(string(backend)) - fallback"] = @benchmarkable begin + value_and_multiderivative(Val(:fallback), $backend, $f, $x) + end + suite["value_and_multiderivative!"][(1, m)]["$(string(backend)) - fallback"] = @benchmarkable begin + value_and_multiderivative!(Val(:fallback), $multider, $backend, $f, $x) + end + + suite["multiderivative"][(1, m)]["$(string(backend))"] = @benchmarkable begin multiderivative($backend, $f, $x) end - suite["multiderivative!"][(1, m)][string(backend)] = @benchmarkable begin + suite["multiderivative!"][(1, m)]["$(string(backend))"] = @benchmarkable begin multiderivative!($multider, $backend, $f, $x) end + suite["multiderivative"][(1, m)]["$(string(backend)) - fallback"] = @benchmarkable begin + multiderivative(Val(:fallback), $backend, $f, $x) + end + suite["multiderivative!"][(1, m)]["$(string(backend)) - fallback"] = @benchmarkable begin + multiderivative!(Val(:fallback), $multider, $backend, $f, $x) + end + return nothing end function add_gradient_benchmarks!( - suite::BenchmarkGroup, backend::AbstractBackend, f::F, n::Integer, m::Integer + suite::BenchmarkGroup, backend::AbstractADType, f::F, n::Integer, m::Integer ) where {F} @assert m == 1 if !handles_types(backend, Vector, Number) @@ -117,25 +155,39 @@ function add_gradient_benchmarks!( x = randn(n) grad = zeros(n) - suite["value_and_gradient"][(n, 1)][string(backend)] = @benchmarkable begin + suite["value_and_gradient"][(n, 1)]["$(string(backend))"] = @benchmarkable begin value_and_gradient($backend, $f, $x) end - suite["value_and_gradient!"][(n, 1)][string(backend)] = @benchmarkable begin + suite["value_and_gradient!"][(n, 1)]["$(string(backend))"] = @benchmarkable begin value_and_gradient!($grad, $backend, $f, $x) end - suite["gradient"][(n, 1)][string(backend)] = @benchmarkable begin + suite["value_and_gradient"][(n, 1)]["$(string(backend)) - fallback"] = @benchmarkable begin + value_and_gradient(Val(:fallback), $backend, $f, $x) + end + suite["value_and_gradient!"][(n, 1)]["$(string(backend)) - fallback"] = @benchmarkable begin + value_and_gradient!(Val(:fallback), $grad, $backend, $f, $x) + end + + suite["gradient"][(n, 1)]["$(string(backend))"] = @benchmarkable begin gradient($backend, $f, $x) end - suite["gradient!"][(n, 1)][string(backend)] = @benchmarkable begin + suite["gradient!"][(n, 1)]["$(string(backend))"] = @benchmarkable begin gradient!($grad, $backend, $f, $x) end + suite["gradient"][(n, 1)]["$(string(backend)) - fallback"] = @benchmarkable begin + gradient(Val(:fallback), $backend, $f, $x) + end + suite["gradient!"][(n, 1)]["$(string(backend)) - fallback"] = @benchmarkable begin + gradient!(Val(:fallback), $grad, $backend, $f, $x) + end + return nothing end function add_jacobian_benchmarks!( - suite::BenchmarkGroup, backend::AbstractBackend, f::F, n::Integer, m::Integer + suite::BenchmarkGroup, backend::AbstractADType, f::F, n::Integer, m::Integer ) where {F} if !handles_types(backend, Vector, Vector) return nothing @@ -144,19 +196,33 @@ function add_jacobian_benchmarks!( x = randn(n) jac = zeros(m, n) - suite["value_and_jacobian"][(n, m)][string(backend)] = @benchmarkable begin + suite["value_and_jacobian"][(n, m)]["$(string(backend))"] = @benchmarkable begin value_and_jacobian($backend, $f, $x) end - suite["value_and_jacobian!"][(n, m)][string(backend)] = @benchmarkable begin + suite["value_and_jacobian!"][(n, m)]["$(string(backend))"] = @benchmarkable begin value_and_jacobian!($jac, $backend, $f, $x) end - suite["jacobian"][(n, m)][string(backend)] = @benchmarkable begin + suite["value_and_jacobian"][(n, m)]["$(string(backend)) (:fallback)"] = @benchmarkable begin + value_and_jacobian(Val(:fallback), $backend, $f, $x) + end + suite["value_and_jacobian!"][(n, m)]["$(string(backend)) (:fallback)"] = @benchmarkable begin + value_and_jacobian!(Val(:fallback), $jac, $backend, $f, $x) + end + + suite["jacobian"][(n, m)]["$(string(backend))"] = @benchmarkable begin jacobian($backend, $f, $x) end - suite["jacobian!"][(n, m)][string(backend)] = @benchmarkable begin + suite["jacobian!"][(n, m)]["$(string(backend))"] = @benchmarkable begin jacobian!($jac, $backend, $f, $x) end + suite["jacobian"][(n, m)]["$(string(backend)) - fallback"] = @benchmarkable begin + jacobian(Val(:fallback), $backend, $f, $x) + end + suite["jacobian!"][(n, m)]["$(string(backend)) - fallback"] = @benchmarkable begin + jacobian!(Val(:fallback), $jac, $backend, $f, $x) + end + return nothing end diff --git a/docs/Project.toml b/docs/Project.toml index 7850516f4..5a20d1146 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" diff --git a/docs/make.jl b/docs/make.jl index 16c3a5566..29c31acf2 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -59,9 +59,7 @@ makedocs(; prettyurls=get(ENV, "CI", "false") == "true", canonical="https://gdalle.github.io/DifferentiationInterface.jl", ), - pages=[ - "Home" => "index.md", "Interface" => "interface.md", "Backends" => "backends.md" - ], + pages=["Home" => "index.md", "Interface" => "interface.md"], ) deploydocs(; repo="github.com/gdalle/DifferentiationInterface.jl", devbranch="main") diff --git a/docs/src/backends.md b/docs/src/backends.md deleted file mode 100644 index df9b2b271..000000000 --- a/docs/src/backends.md +++ /dev/null @@ -1,73 +0,0 @@ -```@meta -CollapsedDocStrings = true -``` - -# Backends - -## ChainRulesCore - -```@docs -ChainRulesForwardBackend -ChainRulesReverseBackend -``` - -```@autodocs -Modules = [ChainRulesCoreExt] -``` - -## Enzyme - -```@docs -EnzymeForwardBackend -EnzymeReverseBackend -``` - -```@autodocs -Modules = [EnzymeExt] -``` - -## FiniteDiff - -```@docs -FiniteDiffBackend -``` - -```@autodocs -Modules = [FiniteDiffExt] -``` - -## ForwardDiff - -```@docs -ForwardDiffBackend -``` - -```@autodocs -Modules = [ForwardDiffExt] -``` - -## PolyesterForwardDiff - -```@docs -PolyesterForwardDiffBackend -``` - -```@autodocs -Modules = [PolyesterForwardDiffExt] -``` - -## ReverseDiff - -```@docs -ReverseDiffBackend -``` - -```@autodocs -Modules = [ReverseDiffExt] -``` - -## Zygote - -```@autodocs -Modules = [ZygoteExt] -``` diff --git a/docs/src/index.md b/docs/src/index.md index 9cdd27b68..55f213137 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -17,18 +17,20 @@ This package provides a backend-agnostic syntax to differentiate functions `f(x) It started out as an experimental redesign for [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl). -## Example +## Supported backends -```jldoctest -julia> using DifferentiationInterface, Enzyme +We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl): -julia> backend = EnzymeReverseBackend(); +- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) with `AutoEnzyme(Val(:forward))` +- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) with `AutoFiniteDiff()` +- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) with `AutoForwardDiff()` +- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) with `AutoPolyesterForwardDiff(; chunksize=C)` +- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) with `AutoReverseDiff()` +- [Zygote.jl](https://github.com/FluxML/Zygote.jl) with `AutoZygote()` -julia> f(x) = sum(abs2, x); +We also support one more backend which is not yet part of ADTypes.jl (see [ADTypes.jl#21](https://github.com/SciML/ADTypes.jl/pull/21)): -julia> value_and_gradient(backend, f, [1., 2., 3.]) -(14.0, [2.0, 4.0, 6.0]) -``` +- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) with `AutoChainRules(ruleconfig)` ## Design @@ -44,22 +46,15 @@ From these primitives, several utilities are defined, depending on the type of t | scalar input | derivative | multiderivative | | array input | gradient | jacobian | -## Supported backends - -Forward mode: - -- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) -- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) -- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) -- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) +## Example -Reverse mode: +```jldoctest +julia> import DifferentiationInterface, ADTypes, ForwardDiff -- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) -- [Zygote.jl](https://github.com/FluxML/Zygote.jl) -- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) +julia> backend = ADTypes.AutoForwardDiff(); -Experimental: +julia> f(x) = sum(abs2, x); -- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) -- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) (currently broken due to [#277](https://github.com/JuliaDiff/Diffractor.jl/issues/277)) +julia> DifferentiationInterface.value_and_gradient(backend, f, [1., 2., 3.]) +(14.0, [2.0, 4.0, 6.0]) +``` diff --git a/docs/src/interface.md b/docs/src/interface.md index 6f25db112..991e51eda 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -55,11 +55,11 @@ Modules = [DifferentiationInterface] Pages = ["pullback.jl"] ``` -## Abstract backends +## Backends ```@autodocs Modules = [DifferentiationInterface] -Pages = ["backends_abstract.jl"] +Pages = ["backends.jl"] ``` ## Internals diff --git a/ext/DifferentiationInterfaceChainRulesCoreExt.jl b/ext/DifferentiationInterfaceChainRulesCoreExt.jl index 5e116c557..1a20a1918 100644 --- a/ext/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/ext/DifferentiationInterfaceChainRulesCoreExt.jl @@ -2,71 +2,37 @@ module DifferentiationInterfaceChainRulesCoreExt using ChainRulesCore: HasForwardsMode, HasReverseMode, NoTangent, RuleConfig, frule_via_ad, rrule_via_ad -using DifferentiationInterface: ChainRulesForwardBackend, ChainRulesReverseBackend +using DifferentiationInterface: AutoChainRules, ruleconfig import DifferentiationInterface as DI using DocStringExtensions -ruleconfig(backend::ChainRulesForwardBackend) = backend.ruleconfig -ruleconfig(backend::ChainRulesReverseBackend) = backend.ruleconfig - update!(_old::Number, new::Number) = new update!(old, new) = old .= new -## Backend construction - -""" - ChainRulesForwardBackend(rc::RuleConfig; custom=true) - -Construct a [`ChainRulesForwardBackend`](@ref) from a [`RuleConfig`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/superpowers/ruleconfig.html) object that `HasForwardsMode`. - -## Example - -```julia -using Diffractor, DifferentiationInterface -backend = ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()) -``` -""" -function DI.ChainRulesForwardBackend(rc::RuleConfig{>:HasForwardsMode}; custom::Bool=true) - return ChainRulesForwardBackend{custom,typeof(rc)}(rc) -end - -""" - ChainRulesReverseBackend(rc::RuleConfig; custom=true) - -Construct a [`ChainRulesReverseBackend`](@ref) from a [`RuleConfig`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/superpowers/ruleconfig.html) object that `HasReverseMode`. - -## Example - -```julia -using Zygote, DifferentiationInterface -backend = ChainRulesReverseBackend(Zygote.ZygoteRuleConfig()) -``` -""" -function DI.ChainRulesReverseBackend(rc::RuleConfig{>:HasReverseMode}; custom::Bool=true) - return ChainRulesReverseBackend{custom,typeof(rc)}(rc) -end +const AutoForwardChainRules = AutoChainRules{<:RuleConfig{>:HasForwardsMode}} +const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}} ## Primitives -function DI.value_and_pushforward(backend::ChainRulesForwardBackend, f, x, dx) +function DI.value_and_pushforward(backend::AutoForwardChainRules, f, x, dx) rc = ruleconfig(backend) y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x) return y, new_dy end -function DI.value_and_pushforward!(dy, backend::ChainRulesForwardBackend, f, x, dx) +function DI.value_and_pushforward!(dy, backend::AutoForwardChainRules, f, x, dx) y, new_dy = DI.value_and_pushforward(backend, f, x, dx) return y, update!(dy, new_dy) end -function DI.value_and_pullback(backend::ChainRulesReverseBackend, f, x, dy) +function DI.value_and_pullback(backend::AutoReverseChainRules, f, x, dy) rc = ruleconfig(backend) y, pullback = rrule_via_ad(rc, f, x) _, new_dx = pullback(dy) return y, new_dx end -function DI.value_and_pullback!(dx, backend::ChainRulesReverseBackend, f, x, dy) +function DI.value_and_pullback!(dx, backend::AutoReverseChainRules, f, x, dy) y, new_dx = DI.value_and_pullback(backend, f, x, dy) return y, update!(dx, new_dx) end diff --git a/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index fcbeda077..7c0a7a862 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -1,6 +1,6 @@ module DifferentiationInterfaceEnzymeExt -using DifferentiationInterface: EnzymeForwardBackend, EnzymeReverseBackend +using ADTypes: AutoEnzyme import DifferentiationInterface as DI using DocStringExtensions using Enzyme: @@ -15,9 +15,7 @@ using Enzyme: jacobian # Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type -function DI.basisarray( - ::Union{EnzymeForwardBackend,EnzymeReverseBackend}, a::AbstractArray{T}, i -) where {T} +function DI.basisarray(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T} b = zero(a) b[i] = one(T) return b diff --git a/ext/DifferentiationInterfaceEnzymeExt/forward.jl b/ext/DifferentiationInterfaceEnzymeExt/forward.jl index 95aadfa51..5b1bd9e33 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/forward.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/forward.jl @@ -1,24 +1,17 @@ - -## Backend construction - -""" - EnzymeForwardBackend(; custom=true) - -Construct a [`EnzymeForwardBackend`](@ref). -""" -DI.EnzymeForwardBackend(; custom::Bool=true) = EnzymeForwardBackend{custom}() +const AutoForwardEnzyme = AutoEnzyme{Val{:forward}} +DI.autodiff_mode(::AutoForwardEnzyme) = Val{:forward}() ## Primitives function DI.value_and_pushforward!( - _dy::Y, ::EnzymeForwardBackend, f, x::X, dx + _dy::Y, ::AutoForwardEnzyme, f, x::X, dx ) where {X,Y<:Real} y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx)) return y, new_dy end function DI.value_and_pushforward!( - dy::Y, ::EnzymeForwardBackend, f, x::X, dx + dy::Y, ::AutoForwardEnzyme, f, x::X, dx ) where {X,Y<:AbstractArray} y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx)) dy .= new_dy @@ -27,7 +20,7 @@ end ## Utilities -function DI.value_and_jacobian(::EnzymeForwardBackend{true}, f, x::AbstractArray) +function DI.value_and_jacobian(::AutoForwardEnzyme, f, x::AbstractArray) y = f(x) jac = jacobian(Forward, f, x) # see https://github.com/EnzymeAD/Enzyme.jl/issues/1332 diff --git a/ext/DifferentiationInterfaceEnzymeExt/reverse.jl b/ext/DifferentiationInterfaceEnzymeExt/reverse.jl index 053bd3e37..24787ab34 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/reverse.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/reverse.jl @@ -1,17 +1,11 @@ - -## Backend construction - -""" - EnzymeReverseBackend(; custom=true) - -Construct a [`EnzymeReverseBackend`](@ref). -""" -DI.EnzymeReverseBackend(; custom::Bool=true) = EnzymeReverseBackend{custom}() +const AutoReverseEnzyme = AutoEnzyme{Val{:reverse}} +DI.autodiff_mode(::AutoReverseEnzyme) = Val{:reverse}() +DI.handles_output_type(::AutoReverseEnzyme, ::Type{<:AbstractArray}) = false ## Primitives function DI.value_and_pullback!( - _dx, ::EnzymeReverseBackend, f, x::X, dy::Y + _dx, ::AutoReverseEnzyme, f, x::X, dy::Y ) where {X<:Number,Y<:Union{Real,Nothing}} der, y = autodiff(ReverseWithPrimal, f, Active, Active(x)) new_dx = dy * only(der) @@ -19,7 +13,7 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - dx::X, ::EnzymeReverseBackend, f, x::X, dy::Y + dx::X, ::AutoReverseEnzyme, f, x::X, dy::Y ) where {X<:AbstractArray,Y<:Union{Real,Nothing}} dx .= zero(eltype(dx)) _, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx)) @@ -29,14 +23,14 @@ end ## Utilities -function DI.value_and_gradient(::EnzymeReverseBackend{true}, f, x::AbstractArray) +function DI.value_and_gradient(::AutoReverseEnzyme, f, x::AbstractArray) y = f(x) grad = gradient(Reverse, f, x) return y, grad end function DI.value_and_gradient!( - grad::AbstractArray, ::EnzymeReverseBackend{true}, f, x::AbstractArray + grad::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray ) y = f(x) gradient!(Reverse, grad, f, x) diff --git a/ext/DifferentiationInterfaceFiniteDiffExt.jl b/ext/DifferentiationInterfaceFiniteDiffExt.jl index a1c18c3fe..5ad2be8de 100644 --- a/ext/DifferentiationInterfaceFiniteDiffExt.jl +++ b/ext/DifferentiationInterfaceFiniteDiffExt.jl @@ -1,6 +1,6 @@ module DifferentiationInterfaceFiniteDiffExt -using DifferentiationInterface: FiniteDiffBackend +using ADTypes: AutoFiniteDiff import DifferentiationInterface as DI using DocStringExtensions using FiniteDiff: @@ -14,24 +14,11 @@ using LinearAlgebra: dot, mul! const FUNCTION_INPLACE = Val{true} const FUNCTION_NOT_INPLACE = Val{false} -## Backend construction - -""" - FiniteDiffBackend(::Type{fdtype}=Val{:central}; custom=true) - -Construct a [`FiniteDiffBackend`](@ref) with any finite difference type `fdtype` (`Val{:forward}` or `Val{:central}`). -""" -function DI.FiniteDiffBackend( - ::Type{fdtype}=Val{:central}; custom::Bool=true -) where {fdtype} - return FiniteDiffBackend{custom,fdtype}() -end - ## Primitives function DI.value_and_pushforward!( - dy::Y, ::FiniteDiffBackend{custom,fdtype}, f, x, dx -) where {Y<:Number,custom,fdtype} + dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx +) where {Y<:Number,fdtype} y = f(x) step(t::Number)::Number = f(x .+ t .* dx) new_dy = finite_difference_derivative(step, zero(eltype(dx)), fdtype, eltype(y), y) @@ -39,8 +26,8 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::Y, ::FiniteDiffBackend{custom,fdtype}, f, x, dx -) where {Y<:AbstractArray,custom,fdtype} + dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx +) where {Y<:AbstractArray,fdtype} y = f(x) step(t::Number)::AbstractArray = f(x .+ t .* dx) finite_difference_gradient!( @@ -51,56 +38,48 @@ end ## Utilities -function DI.value_and_derivative( - ::FiniteDiffBackend{true,fdtype}, f, x::Number -) where {fdtype} +function DI.value_and_derivative(::AutoFiniteDiff{fdtype}, f, x::Number) where {fdtype} y = f(x) der = finite_difference_derivative(f, x, fdtype, eltype(y), y) return y, der end function DI.value_and_multiderivative!( - multider::AbstractArray, ::FiniteDiffBackend{true,fdtype}, f, x::Number + multider::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x::Number ) where {fdtype} y = f(x) finite_difference_gradient!(multider, f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) return y, multider end -function DI.value_and_multiderivative( - ::FiniteDiffBackend{true,fdtype}, f, x::Number -) where {fdtype} +function DI.value_and_multiderivative(::AutoFiniteDiff{fdtype}, f, x::Number) where {fdtype} y = f(x) multider = finite_difference_gradient(f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) return y, multider end function DI.value_and_gradient!( - grad::AbstractArray, ::FiniteDiffBackend{true,fdtype}, f, x::AbstractArray + grad::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x::AbstractArray ) where {fdtype} y = f(x) finite_difference_gradient!(grad, f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) return y, grad end -function DI.value_and_gradient( - ::FiniteDiffBackend{true,fdtype}, f, x::AbstractArray -) where {fdtype} +function DI.value_and_gradient(::AutoFiniteDiff{fdtype}, f, x::AbstractArray) where {fdtype} y = f(x) grad = finite_difference_gradient(f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) return y, grad end -function DI.value_and_jacobian( - ::FiniteDiffBackend{true,fdtype}, f, x::AbstractArray -) where {fdtype} +function DI.value_and_jacobian(::AutoFiniteDiff{fdtype}, f, x::AbstractArray) where {fdtype} y = f(x) jac = finite_difference_jacobian(f, x, fdtype, eltype(y)) return y, jac end function DI.value_and_jacobian!( - jac::AbstractMatrix, backend::FiniteDiffBackend{true}, f, x::AbstractArray + jac::AbstractMatrix, backend::AutoFiniteDiff, f, x::AbstractArray ) y, new_jac = DI.value_and_jacobian(backend, f, x) jac .= new_jac diff --git a/ext/DifferentiationInterfaceForwardDiffExt.jl b/ext/DifferentiationInterfaceForwardDiffExt.jl index 367fca5d9..ee45f530f 100644 --- a/ext/DifferentiationInterfaceForwardDiffExt.jl +++ b/ext/DifferentiationInterfaceForwardDiffExt.jl @@ -1,6 +1,6 @@ module DifferentiationInterfaceForwardDiffExt -using DifferentiationInterface: ForwardDiffBackend +using ADTypes: AutoForwardDiff import DifferentiationInterface as DI using DiffResults: DiffResults using DocStringExtensions @@ -18,19 +18,10 @@ using ForwardDiff: value using LinearAlgebra: mul! -## Backend construction - -""" - ForwardDiffBackend(; custom=true) - -Construct a [`ForwardDiffBackend`](@ref). -""" -DI.ForwardDiffBackend(; custom::Bool=true) = ForwardDiffBackend{custom}() - ## Primitives function DI.value_and_pushforward!( - _dy::Y, ::ForwardDiffBackend, f, x::X, dx + _dy::Y, ::AutoForwardDiff, f, x::X, dx ) where {X<:Real,Y<:Real} T = typeof(Tag(f, X)) xdual = Dual{T}(x, dx) @@ -41,7 +32,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::Y, ::ForwardDiffBackend, f, x::X, dx + dy::Y, ::AutoForwardDiff, f, x::X, dx ) where {X<:Real,Y<:AbstractArray} T = typeof(Tag(f, X)) xdual = Dual{T}(x, dx) @@ -52,7 +43,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - _dy::Y, ::ForwardDiffBackend, f, x::X, dx + _dy::Y, ::AutoForwardDiff, f, x::X, dx ) where {X<:AbstractArray,Y<:Real} T = typeof(Tag(f, X)) # TODO: unsure xdual = Dual{T}.(x, dx) # TODO: allocation @@ -63,7 +54,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::Y, ::ForwardDiffBackend, f, x::X, dx + dy::Y, ::AutoForwardDiff, f, x::X, dx ) where {X<:AbstractArray,Y<:AbstractArray} T = typeof(Tag(f, X)) # TODO: unsure xdual = Dual{T}.(x, dx) # TODO: allocation @@ -75,49 +66,45 @@ end ## Utilities (TODO: use DiffResults) -function DI.value_and_derivative(::ForwardDiffBackend{true}, f, x::Number) +function DI.value_and_derivative(::AutoForwardDiff, f, x::Number) y = f(x) der = derivative(f, x) return y, der end -function DI.value_and_multiderivative(::ForwardDiffBackend{true}, f, x::Number) +function DI.value_and_multiderivative(::AutoForwardDiff, f, x::Number) y = f(x) multider = derivative(f, x) return y, multider end function DI.value_and_multiderivative!( - multider::AbstractArray, ::ForwardDiffBackend{true}, f, x::Number + multider::AbstractArray, ::AutoForwardDiff, f, x::Number ) y = f(x) derivative!(multider, f, x) return y, multider end -function DI.value_and_gradient(::ForwardDiffBackend{true}, f, x::AbstractArray) +function DI.value_and_gradient(::AutoForwardDiff, f, x::AbstractArray) y = f(x) grad = gradient(f, x) return y, grad end -function DI.value_and_gradient!( - grad::AbstractArray, ::ForwardDiffBackend{true}, f, x::AbstractArray -) +function DI.value_and_gradient!(grad::AbstractArray, ::AutoForwardDiff, f, x::AbstractArray) y = f(x) gradient!(grad, f, x) return y, grad end -function DI.value_and_jacobian(::ForwardDiffBackend{true}, f, x::AbstractArray) +function DI.value_and_jacobian(::AutoForwardDiff, f, x::AbstractArray) y = f(x) jac = jacobian(f, x) return y, jac end -function DI.value_and_jacobian!( - jac::AbstractMatrix, ::ForwardDiffBackend{true}, f, x::AbstractArray -) +function DI.value_and_jacobian!(jac::AbstractMatrix, ::AutoForwardDiff, f, x::AbstractArray) y = f(x) jacobian!(jac, f, x) return y, jac diff --git a/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl b/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl index 25d54d6f9..6ff75178a 100644 --- a/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -1,6 +1,6 @@ module DifferentiationInterfacePolyesterForwardDiffExt -using DifferentiationInterface: ForwardDiffBackend, PolyesterForwardDiffBackend +using ADTypes: AutoPolyesterForwardDiff, AutoForwardDiff import DifferentiationInterface as DI using DiffResults: DiffResults using DocStringExtensions @@ -8,29 +8,16 @@ using ForwardDiff: Chunk using LinearAlgebra: mul! using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian! -## Backend construction - -""" - PolyesterForwardDiffBackend(C; custom=true) - -Construct a [`PolyesterForwardDiffBackend`](@ref) with chunk size `C`. -""" -function DI.PolyesterForwardDiffBackend(C::Integer; custom::Bool=true) - return PolyesterForwardDiffBackend{custom,C}() -end - ## Primitives -function DI.value_and_pushforward!( - dy, ::PolyesterForwardDiffBackend{custom}, f, x, dx -) where {custom} - return DI.value_and_pushforward!(dy, ForwardDiffBackend{custom}(), f, x, dx) +function DI.value_and_pushforward!(dy, ::AutoPolyesterForwardDiff{C}, f, x, dx) where {C} + return DI.value_and_pushforward!(dy, AutoForwardDiff{C,Nothing}(nothing), f, x, dx) end ## Utilities function DI.value_and_gradient!( - grad::AbstractArray, ::PolyesterForwardDiffBackend{true,C}, f, x::AbstractArray + grad::AbstractArray, ::AutoPolyesterForwardDiff{C}, f, x::AbstractArray ) where {C} y = f(x) threaded_gradient!(f, grad, x, Chunk{C}()) @@ -38,7 +25,7 @@ function DI.value_and_gradient!( end function DI.value_and_jacobian!( - jac::AbstractMatrix, ::PolyesterForwardDiffBackend{true,C}, f, x::AbstractArray + jac::AbstractMatrix, ::AutoPolyesterForwardDiff{C}, f, x::AbstractArray ) where {C} y = f(x) threaded_jacobian!(f, jac, x, Chunk{C}()) diff --git a/ext/DifferentiationInterfaceReverseDiffExt.jl b/ext/DifferentiationInterfaceReverseDiffExt.jl index 99c5b325c..ea1d23462 100644 --- a/ext/DifferentiationInterfaceReverseDiffExt.jl +++ b/ext/DifferentiationInterfaceReverseDiffExt.jl @@ -1,25 +1,20 @@ module DifferentiationInterfaceReverseDiffExt -using DifferentiationInterface: ReverseDiffBackend +using ADTypes: AutoReverseDiff import DifferentiationInterface as DI using DiffResults: DiffResults using DocStringExtensions using LinearAlgebra: mul! using ReverseDiff: gradient, gradient!, jacobian, jacobian! -## Backend construction +## Limitations -""" - ReverseDiffBackend(; custom) - -Construct a [`ReverseDiffBackend`](@ref). -""" -DI.ReverseDiffBackend(; custom::Bool=true) = ReverseDiffBackend{custom}() +DI.handles_input_type(::AutoReverseDiff, ::Type{<:Number}) = false ## Primitives function DI.value_and_pullback!( - dx, ::ReverseDiffBackend, f, x::X, dy::Y + dx, ::AutoReverseDiff, f, x::X, dy::Y ) where {X<:AbstractArray,Y<:Real} res = DiffResults.DiffResult(zero(Y), dx) res = gradient!(res, f, x) @@ -29,7 +24,7 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - dx, ::ReverseDiffBackend, f, x::X, dy::Y + dx, ::AutoReverseDiff, f, x::X, dy::Y ) where {X<:AbstractArray,Y<:AbstractArray} res = DiffResults.DiffResult(similar(dy), similar(dy, length(dy), length(x))) res = jacobian!(res, f, x) @@ -41,29 +36,25 @@ end ## Utilities (TODO: use DiffResults) -function DI.value_and_gradient(::ReverseDiffBackend{true}, f, x::AbstractArray) +function DI.value_and_gradient(::AutoReverseDiff, f, x::AbstractArray) y = f(x) grad = gradient(f, x) return y, grad end -function DI.value_and_gradient!( - grad::AbstractArray, ::ReverseDiffBackend{true}, f, x::AbstractArray -) +function DI.value_and_gradient!(grad::AbstractArray, ::AutoReverseDiff, f, x::AbstractArray) y = f(x) gradient!(grad, f, x) return y, grad end -function DI.value_and_jacobian(::ReverseDiffBackend{true}, f, x::AbstractArray) +function DI.value_and_jacobian(::AutoReverseDiff, f, x::AbstractArray) y = f(x) jac = jacobian(f, x) return y, jac end -function DI.value_and_jacobian!( - jac::AbstractMatrix, ::ReverseDiffBackend{true}, f, x::AbstractArray -) +function DI.value_and_jacobian!(jac::AbstractMatrix, ::AutoReverseDiff, f, x::AbstractArray) y = f(x) jacobian!(jac, f, x) return y, jac diff --git a/ext/DifferentiationInterfaceZygoteExt.jl b/ext/DifferentiationInterfaceZygoteExt.jl index 28c0a7998..bd8d8d333 100644 --- a/ext/DifferentiationInterfaceZygoteExt.jl +++ b/ext/DifferentiationInterfaceZygoteExt.jl @@ -1,48 +1,54 @@ module DifferentiationInterfaceZygoteExt -using DifferentiationInterface: ChainRulesReverseBackend, ZygoteBackend +using ADTypes: AutoZygote +using DifferentiationInterface: AutoChainRules import DifferentiationInterface as DI using DocStringExtensions using Zygote: ZygoteRuleConfig, gradient, jacobian, withgradient, withjacobian -## Backend construction +## Primitives -""" - ZygoteBackend(; custom=true) +const zygote_chainrules_backend = AutoChainRules(ZygoteRuleConfig()) -Enables the use of [Zygote.jl](https://github.com/FluxML/Zygote.jl) by constructing a [`ChainRulesReverseBackend`](@ref) from `ZygoteRuleConfig()`. -""" -DI.ZygoteBackend(; custom::Bool=true) = ChainRulesReverseBackend(ZygoteRuleConfig(); custom) +function DI.value_and_pushforward!(dy, ::AutoZygote, f, x, dx) + return DI.value_and_pushforward!(dy, zygote_chainrules_backend, f, x, dx) +end + +function DI.value_and_pushforward(::AutoZygote, f, x, dx) + return DI.value_and_pushforward(zygote_chainrules_backend, f, x, dx) +end -const ZygoteBackendType{custom} = ChainRulesReverseBackend{custom,<:ZygoteRuleConfig} +function DI.value_and_pullback!(dx, ::AutoZygote, f, x, dy) + return DI.value_and_pullback!(dx, zygote_chainrules_backend, f, x, dy) +end -function Base.show(io::IO, backend::ZygoteBackendType{custom}) where {custom} - return print(io, "ZygoteBackend{$(custom ? "custom" : "fallback")}()") +function DI.value_and_pullback(::AutoZygote, f, x, dy) + return DI.value_and_pullback(zygote_chainrules_backend, f, x, dy) end ## Utilities -function DI.value_and_gradient(::ZygoteBackendType{true}, f, x::AbstractArray) +function DI.value_and_gradient(::AutoZygote, f, x::AbstractArray) res = withgradient(f, x) return res.val, only(res.grad) end function DI.value_and_gradient!( - grad::AbstractArray, backend::ZygoteBackendType{true}, f, x::AbstractArray + grad::AbstractArray, backend::AutoZygote, f, x::AbstractArray ) y, new_grad = DI.value_and_gradient(backend, f, x) grad .= new_grad return y, grad end -function DI.value_and_jacobian(::ZygoteBackendType{true}, f, x::AbstractArray) +function DI.value_and_jacobian(::AutoZygote, f, x::AbstractArray) y = f(x) jac = jacobian(f, x) return y, only(jac) end function DI.value_and_jacobian!( - jac::AbstractMatrix, backend::ZygoteBackendType{true}, f, x::AbstractArray + jac::AbstractMatrix, backend::AutoZygote, f, x::AbstractArray ) y, new_jac = DI.value_and_jacobian(backend, f, x) jac .= new_jac diff --git a/src/DifferentiationInterface.jl b/src/DifferentiationInterface.jl index 51e33d01c..d1c210297 100644 --- a/src/DifferentiationInterface.jl +++ b/src/DifferentiationInterface.jl @@ -9,10 +9,11 @@ $(EXPORTS) """ module DifferentiationInterface +using ADTypes: + AbstractADType, AbstractForwardMode, AbstractReverseMode, AbstractFiniteDifferencesMode using DocStringExtensions using FillArrays: OneElement -include("backends_abstract.jl") include("backends.jl") include("utils.jl") include("pushforward.jl") @@ -22,19 +23,10 @@ include("scalar_array.jl") include("array_scalar.jl") include("array_array.jl") -export AbstractBackend, AbstractForwardBackend, AbstractReverseBackend -export autodiff_mode, is_custom -export handles_input_type, handles_output_type, handles_types +export AutoChainRules -export ChainRulesForwardBackend, - ChainRulesReverseBackend, - EnzymeForwardBackend, - EnzymeReverseBackend, - FiniteDiffBackend, - ForwardDiffBackend, - PolyesterForwardDiffBackend, - ReverseDiffBackend, - ZygoteBackend +export autodiff_mode +export handles_input_type, handles_output_type, handles_types export value_and_pushforward!, value_and_pushforward export pushforward!, pushforward diff --git a/src/array_array.jl b/src/array_array.jl index 38e23d2a1..155dbb909 100644 --- a/src/array_array.jl +++ b/src/array_array.jl @@ -13,23 +13,39 @@ Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of a $JAC_NOTES """ function value_and_jacobian!( - jac::AbstractMatrix, backend::AbstractBackend, f, x::AbstractArray + jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray ) - y = f(x) + return value_and_jacobian!(Val{:fallback}(), jac, backend, f, x) +end + +function value_and_jacobian!( + implem::Val{:fallback}, + jac::AbstractMatrix, + backend::AbstractADType, + f, + x::AbstractArray, +) + return value_and_jacobian!(implem, autodiff_mode(backend), jac, backend, f, x) +end + +function check_jac(jac::AbstractMatrix, x::AbstractArray, y::AbstractArray) nx, ny = length(x), length(y) size(jac) != (ny, nx) && throw( DimensionMismatch("Size of Jacobian buffer doesn't match expected size ($ny, $nx)"), ) - return _value_and_jacobian!(jac, backend, f, x, y) + return nothing end -function _value_and_jacobian!( +function value_and_jacobian!( + ::Val{:fallback}, + ::Val{:forward}, jac::AbstractMatrix, - backend::AbstractForwardBackend, + backend::AbstractADType, f, x::AbstractArray, - y::AbstractArray, ) + y = f(x) + check_jac(jac, x, y) for (k, j) in enumerate(eachindex(IndexCartesian(), x)) dx_j = basisarray(backend, x, j) jac_col_j = reshape(view(jac, :, k), size(y)) @@ -38,13 +54,16 @@ function _value_and_jacobian!( return y, jac end -function _value_and_jacobian!( +function value_and_jacobian!( + ::Val{:fallback}, + ::Val{:reverse}, jac::AbstractMatrix, - backend::AbstractReverseBackend, + backend::AbstractADType, f, x::AbstractArray, - y::AbstractArray, ) + y = f(x) + check_jac(jac, x, y) for (k, i) in enumerate(eachindex(IndexCartesian(), y)) dy_i = basisarray(backend, y, i) jac_row_i = reshape(view(jac, k, :), size(x)) @@ -60,11 +79,17 @@ Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of a $JAC_NOTES """ -function value_and_jacobian(backend::AbstractBackend, f, x::AbstractArray) +function value_and_jacobian(backend::AbstractADType, f, x::AbstractArray) + return value_and_jacobian(Val{:fallback}(), backend, f, x) +end + +function value_and_jacobian( + implem::Val{:fallback}, backend::AbstractADType, f, x::AbstractArray +) y = f(x) T = promote_type(eltype(x), eltype(y)) jac = similar(y, T, length(y), length(x)) - return value_and_jacobian!(jac, backend, f, x) + return value_and_jacobian!(implem, jac, backend, f, x) end """ @@ -74,8 +99,18 @@ Compute the Jacobian matrix `jac = ∂f(x)` of an array-to-array function, overw $JAC_NOTES """ -function jacobian!(jac::AbstractMatrix, backend::AbstractBackend, f, x::AbstractArray) - return last(value_and_jacobian!(jac, backend, f, x)) +function jacobian!(jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray) + return jacobian!(Val{:fallback}(), jac, backend, f, x) +end + +function jacobian!( + implem::Val{:fallback}, + jac::AbstractMatrix, + backend::AbstractADType, + f, + x::AbstractArray, +) + return last(value_and_jacobian!(implem, jac, backend, f, x)) end """ @@ -85,6 +120,10 @@ Compute the Jacobian matrix `jac = ∂f(x)` of an array-to-array function. $JAC_NOTES """ -function jacobian(backend::AbstractBackend, f, x::AbstractArray) - return last(value_and_jacobian(backend, f, x)) +function jacobian(backend::AbstractADType, f, x::AbstractArray) + return jacobian(Val{:fallback}(), backend, f, x) +end + +function jacobian(implem::Val{:fallback}, backend::AbstractADType, f, x::AbstractArray) + return last(value_and_jacobian(implem, backend, f, x)) end diff --git a/src/array_scalar.jl b/src/array_scalar.jl index 6617bbccc..ec525cadd 100644 --- a/src/array_scalar.jl +++ b/src/array_scalar.jl @@ -3,10 +3,29 @@ Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function, overwriting `grad` if possible. """ -function value_and_gradient! end +function value_and_gradient!( + grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray +) + return value_and_gradient!(Val{:fallback}(), grad, backend, f, x) +end + +function value_and_gradient!( + implem::Val{:fallback}, + grad::AbstractArray, + backend::AbstractADType, + f, + x::AbstractArray, +) + return value_and_gradient!(implem, autodiff_mode(backend), grad, backend, f, x) +end function value_and_gradient!( - grad::AbstractArray, backend::AbstractForwardBackend, f, x::AbstractArray + ::Val{:fallback}, + ::Val{:forward}, + grad::AbstractArray, + backend::AbstractADType, + f, + x::AbstractArray, ) y = f(x) for j in eachindex(IndexCartesian(), x) @@ -17,7 +36,12 @@ function value_and_gradient!( end function value_and_gradient!( - grad::AbstractArray, backend::AbstractReverseBackend, f, x::AbstractArray + ::Val{:fallback}, + ::Val{:reverse}, + grad::AbstractArray, + backend::AbstractADType, + f, + x::AbstractArray, ) y = f(x) return y, pullback!(grad, backend, f, x, one(y)) @@ -28,9 +52,15 @@ end Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function. """ -function value_and_gradient(backend::AbstractBackend, f, x::AbstractArray) +function value_and_gradient(backend::AbstractADType, f, x::AbstractArray) + return value_and_gradient(Val{:fallback}(), backend, f, x) +end + +function value_and_gradient( + implem::Val{:fallback}, backend::AbstractADType, f, x::AbstractArray +) grad = similar(x) - return value_and_gradient!(grad, backend, f, x) + return value_and_gradient!(implem, grad, backend, f, x) end """ @@ -38,8 +68,18 @@ end Compute the gradient `grad = ∇f(x)` of an array-to-scalar function, overwriting `grad` if possible. """ -function gradient!(grad::AbstractArray, backend::AbstractBackend, f, x::AbstractArray) - return last(value_and_gradient!(grad, backend, f, x)) +function gradient!(grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray) + return gradient!(Val{:fallback}(), grad, backend, f, x) +end + +function gradient!( + implem::Val{:fallback}, + grad::AbstractArray, + backend::AbstractADType, + f, + x::AbstractArray, +) + return last(value_and_gradient!(implem, grad, backend, f, x)) end """ @@ -47,6 +87,10 @@ end Compute the gradient `grad = ∇f(x)` of an array-to-scalar function. """ -function gradient(backend::AbstractBackend, f, x::AbstractArray) - return last(value_and_gradient(backend, f, x)) +function gradient(backend::AbstractADType, f, x::AbstractArray) + return gradient(Val{:fallback}(), backend, f, x) +end + +function gradient(implem::Val{:fallback}, backend::AbstractADType, f, x::AbstractArray) + return last(value_and_gradient(implem, backend, f, x)) end diff --git a/src/backends.jl b/src/backends.jl index 81581e87d..6e2c9a551 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -1,115 +1,65 @@ -""" - ChainRulesForwardBackend <: AbstractForwardBackend - -Enables the use of forward mode AD packages based on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl). -""" -struct ChainRulesForwardBackend{custom,RC} <: AbstractForwardBackend{custom} - ruleconfig::RC -end - -function Base.show(io::IO, backend::ChainRulesForwardBackend{custom}) where {custom} - return print( - io, - "ChainRulesForwardBackend{$(custom ? "custom" : "fallback")}($(backend.ruleconfig))", - ) -end +## Additional backend -""" - ChainRulesReverseBackend <: AbstractReverseBackend +# TODO: remove once https://github.com/SciML/ADTypes.jl/pull/21 is merged -Enables the use of reverse mode AD packages based on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl). """ -struct ChainRulesReverseBackend{custom,RC} <: AbstractReverseBackend{custom} - ruleconfig::RC -end - -function Base.show(io::IO, backend::ChainRulesReverseBackend{custom}) where {custom} - return print( - io, - "ChainRulesReverseBackend{$(custom ? "custom" : "fallback")}($(backend.ruleconfig))", - ) -end + AutoChainRules{RC} -""" - FiniteDiffBackend <: AbstractForwardBackend +Enables the use of AD libraries based on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl). -Enables the use of [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl). -""" -struct FiniteDiffBackend{custom,fdtype} <: AbstractForwardBackend{custom} end +# Fields -function Base.show(io::IO, ::FiniteDiffBackend{custom,fdtype}) where {custom,fdtype} - return print(io, "FiniteDiffBackend{$(custom ? "custom" : "fallback"),$fdtype}()") -end +- `ruleconfig::RC`: a [`RuleConfig`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/superpowers/ruleconfig.html) object -""" - EnzymeForwardBackend <: AbstractForwardBackend +# Example -Enables the use of [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) in forward mode. +```julia +using DifferentiationInterface, Zygote +backend = AutoChainRules(Zygote.ZygoteRuleConfig()) +``` """ -struct EnzymeForwardBackend{custom} <: AbstractForwardBackend{custom} end - -function Base.show(io::IO, ::EnzymeForwardBackend{custom}) where {custom} - return print(io, "EnzymeForwardBackend{$(custom ? "custom" : "fallback")}()") +struct AutoChainRules{RC} <: AbstractADType + ruleconfig::RC end -""" - EnzymeReverseBackend <: AbstractReverseBackend +ruleconfig(backend::AutoChainRules) = backend.ruleconfig -Enables the use of [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) in reverse mode. +## Traits and access -!!! warning - This backend only works for scalar output. """ -struct EnzymeReverseBackend{custom} <: AbstractReverseBackend{custom} end + autodiff_mode(backend) -function Base.show(io::IO, ::EnzymeReverseBackend{custom}) where {custom} - return print(io, "EnzymeReverseBackend{$(custom ? "custom" : "fallback")}()") -end +Return `Val(:forward)` or `Val(:reverse)` in a statically predictable way. -""" - ForwardDiffBackend <: AbstractForwardBackend +This function must be overloaded for backends that do not inherit from `ADTypes.AbstractForwardMode` or `ADTypes.AbstractReverseMode` (e.g. because they support both forward and reverse). -Enables the use of [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). +We classify `ADTypes.AbstractFiniteDifferencesMode` as forward mode. """ -struct ForwardDiffBackend{custom} <: AbstractForwardBackend{custom} end - -function Base.show(io::IO, ::ForwardDiffBackend{custom}) where {custom} - return print(io, "ForwardDiffBackend{$(custom ? "custom" : "fallback")}()") -end +autodiff_mode(::AbstractForwardMode) = Val{:forward}() +autodiff_mode(::AbstractFiniteDifferencesMode) = Val{:forward}() +autodiff_mode(::AbstractReverseMode) = Val{:reverse}() """ - PolyesterForwardDiffBackend <: AbstractForwardBackend - -Enables the use of [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl), falling back on [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) if needed. + handles_input_type(backend, ::Type{X}) -!!! warning - This backend only works when the arrays are vectors. +Check if `backend` can differentiate functions with input type `X`. """ -struct PolyesterForwardDiffBackend{custom,C} <: AbstractForwardBackend{custom} end - -function Base.show(io::IO, ::PolyesterForwardDiffBackend{custom,C}) where {custom,C} - return print(io, "PolyesterForwardDiffBackend{$(custom ? "custom" : "fallback"),$C}()") -end +handles_input_type(::AbstractADType, ::Type{<:Number}) = true +handles_input_type(::AbstractADType, ::Type{<:AbstractArray}) = true """ - ReverseDiffBackend <: AbstractReverseBackend + handles_output_type(backend, ::Type{Y}) -Performs autodiff with [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl). +Check if `backend` can differentiate functions with output type `Y`. +""" +handles_output_type(::AbstractADType, ::Type{<:Number}) = true +handles_output_type(::AbstractADType, ::Type{<:AbstractArray}) = true -!!! warning - This backend only works for array input. """ -struct ReverseDiffBackend{custom} <: AbstractReverseBackend{custom} end + handles_types(backend, ::Type{X}, ::Type{Y}) -function Base.show(io::IO, ::ReverseDiffBackend{custom}) where {custom} - return print(io, "ReverseDiffBackend{$(custom ? "custom" : "fallback")}()") +Check if `backend` can differentiate functions with input type `X` and output type `Y`. +""" +function handles_types(backend::AbstractADType, ::Type{X}, ::Type{Y}) where {X,Y} + return handles_input_type(backend, X) && handles_output_type(backend, Y) end - -## Pseudo backends - -function ZygoteBackend end - -## Limitations - -handles_input_type(::ReverseDiffBackend, ::Type{<:Number}) = false -handles_output_type(::EnzymeReverseBackend, ::Type{<:AbstractArray}) = false diff --git a/src/backends_abstract.jl b/src/backends_abstract.jl deleted file mode 100644 index dbb9e540f..000000000 --- a/src/backends_abstract.jl +++ /dev/null @@ -1,68 +0,0 @@ - -""" - AbstractBackend - -Abstract type pointing to the AD package chosen by the user, which is called a "backend". - -## Custom - -When we say that a backend is "custom", it describes how the utilities (derivative, multiderivative, gradient and jacobian) are implemented: - -- Custom backends use specific routines defined in their package whenever they exist -- Non-custom backends use fallbacks defined in DifferentiationInterface.jl, which end up calling the pushforward or pullback -""" -abstract type AbstractBackend{custom} end - -""" - is_custom(backend) - -Return a boolean `custom` that describes how utilities (derivative, multiderivative, gradient and jacobian) are implemented. -""" -is_custom(::AbstractBackend{custom}) where {custom} = custom - -""" - handles_input_type(backend, ::Type{X}) - -Check if `backend` can differentiate functions with input type `X`. -""" -handles_input_type(::AbstractBackend, ::Type{<:Number}) = true -handles_input_type(::AbstractBackend, ::Type{<:AbstractArray}) = true - -""" - handles_output_type(backend, ::Type{Y}) - -Check if `backend` can differentiate functions with output type `Y`. -""" -handles_output_type(::AbstractBackend, ::Type{<:Number}) = true -handles_output_type(::AbstractBackend, ::Type{<:AbstractArray}) = true - -""" - handles_types(backend, ::Type{X}, ::Type{Y}) - -Check if `backend` can differentiate functions with input type `X` and output type `Y`. -""" -function handles_types(backend::AbstractBackend, ::Type{X}, ::Type{Y}) where {X,Y} - return handles_input_type(backend, X) && handles_output_type(backend, Y) -end - -""" - AbstractForwardBackend <: AbstractBackend - -Abstract subtype of [`AbstractBackend`](@ref) for forward mode AD packages. -""" -abstract type AbstractForwardBackend{custom} <: AbstractBackend{custom} end - -""" - AbstractReverseBackend <: AbstractBackend - -Abstract subtype of [`AbstractBackend`](@ref) for reverse mode AD packages. -""" -abstract type AbstractReverseBackend{custom} <: AbstractBackend{custom} end - -""" - autodiff_mode(backend) - -Return either `:forward` or `:reverse` depending on the mode of `backend`. -""" -autodiff_mode(::AbstractForwardBackend) = :forward -autodiff_mode(::AbstractReverseBackend) = :reverse diff --git a/src/pullback.jl b/src/pullback.jl index 501ff7c57..637524987 100644 --- a/src/pullback.jl +++ b/src/pullback.jl @@ -1,41 +1,41 @@ """ - value_and_pullback!(dx, backend::AbstractReverseBackend, f, x, dy) -> (y, dx) + value_and_pullback!(dx, backend, f, x, dy) -> (y, dx) Compute the primal value `y = f(x)` and the vector-Jacobian product `dx = ∂f(x)' * dy`, overwriting `dx` if possible. !!! info "Interface requirement" - This is the only required implementation for an [`AbstractReverseBackend`](@ref). + This is the only required implementation for a reverse mode backend. """ -function value_and_pullback!(dx, backend::AbstractReverseBackend, f, x, dy) +function value_and_pullback!(dx, backend::AbstractADType, f, x, dy) return error( - "Backend $backend is not loaded or does not support this type combination." + "Backend $backend is not loaded or does not support this type combination: `typeof(x) = $(typeof(x))` and `typeof(y) = $(typeof(dy))`", ) end """ - value_and_pullback(backend::AbstractReverseBackend, f, x, dy) -> (y, dx) + value_and_pullback(backend, f, x, dy) -> (y, dx) Compute the primal value `y = f(x)` and the vector-Jacobian product `dx = ∂f(x)' * dy`. """ -function value_and_pullback(backend::AbstractReverseBackend, f, x, dy) +function value_and_pullback(backend::AbstractADType, f, x, dy) dx = mysimilar(x) return value_and_pullback!(dx, backend, f, x, dy) end """ - pullback!(dx, backend::AbstractReverseBackend, f, x, dy) -> dx + pullback!(dx, backend, f, x, dy) -> dx Compute the vector-Jacobian product `dx = ∂f(x)' * dy`, overwriting `dx` if possible. """ -function pullback!(dx, backend::AbstractReverseBackend, f, x, dy) +function pullback!(dx, backend::AbstractADType, f, x, dy) return last(value_and_pullback!(dx, backend, f, x, dy)) end """ - pullback(backend::AbstractReverseBackend, f, x, dy) -> dx + pullback(backend, f, x, dy) -> dx Compute the vector-Jacobian product `dx = ∂f(x)' * dy`. """ -function pullback(backend::AbstractReverseBackend, f, x, dy) +function pullback(backend::AbstractADType, f, x, dy) return last(value_and_pullback(backend, f, x, dy)) end diff --git a/src/pushforward.jl b/src/pushforward.jl index e71b92197..a4f2a9c4e 100644 --- a/src/pushforward.jl +++ b/src/pushforward.jl @@ -1,41 +1,41 @@ """ - value_and_pushforward!(dy, backend::AbstractForwardBackend, f, x, dx) -> (y, dy) + value_and_pushforward!(dy, backend, f, x, dx) -> (y, dy) Compute the primal value `y = f(x)` and the Jacobian-vector product `dy = ∂f(x) * dx`, overwriting `dy` if possible. !!! info "Interface requirement" - This is the only required implementation for an [`AbstractForwardBackend`](@ref). + This is the only required implementation for a forward mode backend. """ -function value_and_pushforward!(dy, backend::AbstractForwardBackend, f, x, dx) +function value_and_pushforward!(dy, backend::AbstractADType, f, x, dx) return error( - "Backend $backend is not loaded or does not support this type combination." + "Backend $backend is not loaded or does not support this type combination: `typeof(x) = $(typeof(x))` and `typeof(y) = $(typeof(dy))`", ) end """ - value_and_pushforward(backend::AbstractForwardBackend, f, x, dx) -> (y, dy) + value_and_pushforward(backend, f, x, dx) -> (y, dy) Compute the primal value `y = f(x)` and the Jacobian-vector product `dy = ∂f(x) * dx`. """ -function value_and_pushforward(backend::AbstractForwardBackend, f, x, dx) +function value_and_pushforward(backend::AbstractADType, f, x, dx) dy = mysimilar(f(x)) return value_and_pushforward!(dy, backend, f, x, dx) end """ - pushforward!(dy, backend::AbstractForwardBackend, f, x, dx) -> dy + pushforward!(dy, backend, f, x, dx) -> dy Compute the Jacobian-vector product `dy = ∂f(x) * dx`, overwriting `dy` if possible. """ -function pushforward!(dy, backend::AbstractForwardBackend, f, x, dx) +function pushforward!(dy, backend::AbstractADType, f, x, dx) return last(value_and_pushforward!(dy, backend, f, x, dx)) end """ - pushforward(backend::AbstractForwardBackend, f, x, dx) -> dy + pushforward(backend, f, x, dx) -> dy Compute the Jacobian-vector product `dy = ∂f(x) * dx`. """ -function pushforward(backend::AbstractForwardBackend, f, x, dx) +function pushforward(backend::AbstractADType, f, x, dx) return last(value_and_pushforward(backend, f, x, dx)) end diff --git a/src/scalar_array.jl b/src/scalar_array.jl index 1b6aed54c..88fd3ed09 100644 --- a/src/scalar_array.jl +++ b/src/scalar_array.jl @@ -3,16 +3,38 @@ Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function, overwriting `multider` if possible. """ -function value_and_multiderivative! end +function value_and_multiderivative!( + multider::AbstractArray, backend::AbstractADType, f, x::Number +) + return value_and_multiderivative!(Val{:fallback}(), multider, backend, f, x) +end + +function value_and_multiderivative!( + implem::Val{:fallback}, multider::AbstractArray, backend::AbstractADType, f, x::Number +) + return value_and_multiderivative!( + implem, autodiff_mode(backend), multider, backend, f, x + ) +end function value_and_multiderivative!( - multider::AbstractArray, backend::AbstractForwardBackend, f, x::Number + ::Val{:fallback}, + ::Val{:forward}, + multider::AbstractArray, + backend::AbstractADType, + f, + x::Number, ) return value_and_pushforward!(multider, backend, f, x, one(x)) end function value_and_multiderivative!( - multider::AbstractArray, backend::AbstractReverseBackend, f, x::Number + ::Val{:fallback}, + ::Val{:reverse}, + multider::AbstractArray, + backend::AbstractADType, + f, + x::Number, ) y = f(x) for i in eachindex(IndexCartesian(), y) @@ -27,9 +49,15 @@ end Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function. """ -function value_and_multiderivative(backend::AbstractBackend, f, x::Number) +function value_and_multiderivative(backend::AbstractADType, f, x::Number) + return value_and_multiderivative(Val{:fallback}(), backend, f, x) +end + +function value_and_multiderivative( + implem::Val{:fallback}, backend::AbstractADType, f, x::Number +) multider = similar(f(x)) - return value_and_multiderivative!(multider, backend, f, x) + return value_and_multiderivative!(implem, multider, backend, f, x) end """ @@ -37,8 +65,14 @@ end Compute the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function, overwriting `multider` if possible. """ -function multiderivative!(multider::AbstractArray, backend::AbstractBackend, f, x::Number) - return last(value_and_multiderivative!(multider, backend, f, x)) +function multiderivative!(multider::AbstractArray, backend::AbstractADType, f, x::Number) + return multiderivative!(Val{:fallback}(), multider, backend, f, x) +end + +function multiderivative!( + implem::Val{:fallback}, multider::AbstractArray, backend::AbstractADType, f, x::Number +) + return last(value_and_multiderivative!(implem, multider, backend, f, x)) end """ @@ -46,6 +80,10 @@ end Compute the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function. """ -function multiderivative(backend::AbstractBackend, f, x::Number) - return last(value_and_multiderivative(backend, f, x)) +function multiderivative(backend::AbstractADType, f, x::Number) + return multiderivative(Val{:fallback}(), backend, f, x) +end + +function multiderivative(implem::Val{:fallback}, backend::AbstractADType, f, x::Number) + return last(value_and_multiderivative(implem, backend, f, x)) end diff --git a/src/scalar_scalar.jl b/src/scalar_scalar.jl index 7cf8c9c2a..9b9e16d5c 100644 --- a/src/scalar_scalar.jl +++ b/src/scalar_scalar.jl @@ -3,13 +3,23 @@ Compute the primal value `y = f(x)` and the derivative `der = f'(x)` of a scalar-to-scalar function. """ -function value_and_derivative end +function value_and_derivative(backend::AbstractADType, f, x::Number) + return value_and_derivative(Val{:fallback}(), backend::AbstractADType, f, x::Number) +end + +function value_and_derivative(implem::Val{:fallback}, backend::AbstractADType, f, x::Number) + return value_and_derivative(implem, autodiff_mode(backend), backend, f, x) +end -function value_and_derivative(backend::AbstractForwardBackend, f, x::Number) +function value_and_derivative( + ::Val{:fallback}, ::Val{:forward}, backend::AbstractADType, f, x::Number +) return value_and_pushforward!(one(x), backend, f, x, one(x)) end -function value_and_derivative(backend::AbstractReverseBackend, f, x::Number) +function value_and_derivative( + ::Val{:fallback}, ::Val{:reverse}, backend::AbstractADType, f, x::Number +) return value_and_pullback!(one(x), backend, f, x, one(x)) end @@ -18,6 +28,10 @@ end Compute the derivative `der = f'(x)` of a scalar-to-scalar function. """ -function derivative(backend::AbstractBackend, f, x::Number) - return last(value_and_derivative(backend, f, x)) +function derivative(backend::AbstractADType, f, x::Number) + return derivative(Val{:fallback}(), backend, f, x) +end + +function derivative(implem::Val{:fallback}, backend::AbstractADType, f, x::Number) + return last(value_and_derivative(implem, backend, f, x)) end diff --git a/src/utils.jl b/src/utils.jl index 8c4f9746a..4ad0de65b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -8,7 +8,7 @@ Construct the `i`-th stardard basis array in the vector space of `a` with elemen If an AD backend benefits from a more specialized basis array implementation, this function can be extended on the backend type. """ -basisarray(::AbstractBackend, a::AbstractArray, i) = basisarray(a, i) +basisarray(::AbstractADType, a::AbstractArray, i) = basisarray(a, i) function basisarray(a::AbstractArray{T,N}, i::CartesianIndex{N}) where {T,N} return OneElement(one(T), Tuple(i), axes(a)) diff --git a/test/Project.toml b/test/Project.toml index e23957ab6..7c5bbd3b0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,7 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" diff --git a/test/backends.jl b/test/backends.jl deleted file mode 100644 index b694a6a89..000000000 --- a/test/backends.jl +++ /dev/null @@ -1,11 +0,0 @@ -using DifferentiationInterface -using DifferentiationInterface: is_custom -using Test - -backend = ChainRulesForwardBackend{false,Nothing}(nothing) -@test !is_custom(backend) -@test autodiff_mode(backend) == :forward - -backend = ChainRulesReverseBackend{true,Nothing}(nothing) -@test is_custom(backend) -@test autodiff_mode(backend) == :reverse diff --git a/test/diffractor.jl b/test/diffractor.jl deleted file mode 100644 index 62b293e50..000000000 --- a/test/diffractor.jl +++ /dev/null @@ -1,15 +0,0 @@ -using DifferentiationInterface -using Diffractor: Diffractor - -# see https://github.com/JuliaDiff/Diffractor.jl/issues/277 - -@test_skip test_pushforward( - ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()), - scenarios; - type_stability=false, -); -@test_skip test_jacobian_and_friends( - ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()), - scenarios; - type_stability=false, -); diff --git a/test/enzyme_forward.jl b/test/enzyme_forward.jl index 9289a07dc..cb2b497b7 100644 --- a/test/enzyme_forward.jl +++ b/test/enzyme_forward.jl @@ -1,10 +1,8 @@ -using DifferentiationInterface +using ADTypes: AutoEnzyme using Enzyme: Enzyme -test_pushforward(EnzymeForwardBackend(), scenarios; type_stability=true); +test_pushforward(AutoEnzyme(Val(:forward)), scenarios; type_stability=true); +test_jacobian_and_friends(AutoEnzyme(Val(:forward)), scenarios; type_stability=true); test_jacobian_and_friends( - EnzymeForwardBackend(; custom=true), scenarios; type_stability=true -); -test_jacobian_and_friends( - EnzymeForwardBackend(; custom=false), scenarios; type_stability=true + AutoEnzyme(Val(:forward)), scenarios, Val(:fallback); type_stability=true ); diff --git a/test/enzyme_reverse.jl b/test/enzyme_reverse.jl index 022e64843..79965dacd 100644 --- a/test/enzyme_reverse.jl +++ b/test/enzyme_reverse.jl @@ -1,10 +1,8 @@ -using DifferentiationInterface +using ADTypes: AutoEnzyme using Enzyme: Enzyme -test_pullback(EnzymeReverseBackend(), scenarios; type_stability=true); +test_pullback(AutoEnzyme(Val(:reverse)), scenarios; type_stability=true); +test_jacobian_and_friends(AutoEnzyme(Val(:reverse)), scenarios; type_stability=true) test_jacobian_and_friends( - EnzymeReverseBackend(; custom=true), scenarios; type_stability=true -) -test_jacobian_and_friends( - EnzymeReverseBackend(; custom=false), scenarios; type_stability=true + AutoEnzyme(Val(:reverse)), scenarios, Val(:fallback); type_stability=true ) diff --git a/test/finitediff.jl b/test/finitediff.jl index b976b1f62..e0def3a32 100644 --- a/test/finitediff.jl +++ b/test/finitediff.jl @@ -1,6 +1,6 @@ -using DifferentiationInterface +using ADTypes: AutoFiniteDiff using FiniteDiff: FiniteDiff -test_pushforward(FiniteDiffBackend(), scenarios; type_stability=true); -test_jacobian_and_friends(FiniteDiffBackend(; custom=true), scenarios; type_stability=false); -test_jacobian_and_friends(FiniteDiffBackend(; custom=false), scenarios; type_stability=true); +test_pushforward(AutoFiniteDiff(), scenarios; type_stability=true); +test_jacobian_and_friends(AutoFiniteDiff(), scenarios; type_stability=false); +test_jacobian_and_friends(AutoFiniteDiff(), scenarios, Val(:fallback); type_stability=false); diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index f6344fc61..3f75b9dc7 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -1,10 +1,8 @@ -using DifferentiationInterface +using ADTypes: AutoForwardDiff using ForwardDiff: ForwardDiff -test_pushforward(ForwardDiffBackend(), scenarios; type_stability=true); +test_pushforward(AutoForwardDiff(), scenarios; type_stability=true); +test_jacobian_and_friends(AutoForwardDiff(), scenarios; type_stability=false); test_jacobian_and_friends( - ForwardDiffBackend(; custom=true), scenarios; type_stability=false -); -test_jacobian_and_friends( - ForwardDiffBackend(; custom=false), scenarios; type_stability=true + AutoForwardDiff(), scenarios, Val(:fallback); type_stability=false ); diff --git a/test/polyesterforwarddiff.jl b/test/polyesterforwarddiff.jl index 1a3a51233..408e69c55 100644 --- a/test/polyesterforwarddiff.jl +++ b/test/polyesterforwarddiff.jl @@ -1,14 +1,12 @@ -using DifferentiationInterface +using ADTypes: AutoPolyesterForwardDiff using PolyesterForwardDiff: PolyesterForwardDiff # see https://github.com/JuliaDiff/PolyesterForwardDiff.jl/issues/17 -test_pushforward( - PolyesterForwardDiffBackend(4; custom=true), scenarios; type_stability=false -); +test_pushforward(AutoPolyesterForwardDiff(; chunksize=4), scenarios; type_stability=false); test_jacobian_and_friends( - PolyesterForwardDiffBackend(4; custom=true), + AutoPolyesterForwardDiff(; chunksize=4), scenarios; input_type=Union{Number,AbstractVector}, output_type=Union{Number,AbstractVector}, diff --git a/test/reversediff.jl b/test/reversediff.jl index 748949ac6..80fb2a0d2 100644 --- a/test/reversediff.jl +++ b/test/reversediff.jl @@ -1,10 +1,8 @@ -using DifferentiationInterface +using ADTypes: AutoReverseDiff using ReverseDiff: ReverseDiff -test_pullback(ReverseDiffBackend(), scenarios; type_stability=false); +test_pullback(AutoReverseDiff(), scenarios; type_stability=false); +test_jacobian_and_friends(AutoReverseDiff(), scenarios; type_stability=false); test_jacobian_and_friends( - ReverseDiffBackend(; custom=true), scenarios; type_stability=false -); -test_jacobian_and_friends( - ReverseDiffBackend(; custom=false), scenarios; type_stability=false + AutoReverseDiff(), scenarios, Val(:fallback); type_stability=false ); diff --git a/test/runtests.jl b/test/runtests.jl index 50f0f19d2..628d366c9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,13 +25,6 @@ include("utils.jl") JET.test_package(DifferentiationInterface; target_defined_modules=true) end - @testset "Backend utilities" begin - include("backends.jl") - end - - @testset "Diffractor" begin - include("diffractor.jl") - end @testset "Enzyme (forward)" begin include("enzyme_forward.jl") end diff --git a/test/utils.jl b/test/utils.jl index 7fed9e7e2..067114ab0 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,6 +1,5 @@ +using ADTypes: AbstractADType using DifferentiationInterface -using DifferentiationInterface: - AbstractBackend, AbstractReverseBackend, AbstractForwardBackend using ForwardDiff: ForwardDiff using LinearAlgebra using JET @@ -150,17 +149,20 @@ scenarios = [ ## Test utilities function test_pushforward( - backend::AbstractForwardBackend, + backend::AbstractADType, scenarios::Vector{<:Scenario}; input_type::Type=Any, output_type::Type=Any, allocs::Bool=false, type_stability::Bool=true, ) + if autodiff_mode(backend) != Val(:forward) + return nothing + end scenarios = filter(scenarios) do s get_input_type(s) <: input_type && get_output_type(s) <: output_type end - @testset "Pushforward ($(is_custom(backend) ? "custom" : "fallback"))" begin + @testset "Pushforward" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -205,17 +207,20 @@ function test_pushforward( end function test_pullback( - backend::AbstractReverseBackend, + backend::AbstractADType, scenarios::Vector{<:Scenario}; input_type::Type=Any, output_type::Type=Any, allocs::Bool=false, type_stability::Bool=true, ) + if autodiff_mode(backend) != Val(:reverse) + return nothing + end scenarios = filter(scenarios) do s (get_input_type(s) <: input_type) && (get_output_type(s) <: output_type) end - @testset "Pullback ($(is_custom(backend) ? "custom" : "fallback"))" begin + @testset "Pullback" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -261,15 +266,17 @@ function test_pullback( end function test_derivative( - backend::AbstractBackend, - scenarios::Vector{<:Scenario}; + backend::AbstractADType, + scenarios::Vector{<:Scenario}, + implems...; allocs::Bool=false, type_stability::Bool=true, ) scenarios = filter(scenarios) do s (get_input_type(s) <: Number) && (get_output_type(s) <: Number) end - @testset "Derivative ($(is_custom(backend) ? "custom" : "fallback"))" begin + testset_name = isempty(implems) ? "" : "(fallback)" + @testset "Derivative $testset_name" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -277,9 +284,9 @@ function test_derivative( @testset "$X -> $Y" begin (; f, x, y, der_true) = scenario - y_out1, der_out1 = value_and_derivative(backend, f, x) + y_out1, der_out1 = value_and_derivative(implems..., backend, f, x) - der_out2 = derivative(backend, f, x) + der_out2 = derivative(implems..., backend, f, x) @testset "Primal value" begin @test y_out1 ≈ y @@ -289,12 +296,12 @@ function test_derivative( @test der_out2 ≈ der_true rtol = 1e-3 end allocs && @testset "Allocations" begin - @test iszero(@allocated value_and_derivative(backend, f, x)) - @test iszero(@allocated derivative(backend, f, x)) + @test iszero(@allocated value_and_derivative(implems..., backend, f, x)) + @test iszero(@allocated derivative(implems..., backend, f, x)) end type_stability && @testset "Type stability" begin - @test_opt value_and_derivative(backend, f, x) - @test_opt derivative(backend, f, x) + @test_opt value_and_derivative(implems..., backend, f, x) + @test_opt derivative(implems..., backend, f, x) end end end @@ -302,15 +309,17 @@ function test_derivative( end function test_multiderivative( - backend::AbstractBackend, - scenarios::Vector{<:Scenario}; + backend::AbstractADType, + scenarios::Vector{<:Scenario}, + implems...; allocs::Bool=false, type_stability::Bool=true, ) scenarios = filter(scenarios) do s (get_input_type(s) <: Number) && (get_output_type(s) <: AbstractArray) end - @testset "Multiderivative ($(is_custom(backend) ? "custom" : "fallback"))" begin + testset_name = isempty(implems) ? "" : "(fallback)" + @testset "Multiderivative $testset_name" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -318,15 +327,15 @@ function test_multiderivative( @testset "$X -> $Y" begin (; f, x, y, multider_true) = scenario - y_out1, multider_out1 = value_and_multiderivative(backend, f, x) + y_out1, multider_out1 = value_and_multiderivative(implems..., backend, f, x) multider_in2 = zero(multider_out1) y_out2, multider_out2 = value_and_multiderivative!( - multider_in2, backend, f, x + implems..., multider_in2, backend, f, x ) - multider_out3 = multiderivative(backend, f, x) + multider_out3 = multiderivative(implems..., backend, f, x) multider_in4 = zero(multider_out3) - multider_out4 = multiderivative!(multider_in4, backend, f, x) + multider_out4 = multiderivative!(implems..., multider_in4, backend, f, x) @testset "Primal value" begin @test y_out1 ≈ y @@ -344,13 +353,21 @@ function test_multiderivative( end allocs && @testset "Allocations" begin @test iszero( - @allocated value_and_multiderivative!(multider_in2, backend, f, x) + @allocated value_and_multiderivative!( + implems..., multider_in2, backend, f, x + ) + ) + @test iszero( + @allocated multiderivative!( + implems..., multider_in4, backend, f, x + ) ) - @test iszero(@allocated multiderivative!(multider_in4, backend, f, x)) end type_stability && @testset "Type stability" begin - @test_opt value_and_multiderivative!(multider_in2, backend, f, x) - @test_opt multiderivative!(multider_in4, backend, f, x) + @test_opt value_and_multiderivative!( + implems..., multider_in2, backend, f, x + ) + @test_opt multiderivative!(implems..., multider_in4, backend, f, x) end end end @@ -358,15 +375,17 @@ function test_multiderivative( end function test_gradient( - backend::AbstractBackend, - scenarios::Vector{<:Scenario}; + backend::AbstractADType, + scenarios::Vector{<:Scenario}, + implems...; allocs::Bool=false, type_stability::Bool=true, ) scenarios = filter(scenarios) do s (get_input_type(s) <: AbstractArray) && (get_output_type(s) <: Number) end - @testset "Gradient ($(is_custom(backend) ? "custom" : "fallback"))" begin + testset_name = isempty(implems) ? "" : "(fallback)" + @testset "Gradient $testset_name" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -374,13 +393,13 @@ function test_gradient( @testset "$X -> $Y" begin (; f, x, y, grad_true) = scenario - y_out1, grad_out1 = value_and_gradient(backend, f, x) + y_out1, grad_out1 = value_and_gradient(implems..., backend, f, x) grad_in2 = zero(grad_out1) - y_out2, grad_out2 = value_and_gradient!(grad_in2, backend, f, x) + y_out2, grad_out2 = value_and_gradient!(implems..., grad_in2, backend, f, x) - grad_out3 = gradient(backend, f, x) + grad_out3 = gradient(implems..., backend, f, x) grad_in4 = zero(grad_out3) - grad_out4 = gradient!(grad_in4, backend, f, x) + grad_out4 = gradient!(implems..., grad_in4, backend, f, x) @testset "Primal value" begin @test y_out1 ≈ y @@ -397,12 +416,14 @@ function test_gradient( end end allocs && @testset "Allocations" begin - @test iszero(@allocated value_and_gradient!(grad_in2, backend, f, x)) - @test iszero(@allocated gradient!(grad_in4, backend, f, x)) + @test iszero( + @allocated value_and_gradient!(implems..., grad_in2, backend, f, x) + ) + @test iszero(@allocated gradient!(implems..., grad_in4, backend, f, x)) end type_stability && @testset "Type stability" begin - @test_opt value_and_gradient!(grad_in2, backend, f, x) - @test_opt gradient!(grad_in4, backend, f, x) + @test_opt value_and_gradient!(implems..., grad_in2, backend, f, x) + @test_opt gradient!(implems..., grad_in4, backend, f, x) end end end @@ -410,15 +431,17 @@ function test_gradient( end function test_jacobian( - backend::AbstractBackend, - scenarios::Vector{<:Scenario}; + backend::AbstractADType, + scenarios::Vector{<:Scenario}, + implems...; allocs::Bool=false, type_stability::Bool=true, ) scenarios = filter(scenarios) do s (get_input_type(s) <: AbstractArray) && (get_output_type(s) <: AbstractArray) end - @testset "Jacobian ($(is_custom(backend) ? "custom" : "fallback"))" begin + testset_name = isempty(implems) ? "" : "(fallback)" + @testset "Jacobian $testset_name" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -426,13 +449,13 @@ function test_jacobian( @testset "$X -> $Y" begin (; f, x, y, jac_true) = scenario - y_out1, jac_out1 = value_and_jacobian(backend, f, x) + y_out1, jac_out1 = value_and_jacobian(implems..., backend, f, x) jac_in2 = zero(jac_out1) - y_out2, jac_out2 = value_and_jacobian!(jac_in2, backend, f, x) + y_out2, jac_out2 = value_and_jacobian!(implems..., jac_in2, backend, f, x) - jac_out3 = jacobian(backend, f, x) + jac_out3 = jacobian(implems..., backend, f, x) jac_in4 = zero(jac_out3) - jac_out4 = jacobian!(jac_in4, backend, f, x) + jac_out4 = jacobian!(implems..., jac_in4, backend, f, x) @testset "Primal value" begin @test y_out1 ≈ y @@ -449,12 +472,14 @@ function test_jacobian( end end allocs && @testset "Allocations" begin - @test iszero(@allocated value_and_jacobian!(jac_in2, backend, f, x)) - @test iszero(@allocated jacobian!(jac_in4, backend, f, x)) + @test iszero( + @allocated value_and_jacobian!(implems..., jac_in2, backend, f, x) + ) + @test iszero(@allocated jacobian!(implems..., jac_in4, backend, f, x)) end type_stability && @testset "Type stability" begin - @test_opt value_and_jacobian!(jac_in2, backend, f, x) - @test_opt jacobian!(jac_in4, backend, f, x) + @test_opt value_and_jacobian!(implems..., jac_in2, backend, f, x) + @test_opt jacobian!(implems..., jac_in4, backend, f, x) end end end @@ -462,8 +487,9 @@ function test_jacobian( end function test_jacobian_and_friends( - backend::AbstractBackend, - scenarios::Vector{<:Scenario}; + backend::AbstractADType, + scenarios::Vector{<:Scenario}, + implems...; input_type::Type=Any, output_type::Type=Any, allocs::Bool=false, @@ -472,9 +498,9 @@ function test_jacobian_and_friends( scenarios = filter(scenarios) do s (get_input_type(s) <: input_type) && (get_output_type(s) <: output_type) end - test_derivative(backend, scenarios; allocs, type_stability) - test_multiderivative(backend, scenarios; allocs, type_stability) - test_gradient(backend, scenarios; allocs, type_stability) - test_jacobian(backend, scenarios; allocs, type_stability) + test_derivative(backend, scenarios, implems...; allocs, type_stability) + test_multiderivative(backend, scenarios, implems...; allocs, type_stability) + test_gradient(backend, scenarios, implems...; allocs, type_stability) + test_jacobian(backend, scenarios, implems...; allocs, type_stability) return nothing end diff --git a/test/zygote.jl b/test/zygote.jl index 503dfc4c0..4839be695 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -1,6 +1,6 @@ -using DifferentiationInterface +using ADTypes: AutoZygote using Zygote: Zygote -test_pullback(ZygoteBackend(), scenarios; type_stability=false); -test_jacobian_and_friends(ZygoteBackend(; custom=true), scenarios; type_stability=false); -test_jacobian_and_friends(ZygoteBackend(; custom=false), scenarios; type_stability=false); +test_pullback(AutoZygote(), scenarios; type_stability=false); +test_jacobian_and_friends(AutoZygote(), scenarios; type_stability=false); +test_jacobian_and_friends(AutoZygote(), scenarios, Val(:fallback); type_stability=false); From f96dd52e7ddb7631b17d780e1b6f95ad3ef0a491 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 8 Mar 2024 17:36:28 +0100 Subject: [PATCH 2/9] No type piracy in printing --- benchmark/utils.jl | 88 +++++++++++++++++++++++----------------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/benchmark/utils.jl b/benchmark/utils.jl index 0a1d7f6bb..e5620e451 100644 --- a/benchmark/utils.jl +++ b/benchmark/utils.jl @@ -3,15 +3,15 @@ using ADTypes: AbstractADType using BenchmarkTools using DifferentiationInterface -## Pretty printing with type piracy +## Pretty printing -Base.string(::AutoEnzyme{Val{:forward}}) = "Enzyme (forward)" -Base.string(::AutoEnzyme{Val{:reverse}}) = "Enzyme (reverse)" -Base.string(::AutoFiniteDiff) = "FiniteDiff" -Base.string(::AutoForwardDiff) = "ForwardDiff" -Base.string(::AutoPolyesterForwardDiff) = "PolyesterForwardDiff" -Base.string(::AutoReverseDiff) = "ReverseDiff" -Base.string(::AutoZygote) = "Zygote" +pretty(::AutoEnzyme{Val{:forward}}) = "Enzyme (forward)" +pretty(::AutoEnzyme{Val{:reverse}}) = "Enzyme (reverse)" +pretty(::AutoFiniteDiff) = "FiniteDiff" +pretty(::AutoForwardDiff) = "ForwardDiff" +pretty(::AutoPolyesterForwardDiff) = "PolyesterForwardDiff" +pretty(::AutoReverseDiff) = "ReverseDiff" +pretty(::AutoZygote) = "Zygote" ## Benchmark suite @@ -27,17 +27,17 @@ function add_pushforward_benchmarks!( return nothing end - suite["value_and_pushforward"][(n, m)]["$(string(backend))"] = @benchmarkable begin + suite["value_and_pushforward"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin value_and_pushforward($backend, $f, $x, $dx) end - suite["value_and_pushforward!"][(n, m)]["$(string(backend))"] = @benchmarkable begin + suite["value_and_pushforward!"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin value_and_pushforward!($dy, $backend, $f, $x, $dx) end - suite["pushforward"][(n, m)]["$(string(backend))"] = @benchmarkable begin + suite["pushforward"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin pushforward($backend, $f, $x, $dx) end - suite["pushforward!"][(n, m)]["$(string(backend))"] = @benchmarkable begin + suite["pushforward!"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin pushforward!($dy, $backend, $f, $x, $dx) end @@ -56,17 +56,17 @@ function add_pullback_benchmarks!( return nothing end - suite["value_and_pullback"][(n, m)]["$(string(backend))"] = @benchmarkable begin + suite["value_and_pullback"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin value_and_pullback($backend, $f, $x, $dy) end - suite["value_and_pullback!"][(n, m)]["$(string(backend))"] = @benchmarkable begin + suite["value_and_pullback!"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin value_and_pullback!($dx, $backend, $f, $x, $dy) end - suite["pullback"][(n, m)]["$(string(backend))"] = @benchmarkable begin + suite["pullback"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin pullback($backend, $f, $x, $dy) end - suite["pullback!"][(n, m)]["$(string(backend))"] = @benchmarkable begin + suite["pullback!"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin pullback!($dx, $backend, $f, $x, $dy) end @@ -83,19 +83,19 @@ function add_derivative_benchmarks!( x = randn() - suite["value_and_derivative"][(1, 1)]["$(string(backend))"] = @benchmarkable begin + suite["value_and_derivative"][(1, 1)]["$(pretty(backend))"] = @benchmarkable begin value_and_derivative($backend, $f, $x) end - suite["value_and_derivative"][(1, 1)]["$(string(backend)) - fallback"] = @benchmarkable begin + suite["value_and_derivative"][(1, 1)]["$(pretty(backend)) - fallback"] = @benchmarkable begin value_and_derivative(Val(:fallback), $backend, $f, $x) end - suite["derivative"][(1, 1)]["$(string(backend))"] = @benchmarkable begin + suite["derivative"][(1, 1)]["$(pretty(backend))"] = @benchmarkable begin derivative($backend, $f, $x) end - suite["derivative"][(1, 1)]["$(string(backend)) - fallback"] = @benchmarkable begin + suite["derivative"][(1, 1)]["$(pretty(backend)) - fallback"] = @benchmarkable begin derivative(Val(:fallback), $backend, $f, $x) end @@ -113,31 +113,31 @@ function add_multiderivative_benchmarks!( x = randn() multider = zeros(m) - suite["value_and_multiderivative"][(1, m)]["$(string(backend))"] = @benchmarkable begin + suite["value_and_multiderivative"][(1, m)]["$(pretty(backend))"] = @benchmarkable begin value_and_multiderivative($backend, $f, $x) end - suite["value_and_multiderivative!"][(1, m)]["$(string(backend))"] = @benchmarkable begin + suite["value_and_multiderivative!"][(1, m)]["$(pretty(backend))"] = @benchmarkable begin value_and_multiderivative!($multider, $backend, $f, $x) end - suite["value_and_multiderivative"][(1, m)]["$(string(backend)) - fallback"] = @benchmarkable begin + suite["value_and_multiderivative"][(1, m)]["$(pretty(backend)) - fallback"] = @benchmarkable begin value_and_multiderivative(Val(:fallback), $backend, $f, $x) end - suite["value_and_multiderivative!"][(1, m)]["$(string(backend)) - fallback"] = @benchmarkable begin + suite["value_and_multiderivative!"][(1, m)]["$(pretty(backend)) - fallback"] = @benchmarkable begin value_and_multiderivative!(Val(:fallback), $multider, $backend, $f, $x) end - suite["multiderivative"][(1, m)]["$(string(backend))"] = @benchmarkable begin + suite["multiderivative"][(1, m)]["$(pretty(backend))"] = @benchmarkable begin multiderivative($backend, $f, $x) end - suite["multiderivative!"][(1, m)]["$(string(backend))"] = @benchmarkable begin + suite["multiderivative!"][(1, m)]["$(pretty(backend))"] = @benchmarkable begin multiderivative!($multider, $backend, $f, $x) end - suite["multiderivative"][(1, m)]["$(string(backend)) - fallback"] = @benchmarkable begin + suite["multiderivative"][(1, m)]["$(pretty(backend)) - fallback"] = @benchmarkable begin multiderivative(Val(:fallback), $backend, $f, $x) end - suite["multiderivative!"][(1, m)]["$(string(backend)) - fallback"] = @benchmarkable begin + suite["multiderivative!"][(1, m)]["$(pretty(backend)) - fallback"] = @benchmarkable begin multiderivative!(Val(:fallback), $multider, $backend, $f, $x) end @@ -155,31 +155,31 @@ function add_gradient_benchmarks!( x = randn(n) grad = zeros(n) - suite["value_and_gradient"][(n, 1)]["$(string(backend))"] = @benchmarkable begin + suite["value_and_gradient"][(n, 1)]["$(pretty(backend))"] = @benchmarkable begin value_and_gradient($backend, $f, $x) end - suite["value_and_gradient!"][(n, 1)]["$(string(backend))"] = @benchmarkable begin + suite["value_and_gradient!"][(n, 1)]["$(pretty(backend))"] = @benchmarkable begin value_and_gradient!($grad, $backend, $f, $x) end - suite["value_and_gradient"][(n, 1)]["$(string(backend)) - fallback"] = @benchmarkable begin + suite["value_and_gradient"][(n, 1)]["$(pretty(backend)) - fallback"] = @benchmarkable begin value_and_gradient(Val(:fallback), $backend, $f, $x) end - suite["value_and_gradient!"][(n, 1)]["$(string(backend)) - fallback"] = @benchmarkable begin + suite["value_and_gradient!"][(n, 1)]["$(pretty(backend)) - fallback"] = @benchmarkable begin value_and_gradient!(Val(:fallback), $grad, $backend, $f, $x) end - suite["gradient"][(n, 1)]["$(string(backend))"] = @benchmarkable begin + suite["gradient"][(n, 1)]["$(pretty(backend))"] = @benchmarkable begin gradient($backend, $f, $x) end - suite["gradient!"][(n, 1)]["$(string(backend))"] = @benchmarkable begin + suite["gradient!"][(n, 1)]["$(pretty(backend))"] = @benchmarkable begin gradient!($grad, $backend, $f, $x) end - suite["gradient"][(n, 1)]["$(string(backend)) - fallback"] = @benchmarkable begin + suite["gradient"][(n, 1)]["$(pretty(backend)) - fallback"] = @benchmarkable begin gradient(Val(:fallback), $backend, $f, $x) end - suite["gradient!"][(n, 1)]["$(string(backend)) - fallback"] = @benchmarkable begin + suite["gradient!"][(n, 1)]["$(pretty(backend)) - fallback"] = @benchmarkable begin gradient!(Val(:fallback), $grad, $backend, $f, $x) end @@ -196,31 +196,31 @@ function add_jacobian_benchmarks!( x = randn(n) jac = zeros(m, n) - suite["value_and_jacobian"][(n, m)]["$(string(backend))"] = @benchmarkable begin + suite["value_and_jacobian"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin value_and_jacobian($backend, $f, $x) end - suite["value_and_jacobian!"][(n, m)]["$(string(backend))"] = @benchmarkable begin + suite["value_and_jacobian!"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin value_and_jacobian!($jac, $backend, $f, $x) end - suite["value_and_jacobian"][(n, m)]["$(string(backend)) (:fallback)"] = @benchmarkable begin + suite["value_and_jacobian"][(n, m)]["$(pretty(backend)) (:fallback)"] = @benchmarkable begin value_and_jacobian(Val(:fallback), $backend, $f, $x) end - suite["value_and_jacobian!"][(n, m)]["$(string(backend)) (:fallback)"] = @benchmarkable begin + suite["value_and_jacobian!"][(n, m)]["$(pretty(backend)) (:fallback)"] = @benchmarkable begin value_and_jacobian!(Val(:fallback), $jac, $backend, $f, $x) end - suite["jacobian"][(n, m)]["$(string(backend))"] = @benchmarkable begin + suite["jacobian"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin jacobian($backend, $f, $x) end - suite["jacobian!"][(n, m)]["$(string(backend))"] = @benchmarkable begin + suite["jacobian!"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin jacobian!($jac, $backend, $f, $x) end - suite["jacobian"][(n, m)]["$(string(backend)) - fallback"] = @benchmarkable begin + suite["jacobian"][(n, m)]["$(pretty(backend)) - fallback"] = @benchmarkable begin jacobian(Val(:fallback), $backend, $f, $x) end - suite["jacobian!"][(n, m)]["$(string(backend)) - fallback"] = @benchmarkable begin + suite["jacobian!"][(n, m)]["$(pretty(backend)) - fallback"] = @benchmarkable begin jacobian!(Val(:fallback), $jac, $backend, $f, $x) end From 2212beb7e68e773be3577df057cae0b45cffefc4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 8 Mar 2024 17:38:40 +0100 Subject: [PATCH 3/9] No docs index --- docs/src/index.md | 60 ----------------------------------------------- 1 file changed, 60 deletions(-) delete mode 100644 docs/src/index.md diff --git a/docs/src/index.md b/docs/src/index.md deleted file mode 100644 index 55f213137..000000000 --- a/docs/src/index.md +++ /dev/null @@ -1,60 +0,0 @@ -```@meta -EditURL = "https://github.com/gdalle/DifferentiationInterface.jl/blob/main/README.md" -``` - -# DifferentiationInterface - -[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://gdalle.github.io/DifferentiationInterface.jl/dev/) -[![Build Status](https://github.com/gdalle/DifferentiationInterface.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/gdalle/DifferentiationInterface.jl/actions/workflows/CI.yml?query=branch%3Amain) -[![Coverage](https://codecov.io/gh/gdalle/DifferentiationInterface.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/gdalle/DifferentiationInterface.jl) -[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) - -An interface to various automatic differentiation backends in Julia. - -## Goal - -This package provides a backend-agnostic syntax to differentiate functions `f(x) = y`, where `x` and `y` are either numbers or abstract arrays. - -It started out as an experimental redesign for [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl). - -## Supported backends - -We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl): - -- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) with `AutoEnzyme(Val(:forward))` -- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) with `AutoFiniteDiff()` -- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) with `AutoForwardDiff()` -- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) with `AutoPolyesterForwardDiff(; chunksize=C)` -- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) with `AutoReverseDiff()` -- [Zygote.jl](https://github.com/FluxML/Zygote.jl) with `AutoZygote()` - -We also support one more backend which is not yet part of ADTypes.jl (see [ADTypes.jl#21](https://github.com/SciML/ADTypes.jl/pull/21)): - -- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) with `AutoChainRules(ruleconfig)` - -## Design - -Each backend must implement only one primitive: - -- forward mode: the pushforward, computing a Jacobian-vector product -- reverse mode: the pullback, computing a vector-Jacobian product - -From these primitives, several utilities are defined, depending on the type of the input and output: - -| | scalar output | array output | -| ------------ | ------------- | --------------- | -| scalar input | derivative | multiderivative | -| array input | gradient | jacobian | - -## Example - -```jldoctest -julia> import DifferentiationInterface, ADTypes, ForwardDiff - -julia> backend = ADTypes.AutoForwardDiff(); - -julia> f(x) = sum(abs2, x); - -julia> DifferentiationInterface.value_and_gradient(backend, f, [1., 2., 3.]) -(14.0, [2.0, 4.0, 6.0]) -``` From 367926d441cde5020d33ab63ecd07e296298e32f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 10 Mar 2024 19:41:39 +0100 Subject: [PATCH 4/9] Change implem specification, add Diffractor --- Project.toml | 14 +- README.md | 5 +- benchmark/Project.toml | 1 + benchmark/benchmarks.jl | 3 +- benchmark/utils.jl | 147 ++++------ docs/Project.toml | 1 + docs/make.jl | 6 +- docs/src/interface.md | 36 ++- ...fferentiationInterfaceChainRulesCoreExt.jl | 8 +- ext/DifferentiationInterfaceDiffractorExt.jl | 23 ++ .../DifferentiationInterfaceEnzymeExt.jl | 1 + .../forward.jl | 4 +- .../reverse.jl | 6 +- ext/DifferentiationInterfaceFiniteDiffExt.jl | 23 +- ext/DifferentiationInterfaceForwardDiffExt.jl | 19 +- ...tiationInterfacePolyesterForwardDiffExt.jl | 5 +- ext/DifferentiationInterfaceReverseDiffExt.jl | 13 +- ext/DifferentiationInterfaceZygoteExt.jl | 28 +- src/DifferentiationInterface.jl | 6 +- src/array_array.jl | 34 +-- src/array_scalar.jl | 34 +-- src/backends.jl | 21 +- src/custom.jl | 29 ++ src/implem.jl | 17 ++ src/mode.jl | 17 ++ src/scalar_array.jl | 34 +-- src/scalar_scalar.jl | 16 +- src/utils.jl | 3 + test/Project.toml | 1 + test/chainrules_forward.jl | 10 + test/chainrules_reverse.jl | 10 + test/diffractor.jl | 8 + test/enzyme_forward.jl | 7 +- test/enzyme_reverse.jl | 7 +- test/finitediff.jl | 7 +- test/forwarddiff.jl | 7 +- test/polyesterforwarddiff.jl | 2 + test/reversediff.jl | 7 +- test/runtests.jl | 12 +- test/scenarios.jl | 143 ++++++++++ test/utils.jl | 256 ++++-------------- test/zygote.jl | 5 +- 42 files changed, 584 insertions(+), 452 deletions(-) create mode 100644 ext/DifferentiationInterfaceDiffractorExt.jl create mode 100644 src/custom.jl create mode 100644 src/implem.jl create mode 100644 src/mode.jl create mode 100644 test/chainrules_forward.jl create mode 100644 test/chainrules_reverse.jl create mode 100644 test/diffractor.jl create mode 100644 test/scenarios.jl diff --git a/Project.toml b/Project.toml index 681abadcf..527505474 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,9 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" @@ -21,16 +23,26 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore" +DifferentiationInterfaceDiffractorExt = [ + "Diffractor", + "AbstractDifferentiation", +] DifferentiationInterfaceEnzymeExt = "Enzyme" DifferentiationInterfaceFiniteDiffExt = "FiniteDiff" DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] -DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"] +DifferentiationInterfacePolyesterForwardDiffExt = [ + "PolyesterForwardDiff", + "ForwardDiff", + "DiffResults", +] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceZygoteExt = ["Zygote"] [compat] +AbstractDifferentiation = "0.6" ADTypes = "0.2.6" ChainRulesCore = "1.19" +Diffractor = "0.2" DiffResults = "1.1" DocStringExtensions = "0.9" Enzyme = "0.11" diff --git a/README.md b/README.md index a7d9707a0..70e34e67d 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ This package provides a backend-agnostic syntax to differentiate functions `f(x) It started out as an experimental redesign for [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl). -## Supported backends +## Compatibility We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl): @@ -24,9 +24,10 @@ We support some of the backends defined by [ADTypes.jl](https://github.com/SciML - [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) with `AutoReverseDiff()` - [Zygote.jl](https://github.com/FluxML/Zygote.jl) with `AutoZygote()` -We also support one more backend which is not yet part of ADTypes.jl (see [ADTypes.jl#21](https://github.com/SciML/ADTypes.jl/pull/21)): +We also support two more backends which are not yet part of ADTypes.jl: - [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) with `AutoChainRules(ruleconfig)` +- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) with `AutoDiffractor()` ## Design diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 5961d7e8a..62de913cf 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -4,6 +4,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index e620b7bdd..9d6953cf2 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -8,7 +8,7 @@ using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using PolyesterForwardDiff: PolyesterForwardDiff using ReverseDiff: ReverseDiff -using Zygote: Zygote +using Zygote: Zygote, ZygoteRuleConfig ## Settings @@ -43,6 +43,7 @@ end ## Backends all_backends = [ + AutoChainRules(ZygoteRuleConfig()), AutoEnzyme(Val(:forward)), AutoEnzyme(Val(:reverse)), AutoFiniteDiff(), diff --git a/benchmark/utils.jl b/benchmark/utils.jl index e5620e451..08f08543d 100644 --- a/benchmark/utils.jl +++ b/benchmark/utils.jl @@ -2,9 +2,11 @@ using ADTypes using ADTypes: AbstractADType using BenchmarkTools using DifferentiationInterface +using DifferentiationInterface: CustomImplem, FallbackImplem, ForwardMode, ReverseMode ## Pretty printing +pretty(::AutoChainRules{<:ZygoteRuleConfig}) = "ChainRules (Zygote)" pretty(::AutoEnzyme{Val{:forward}}) = "Enzyme (forward)" pretty(::AutoEnzyme{Val{:reverse}}) = "Enzyme (reverse)" pretty(::AutoFiniteDiff) = "FiniteDiff" @@ -13,6 +15,9 @@ pretty(::AutoPolyesterForwardDiff) = "PolyesterForwardDiff" pretty(::AutoReverseDiff) = "ReverseDiff" pretty(::AutoZygote) = "Zygote" +pretty(::CustomImplem) = "custom" +pretty(::FallbackImplem) = "fallback" + ## Benchmark suite function add_pushforward_benchmarks!( @@ -22,7 +27,7 @@ function add_pushforward_benchmarks!( dx = n == 1 ? randn() : randn(n) dy = m == 1 ? 0.0 : zeros(m) - if !isa(autodiff_mode(backend), Val{:forward}) || + if !isa(autodiff_mode(backend), ForwardMode) || !handles_types(backend, typeof(x), typeof(dy)) return nothing end @@ -51,7 +56,7 @@ function add_pullback_benchmarks!( dx = n == 1 ? 0.0 : zeros(n) dy = m == 1 ? randn() : randn(m) - if !isa(autodiff_mode(backend), Val{:reverse}) || + if !isa(autodiff_mode(backend), ReverseMode) || !handles_types(backend, typeof(x), typeof(dy)) return nothing end @@ -83,20 +88,16 @@ function add_derivative_benchmarks!( x = randn() - suite["value_and_derivative"][(1, 1)]["$(pretty(backend))"] = @benchmarkable begin - value_and_derivative($backend, $f, $x) - end - - suite["value_and_derivative"][(1, 1)]["$(pretty(backend)) - fallback"] = @benchmarkable begin - value_and_derivative(Val(:fallback), $backend, $f, $x) - end + for implem in (CustomImplem(), FallbackImplem()) + backend_implem = "$(pretty(backend)) - $(pretty(implem))" - suite["derivative"][(1, 1)]["$(pretty(backend))"] = @benchmarkable begin - derivative($backend, $f, $x) - end + suite["value_and_derivative"][(1, 1)][backend_implem] = @benchmarkable begin + value_and_derivative($implem, $backend, $f, $x) + end - suite["derivative"][(1, 1)]["$(pretty(backend)) - fallback"] = @benchmarkable begin - derivative(Val(:fallback), $backend, $f, $x) + suite["derivative"][(1, 1)][backend_implem] = @benchmarkable begin + derivative($implem, $backend, $f, $x) + end end return nothing @@ -113,32 +114,22 @@ function add_multiderivative_benchmarks!( x = randn() multider = zeros(m) - suite["value_and_multiderivative"][(1, m)]["$(pretty(backend))"] = @benchmarkable begin - value_and_multiderivative($backend, $f, $x) - end - suite["value_and_multiderivative!"][(1, m)]["$(pretty(backend))"] = @benchmarkable begin - value_and_multiderivative!($multider, $backend, $f, $x) - end - - suite["value_and_multiderivative"][(1, m)]["$(pretty(backend)) - fallback"] = @benchmarkable begin - value_and_multiderivative(Val(:fallback), $backend, $f, $x) - end - suite["value_and_multiderivative!"][(1, m)]["$(pretty(backend)) - fallback"] = @benchmarkable begin - value_and_multiderivative!(Val(:fallback), $multider, $backend, $f, $x) - end + for implem in (CustomImplem(), FallbackImplem()) + backend_implem = "$(pretty(backend)) - $(pretty(implem))" - suite["multiderivative"][(1, m)]["$(pretty(backend))"] = @benchmarkable begin - multiderivative($backend, $f, $x) - end - suite["multiderivative!"][(1, m)]["$(pretty(backend))"] = @benchmarkable begin - multiderivative!($multider, $backend, $f, $x) - end + suite["value_and_multiderivative"][(1, m)][backend_implem] = @benchmarkable begin + value_and_multiderivative($implem, $backend, $f, $x) + end + suite["value_and_multiderivative!"][(1, m)][backend_implem] = @benchmarkable begin + value_and_multiderivative!($implem, $multider, $backend, $f, $x) + end - suite["multiderivative"][(1, m)]["$(pretty(backend)) - fallback"] = @benchmarkable begin - multiderivative(Val(:fallback), $backend, $f, $x) - end - suite["multiderivative!"][(1, m)]["$(pretty(backend)) - fallback"] = @benchmarkable begin - multiderivative!(Val(:fallback), $multider, $backend, $f, $x) + suite["multiderivative"][(1, m)][backend_implem] = @benchmarkable begin + multiderivative($implem, $backend, $f, $x) + end + suite["multiderivative!"][(1, m)][backend_implem] = @benchmarkable begin + multiderivative!($implem, $multider, $backend, $f, $x) + end end return nothing @@ -155,32 +146,22 @@ function add_gradient_benchmarks!( x = randn(n) grad = zeros(n) - suite["value_and_gradient"][(n, 1)]["$(pretty(backend))"] = @benchmarkable begin - value_and_gradient($backend, $f, $x) - end - suite["value_and_gradient!"][(n, 1)]["$(pretty(backend))"] = @benchmarkable begin - value_and_gradient!($grad, $backend, $f, $x) - end + for implem in (CustomImplem(), FallbackImplem()) + backend_implem = "$(pretty(backend)) - $(pretty(implem))" - suite["value_and_gradient"][(n, 1)]["$(pretty(backend)) - fallback"] = @benchmarkable begin - value_and_gradient(Val(:fallback), $backend, $f, $x) - end - suite["value_and_gradient!"][(n, 1)]["$(pretty(backend)) - fallback"] = @benchmarkable begin - value_and_gradient!(Val(:fallback), $grad, $backend, $f, $x) - end - - suite["gradient"][(n, 1)]["$(pretty(backend))"] = @benchmarkable begin - gradient($backend, $f, $x) - end - suite["gradient!"][(n, 1)]["$(pretty(backend))"] = @benchmarkable begin - gradient!($grad, $backend, $f, $x) - end + suite["value_and_gradient"][(n, 1)][backend_implem] = @benchmarkable begin + value_and_gradient($implem, $backend, $f, $x) + end + suite["value_and_gradient!"][(n, 1)][backend_implem] = @benchmarkable begin + value_and_gradient!($implem, $grad, $backend, $f, $x) + end - suite["gradient"][(n, 1)]["$(pretty(backend)) - fallback"] = @benchmarkable begin - gradient(Val(:fallback), $backend, $f, $x) - end - suite["gradient!"][(n, 1)]["$(pretty(backend)) - fallback"] = @benchmarkable begin - gradient!(Val(:fallback), $grad, $backend, $f, $x) + suite["gradient"][(n, 1)][backend_implem] = @benchmarkable begin + gradient($implem, $backend, $f, $x) + end + suite["gradient!"][(n, 1)][backend_implem] = @benchmarkable begin + gradient!($implem, $grad, $backend, $f, $x) + end end return nothing @@ -196,32 +177,22 @@ function add_jacobian_benchmarks!( x = randn(n) jac = zeros(m, n) - suite["value_and_jacobian"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin - value_and_jacobian($backend, $f, $x) - end - suite["value_and_jacobian!"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin - value_and_jacobian!($jac, $backend, $f, $x) - end - - suite["value_and_jacobian"][(n, m)]["$(pretty(backend)) (:fallback)"] = @benchmarkable begin - value_and_jacobian(Val(:fallback), $backend, $f, $x) - end - suite["value_and_jacobian!"][(n, m)]["$(pretty(backend)) (:fallback)"] = @benchmarkable begin - value_and_jacobian!(Val(:fallback), $jac, $backend, $f, $x) - end - - suite["jacobian"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin - jacobian($backend, $f, $x) - end - suite["jacobian!"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin - jacobian!($jac, $backend, $f, $x) - end - - suite["jacobian"][(n, m)]["$(pretty(backend)) - fallback"] = @benchmarkable begin - jacobian(Val(:fallback), $backend, $f, $x) - end - suite["jacobian!"][(n, m)]["$(pretty(backend)) - fallback"] = @benchmarkable begin - jacobian!(Val(:fallback), $jac, $backend, $f, $x) + for implem in (CustomImplem(), FallbackImplem()) + backend_implem = "$(pretty(backend)) - $(pretty(implem))" + + suite["value_and_jacobian"][(n, m)][backend_implem] = @benchmarkable begin + value_and_jacobian($implem, $backend, $f, $x) + end + suite["value_and_jacobian!"][(n, m)][backend_implem] = @benchmarkable begin + value_and_jacobian!($implem, $jac, $backend, $f, $x) + end + + suite["jacobian"][(n, m)][backend_implem] = @benchmarkable begin + jacobian($implem, $backend, $f, $x) + end + suite["jacobian!"][(n, m)][backend_implem] = @benchmarkable begin + jacobian!($implem, $jac, $backend, $f, $x) + end end return nothing diff --git a/docs/Project.toml b/docs/Project.toml index 5a20d1146..7d881d2b2 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,6 +3,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" diff --git a/docs/make.jl b/docs/make.jl index 29c31acf2..6eb2f67bb 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,9 +1,10 @@ +using ADTypes using Base: get_extension using DifferentiationInterface import DifferentiationInterface as DI using Documenter -using DiffResults: DiffResults +using Diffractor: Diffractor using Enzyme: Enzyme using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff @@ -12,6 +13,7 @@ using ReverseDiff: ReverseDiff using Zygote: Zygote ChainRulesCoreExt = get_extension(DI, :DifferentiationInterfaceChainRulesCoreExt) +DiffractorExt = get_extension(DI, :DifferentiationInterfaceDiffractorExt) EnzymeExt = get_extension(DI, :DifferentiationInterfaceEnzymeExt) FiniteDiffExt = get_extension(DI, :DifferentiationInterfaceFiniteDiffExt) ForwardDiffExt = get_extension(DI, :DifferentiationInterfaceForwardDiffExt) @@ -45,7 +47,9 @@ end makedocs(; modules=[ DifferentiationInterface, + ADTypes, ChainRulesCoreExt, + DiffractorExt, EnzymeExt, FiniteDiffExt, ForwardDiffExt, diff --git a/docs/src/interface.md b/docs/src/interface.md index 991e51eda..17ec45fe2 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -57,14 +57,48 @@ Pages = ["pullback.jl"] ## Backends +### ADTypes.jl + +```@meta +CurrentModule = ADTypes +``` + +The following backends are defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl): + +```@autodocs +Modules = [ADTypes] +``` + +Only a subset is supported by DifferentiationInterface.jl at the moment. + +### DifferentiationInterface.jl + +```@meta +CurrentModule = DifferentiationInterface +``` + +The following backends are defined by DifferentiationInterface.jl: + +```@autodocs +Modules = [DifferentiationInterface] +Pages = ["backends.jl"] +Order = [:type] +Private = false +``` + +### Input / output types + ```@autodocs Modules = [DifferentiationInterface] Pages = ["backends.jl"] +Order = [:function] +Private = false ``` ## Internals ```@autodocs Modules = [DifferentiationInterface] -Pages = ["utils.jl"] +Pages = ["implem.jl", "mode.jl", "utils.jl", "backends.jl"] +Public = false ``` diff --git a/ext/DifferentiationInterfaceChainRulesCoreExt.jl b/ext/DifferentiationInterfaceChainRulesCoreExt.jl index 1a20a1918..3deef027a 100644 --- a/ext/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/ext/DifferentiationInterfaceChainRulesCoreExt.jl @@ -2,16 +2,18 @@ module DifferentiationInterfaceChainRulesCoreExt using ChainRulesCore: HasForwardsMode, HasReverseMode, NoTangent, RuleConfig, frule_via_ad, rrule_via_ad -using DifferentiationInterface: AutoChainRules, ruleconfig +using DifferentiationInterface: AutoChainRules, CustomImplem, update! import DifferentiationInterface as DI using DocStringExtensions -update!(_old::Number, new::Number) = new -update!(old, new) = old .= new +ruleconfig(backend::AutoChainRules) = backend.ruleconfig const AutoForwardChainRules = AutoChainRules{<:RuleConfig{>:HasForwardsMode}} const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}} +DI.autodiff_mode(::AutoForwardChainRules) = DI.ForwardMode() +DI.autodiff_mode(::AutoReverseChainRules) = DI.ReverseMode() + ## Primitives function DI.value_and_pushforward(backend::AutoForwardChainRules, f, x, dx) diff --git a/ext/DifferentiationInterfaceDiffractorExt.jl b/ext/DifferentiationInterfaceDiffractorExt.jl new file mode 100644 index 000000000..81c146e75 --- /dev/null +++ b/ext/DifferentiationInterfaceDiffractorExt.jl @@ -0,0 +1,23 @@ +module DifferentiationInterfaceDiffractorExt + +import AbstractDifferentiation as AD # public API for Diffractor +using DifferentiationInterface: AutoDiffractor, update! +import DifferentiationInterface as DI +using Diffractor: DiffractorForwardBackend +using DocStringExtensions + +DI.autodiff_mode(::AutoDiffractor) = DI.ForwardMode() + +function DI.value_and_pushforward(::AutoDiffractor, f, x, dx) + vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x) + y, dy = vpff((dx,)) + return y, dy +end + +function DI.value_and_pushforward!(dy, ::AutoDiffractor, f, x, dx) + vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x) + y, new_dy = vpff((dx,)) + return y, update!(dy, new_dy) +end + +end diff --git a/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 7c0a7a862..a7c7f1688 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -1,6 +1,7 @@ module DifferentiationInterfaceEnzymeExt using ADTypes: AutoEnzyme +using DifferentiationInterface: CustomImplem import DifferentiationInterface as DI using DocStringExtensions using Enzyme: diff --git a/ext/DifferentiationInterfaceEnzymeExt/forward.jl b/ext/DifferentiationInterfaceEnzymeExt/forward.jl index 5b1bd9e33..9a28138bc 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/forward.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/forward.jl @@ -1,5 +1,5 @@ const AutoForwardEnzyme = AutoEnzyme{Val{:forward}} -DI.autodiff_mode(::AutoForwardEnzyme) = Val{:forward}() +DI.autodiff_mode(::AutoForwardEnzyme) = DI.ForwardMode() ## Primitives @@ -20,7 +20,7 @@ end ## Utilities -function DI.value_and_jacobian(::AutoForwardEnzyme, f, x::AbstractArray) +function DI.value_and_jacobian(::CustomImplem, ::AutoForwardEnzyme, f, x::AbstractArray) y = f(x) jac = jacobian(Forward, f, x) # see https://github.com/EnzymeAD/Enzyme.jl/issues/1332 diff --git a/ext/DifferentiationInterfaceEnzymeExt/reverse.jl b/ext/DifferentiationInterfaceEnzymeExt/reverse.jl index 24787ab34..854e69d88 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/reverse.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/reverse.jl @@ -1,5 +1,5 @@ const AutoReverseEnzyme = AutoEnzyme{Val{:reverse}} -DI.autodiff_mode(::AutoReverseEnzyme) = Val{:reverse}() +DI.autodiff_mode(::AutoReverseEnzyme) = DI.ReverseMode() DI.handles_output_type(::AutoReverseEnzyme, ::Type{<:AbstractArray}) = false ## Primitives @@ -23,14 +23,14 @@ end ## Utilities -function DI.value_and_gradient(::AutoReverseEnzyme, f, x::AbstractArray) +function DI.value_and_gradient(::CustomImplem, ::AutoReverseEnzyme, f, x::AbstractArray) y = f(x) grad = gradient(Reverse, f, x) return y, grad end function DI.value_and_gradient!( - grad::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray + ::CustomImplem, grad::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray ) y = f(x) gradient!(Reverse, grad, f, x) diff --git a/ext/DifferentiationInterfaceFiniteDiffExt.jl b/ext/DifferentiationInterfaceFiniteDiffExt.jl index 5ad2be8de..b5814f5cb 100644 --- a/ext/DifferentiationInterfaceFiniteDiffExt.jl +++ b/ext/DifferentiationInterfaceFiniteDiffExt.jl @@ -1,6 +1,7 @@ module DifferentiationInterfaceFiniteDiffExt using ADTypes: AutoFiniteDiff +using DifferentiationInterface: CustomImplem import DifferentiationInterface as DI using DocStringExtensions using FiniteDiff: @@ -38,48 +39,56 @@ end ## Utilities -function DI.value_and_derivative(::AutoFiniteDiff{fdtype}, f, x::Number) where {fdtype} +function DI.value_and_derivative( + ::CustomImplem, ::AutoFiniteDiff{fdtype}, f, x::Number +) where {fdtype} y = f(x) der = finite_difference_derivative(f, x, fdtype, eltype(y), y) return y, der end function DI.value_and_multiderivative!( - multider::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x::Number + ::CustomImplem, multider::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x::Number ) where {fdtype} y = f(x) finite_difference_gradient!(multider, f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) return y, multider end -function DI.value_and_multiderivative(::AutoFiniteDiff{fdtype}, f, x::Number) where {fdtype} +function DI.value_and_multiderivative( + ::CustomImplem, ::AutoFiniteDiff{fdtype}, f, x::Number +) where {fdtype} y = f(x) multider = finite_difference_gradient(f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) return y, multider end function DI.value_and_gradient!( - grad::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x::AbstractArray + ::CustomImplem, grad::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x::AbstractArray ) where {fdtype} y = f(x) finite_difference_gradient!(grad, f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) return y, grad end -function DI.value_and_gradient(::AutoFiniteDiff{fdtype}, f, x::AbstractArray) where {fdtype} +function DI.value_and_gradient( + ::CustomImplem, ::AutoFiniteDiff{fdtype}, f, x::AbstractArray +) where {fdtype} y = f(x) grad = finite_difference_gradient(f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) return y, grad end -function DI.value_and_jacobian(::AutoFiniteDiff{fdtype}, f, x::AbstractArray) where {fdtype} +function DI.value_and_jacobian( + ::CustomImplem, ::AutoFiniteDiff{fdtype}, f, x::AbstractArray +) where {fdtype} y = f(x) jac = finite_difference_jacobian(f, x, fdtype, eltype(y)) return y, jac end function DI.value_and_jacobian!( - jac::AbstractMatrix, backend::AutoFiniteDiff, f, x::AbstractArray + ::CustomImplem, jac::AbstractMatrix, backend::AutoFiniteDiff, f, x::AbstractArray ) y, new_jac = DI.value_and_jacobian(backend, f, x) jac .= new_jac diff --git a/ext/DifferentiationInterfaceForwardDiffExt.jl b/ext/DifferentiationInterfaceForwardDiffExt.jl index ee45f530f..5825a6098 100644 --- a/ext/DifferentiationInterfaceForwardDiffExt.jl +++ b/ext/DifferentiationInterfaceForwardDiffExt.jl @@ -1,6 +1,7 @@ module DifferentiationInterfaceForwardDiffExt using ADTypes: AutoForwardDiff +using DifferentiationInterface: CustomImplem import DifferentiationInterface as DI using DiffResults: DiffResults using DocStringExtensions @@ -66,45 +67,49 @@ end ## Utilities (TODO: use DiffResults) -function DI.value_and_derivative(::AutoForwardDiff, f, x::Number) +function DI.value_and_derivative(::CustomImplem, ::AutoForwardDiff, f, x::Number) y = f(x) der = derivative(f, x) return y, der end -function DI.value_and_multiderivative(::AutoForwardDiff, f, x::Number) +function DI.value_and_multiderivative(::CustomImplem, ::AutoForwardDiff, f, x::Number) y = f(x) multider = derivative(f, x) return y, multider end function DI.value_and_multiderivative!( - multider::AbstractArray, ::AutoForwardDiff, f, x::Number + ::CustomImplem, multider::AbstractArray, ::AutoForwardDiff, f, x::Number ) y = f(x) derivative!(multider, f, x) return y, multider end -function DI.value_and_gradient(::AutoForwardDiff, f, x::AbstractArray) +function DI.value_and_gradient(::CustomImplem, ::AutoForwardDiff, f, x::AbstractArray) y = f(x) grad = gradient(f, x) return y, grad end -function DI.value_and_gradient!(grad::AbstractArray, ::AutoForwardDiff, f, x::AbstractArray) +function DI.value_and_gradient!( + ::CustomImplem, grad::AbstractArray, ::AutoForwardDiff, f, x::AbstractArray +) y = f(x) gradient!(grad, f, x) return y, grad end -function DI.value_and_jacobian(::AutoForwardDiff, f, x::AbstractArray) +function DI.value_and_jacobian(::CustomImplem, ::AutoForwardDiff, f, x::AbstractArray) y = f(x) jac = jacobian(f, x) return y, jac end -function DI.value_and_jacobian!(jac::AbstractMatrix, ::AutoForwardDiff, f, x::AbstractArray) +function DI.value_and_jacobian!( + ::CustomImplem, jac::AbstractMatrix, ::AutoForwardDiff, f, x::AbstractArray +) y = f(x) jacobian!(jac, f, x) return y, jac diff --git a/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl b/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl index 6ff75178a..868cf1d09 100644 --- a/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -1,6 +1,7 @@ module DifferentiationInterfacePolyesterForwardDiffExt using ADTypes: AutoPolyesterForwardDiff, AutoForwardDiff +using DifferentiationInterface: CustomImplem import DifferentiationInterface as DI using DiffResults: DiffResults using DocStringExtensions @@ -17,7 +18,7 @@ end ## Utilities function DI.value_and_gradient!( - grad::AbstractArray, ::AutoPolyesterForwardDiff{C}, f, x::AbstractArray + ::CustomImplem, grad::AbstractArray, ::AutoPolyesterForwardDiff{C}, f, x::AbstractArray ) where {C} y = f(x) threaded_gradient!(f, grad, x, Chunk{C}()) @@ -25,7 +26,7 @@ function DI.value_and_gradient!( end function DI.value_and_jacobian!( - jac::AbstractMatrix, ::AutoPolyesterForwardDiff{C}, f, x::AbstractArray + ::CustomImplem, jac::AbstractMatrix, ::AutoPolyesterForwardDiff{C}, f, x::AbstractArray ) where {C} y = f(x) threaded_jacobian!(f, jac, x, Chunk{C}()) diff --git a/ext/DifferentiationInterfaceReverseDiffExt.jl b/ext/DifferentiationInterfaceReverseDiffExt.jl index ea1d23462..d7d29a0de 100644 --- a/ext/DifferentiationInterfaceReverseDiffExt.jl +++ b/ext/DifferentiationInterfaceReverseDiffExt.jl @@ -1,6 +1,7 @@ module DifferentiationInterfaceReverseDiffExt using ADTypes: AutoReverseDiff +using DifferentiationInterface: CustomImplem import DifferentiationInterface as DI using DiffResults: DiffResults using DocStringExtensions @@ -36,25 +37,29 @@ end ## Utilities (TODO: use DiffResults) -function DI.value_and_gradient(::AutoReverseDiff, f, x::AbstractArray) +function DI.value_and_gradient(::CustomImplem, ::AutoReverseDiff, f, x::AbstractArray) y = f(x) grad = gradient(f, x) return y, grad end -function DI.value_and_gradient!(grad::AbstractArray, ::AutoReverseDiff, f, x::AbstractArray) +function DI.value_and_gradient!( + ::CustomImplem, grad::AbstractArray, ::AutoReverseDiff, f, x::AbstractArray +) y = f(x) gradient!(grad, f, x) return y, grad end -function DI.value_and_jacobian(::AutoReverseDiff, f, x::AbstractArray) +function DI.value_and_jacobian(::CustomImplem, ::AutoReverseDiff, f, x::AbstractArray) y = f(x) jac = jacobian(f, x) return y, jac end -function DI.value_and_jacobian!(jac::AbstractMatrix, ::AutoReverseDiff, f, x::AbstractArray) +function DI.value_and_jacobian!( + ::CustomImplem, jac::AbstractMatrix, ::AutoReverseDiff, f, x::AbstractArray +) y = f(x) jacobian!(jac, f, x) return y, jac diff --git a/ext/DifferentiationInterfaceZygoteExt.jl b/ext/DifferentiationInterfaceZygoteExt.jl index bd8d8d333..f864b8603 100644 --- a/ext/DifferentiationInterfaceZygoteExt.jl +++ b/ext/DifferentiationInterfaceZygoteExt.jl @@ -1,54 +1,50 @@ module DifferentiationInterfaceZygoteExt using ADTypes: AutoZygote -using DifferentiationInterface: AutoChainRules +using DifferentiationInterface: AutoChainRules, CustomImplem, update! import DifferentiationInterface as DI using DocStringExtensions -using Zygote: ZygoteRuleConfig, gradient, jacobian, withgradient, withjacobian +using Zygote: ZygoteRuleConfig, gradient, jacobian, pullback, withgradient, withjacobian ## Primitives const zygote_chainrules_backend = AutoChainRules(ZygoteRuleConfig()) -function DI.value_and_pushforward!(dy, ::AutoZygote, f, x, dx) - return DI.value_and_pushforward!(dy, zygote_chainrules_backend, f, x, dx) -end - -function DI.value_and_pushforward(::AutoZygote, f, x, dx) - return DI.value_and_pushforward(zygote_chainrules_backend, f, x, dx) -end - function DI.value_and_pullback!(dx, ::AutoZygote, f, x, dy) - return DI.value_and_pullback!(dx, zygote_chainrules_backend, f, x, dy) + y, back = pullback(f, x) + new_dx = only(back(dy)) + return y, update!(dx, new_dx) end function DI.value_and_pullback(::AutoZygote, f, x, dy) - return DI.value_and_pullback(zygote_chainrules_backend, f, x, dy) + y, back = pullback(f, x) + dx = only(back(dy)) + return y, dx end ## Utilities -function DI.value_and_gradient(::AutoZygote, f, x::AbstractArray) +function DI.value_and_gradient(::CustomImplem, ::AutoZygote, f, x::AbstractArray) res = withgradient(f, x) return res.val, only(res.grad) end function DI.value_and_gradient!( - grad::AbstractArray, backend::AutoZygote, f, x::AbstractArray + ::CustomImplem, grad::AbstractArray, backend::AutoZygote, f, x::AbstractArray ) y, new_grad = DI.value_and_gradient(backend, f, x) grad .= new_grad return y, grad end -function DI.value_and_jacobian(::AutoZygote, f, x::AbstractArray) +function DI.value_and_jacobian(::CustomImplem, ::AutoZygote, f, x::AbstractArray) y = f(x) jac = jacobian(f, x) return y, only(jac) end function DI.value_and_jacobian!( - jac::AbstractMatrix, backend::AutoZygote, f, x::AbstractArray + ::CustomImplem, jac::AbstractMatrix, backend::AutoZygote, f, x::AbstractArray ) y, new_jac = DI.value_and_jacobian(backend, f, x) jac .= new_jac diff --git a/src/DifferentiationInterface.jl b/src/DifferentiationInterface.jl index d1c210297..c1a941cf4 100644 --- a/src/DifferentiationInterface.jl +++ b/src/DifferentiationInterface.jl @@ -15,6 +15,8 @@ using DocStringExtensions using FillArrays: OneElement include("backends.jl") +include("implem.jl") +include("mode.jl") include("utils.jl") include("pushforward.jl") include("pullback.jl") @@ -22,10 +24,10 @@ include("scalar_scalar.jl") include("scalar_array.jl") include("array_scalar.jl") include("array_array.jl") +include("custom.jl") -export AutoChainRules +export AutoChainRules, AutoDiffractor -export autodiff_mode export handles_input_type, handles_output_type, handles_types export value_and_pushforward!, value_and_pushforward diff --git a/src/array_array.jl b/src/array_array.jl index 155dbb909..df1218477 100644 --- a/src/array_array.jl +++ b/src/array_array.jl @@ -13,13 +13,7 @@ Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of a $JAC_NOTES """ function value_and_jacobian!( - jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray -) - return value_and_jacobian!(Val{:fallback}(), jac, backend, f, x) -end - -function value_and_jacobian!( - implem::Val{:fallback}, + implem::AbstractImplem, jac::AbstractMatrix, backend::AbstractADType, f, @@ -37,8 +31,8 @@ function check_jac(jac::AbstractMatrix, x::AbstractArray, y::AbstractArray) end function value_and_jacobian!( - ::Val{:fallback}, - ::Val{:forward}, + ::AbstractImplem, + ::ForwardMode, jac::AbstractMatrix, backend::AbstractADType, f, @@ -55,8 +49,8 @@ function value_and_jacobian!( end function value_and_jacobian!( - ::Val{:fallback}, - ::Val{:reverse}, + ::AbstractImplem, + ::ReverseMode, jac::AbstractMatrix, backend::AbstractADType, f, @@ -79,12 +73,8 @@ Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of a $JAC_NOTES """ -function value_and_jacobian(backend::AbstractADType, f, x::AbstractArray) - return value_and_jacobian(Val{:fallback}(), backend, f, x) -end - function value_and_jacobian( - implem::Val{:fallback}, backend::AbstractADType, f, x::AbstractArray + implem::AbstractImplem, backend::AbstractADType, f, x::AbstractArray ) y = f(x) T = promote_type(eltype(x), eltype(y)) @@ -99,12 +89,8 @@ Compute the Jacobian matrix `jac = ∂f(x)` of an array-to-array function, overw $JAC_NOTES """ -function jacobian!(jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray) - return jacobian!(Val{:fallback}(), jac, backend, f, x) -end - function jacobian!( - implem::Val{:fallback}, + implem::AbstractImplem, jac::AbstractMatrix, backend::AbstractADType, f, @@ -120,10 +106,6 @@ Compute the Jacobian matrix `jac = ∂f(x)` of an array-to-array function. $JAC_NOTES """ -function jacobian(backend::AbstractADType, f, x::AbstractArray) - return jacobian(Val{:fallback}(), backend, f, x) -end - -function jacobian(implem::Val{:fallback}, backend::AbstractADType, f, x::AbstractArray) +function jacobian(implem::AbstractImplem, backend::AbstractADType, f, x::AbstractArray) return last(value_and_jacobian(implem, backend, f, x)) end diff --git a/src/array_scalar.jl b/src/array_scalar.jl index ec525cadd..95d981b10 100644 --- a/src/array_scalar.jl +++ b/src/array_scalar.jl @@ -4,13 +4,7 @@ Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function, overwriting `grad` if possible. """ function value_and_gradient!( - grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray -) - return value_and_gradient!(Val{:fallback}(), grad, backend, f, x) -end - -function value_and_gradient!( - implem::Val{:fallback}, + implem::AbstractImplem, grad::AbstractArray, backend::AbstractADType, f, @@ -20,8 +14,8 @@ function value_and_gradient!( end function value_and_gradient!( - ::Val{:fallback}, - ::Val{:forward}, + ::AbstractImplem, + ::ForwardMode, grad::AbstractArray, backend::AbstractADType, f, @@ -36,8 +30,8 @@ function value_and_gradient!( end function value_and_gradient!( - ::Val{:fallback}, - ::Val{:reverse}, + ::AbstractImplem, + ::ReverseMode, grad::AbstractArray, backend::AbstractADType, f, @@ -52,12 +46,8 @@ end Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function. """ -function value_and_gradient(backend::AbstractADType, f, x::AbstractArray) - return value_and_gradient(Val{:fallback}(), backend, f, x) -end - function value_and_gradient( - implem::Val{:fallback}, backend::AbstractADType, f, x::AbstractArray + implem::AbstractImplem, backend::AbstractADType, f, x::AbstractArray ) grad = similar(x) return value_and_gradient!(implem, grad, backend, f, x) @@ -68,12 +58,8 @@ end Compute the gradient `grad = ∇f(x)` of an array-to-scalar function, overwriting `grad` if possible. """ -function gradient!(grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray) - return gradient!(Val{:fallback}(), grad, backend, f, x) -end - function gradient!( - implem::Val{:fallback}, + implem::AbstractImplem, grad::AbstractArray, backend::AbstractADType, f, @@ -87,10 +73,6 @@ end Compute the gradient `grad = ∇f(x)` of an array-to-scalar function. """ -function gradient(backend::AbstractADType, f, x::AbstractArray) - return gradient(Val{:fallback}(), backend, f, x) -end - -function gradient(implem::Val{:fallback}, backend::AbstractADType, f, x::AbstractArray) +function gradient(implem::AbstractImplem, backend::AbstractADType, f, x::AbstractArray) return last(value_and_gradient(implem, backend, f, x)) end diff --git a/src/backends.jl b/src/backends.jl index 6e2c9a551..12cbdd9a8 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -1,6 +1,6 @@ -## Additional backend +## Additional backends -# TODO: remove once https://github.com/SciML/ADTypes.jl/pull/21 is merged +# TODO: remove once https://github.com/SciML/ADTypes.jl/pull/21 is merged and released """ AutoChainRules{RC} @@ -22,22 +22,29 @@ struct AutoChainRules{RC} <: AbstractADType ruleconfig::RC end -ruleconfig(backend::AutoChainRules) = backend.ruleconfig +# TODO: remove this once https://github.com/SciML/ADTypes.jl/issues/27 is solved + +""" + AutoDiffractor + +Enables the use of [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl). +""" +struct AutoDiffractor <: AbstractADType end ## Traits and access """ autodiff_mode(backend) -Return `Val(:forward)` or `Val(:reverse)` in a statically predictable way. +Return `ForwardMode()` or `ReverseMode()` in a statically predictable way. This function must be overloaded for backends that do not inherit from `ADTypes.AbstractForwardMode` or `ADTypes.AbstractReverseMode` (e.g. because they support both forward and reverse). We classify `ADTypes.AbstractFiniteDifferencesMode` as forward mode. """ -autodiff_mode(::AbstractForwardMode) = Val{:forward}() -autodiff_mode(::AbstractFiniteDifferencesMode) = Val{:forward}() -autodiff_mode(::AbstractReverseMode) = Val{:reverse}() +autodiff_mode(::AbstractForwardMode) = ForwardMode() +autodiff_mode(::AbstractFiniteDifferencesMode) = ForwardMode() +autodiff_mode(::AbstractReverseMode) = ReverseMode() """ handles_input_type(backend, ::Type{X}) diff --git a/src/custom.jl b/src/custom.jl new file mode 100644 index 000000000..683e1cf8c --- /dev/null +++ b/src/custom.jl @@ -0,0 +1,29 @@ +for utility in [ + :value_and_derivative, + :value_and_multiderivative, + :value_and_gradient, + :value_and_jacobian, + :derivative, + :multiderivative, + :gradient, + :jacobian, +] + @eval $utility(backend::AbstractADType, f, x::Union{Number,AbstractArray}) = + $utility(CustomImplem(), backend, f, x) +end + +for utility! in [ + :value_and_multiderivative!, + :value_and_gradient!, + :value_and_jacobian!, + :multiderivative!, + :gradient!, + :jacobian!, +] + @eval $utility!( + storage::Union{Number,AbstractArray}, + backend::AbstractADType, + f, + x::Union{Number,AbstractArray}, + ) = $utility!(CustomImplem(), storage, backend, f, x) +end diff --git a/src/implem.jl b/src/implem.jl new file mode 100644 index 000000000..de9413e87 --- /dev/null +++ b/src/implem.jl @@ -0,0 +1,17 @@ +abstract type AbstractImplem end + +""" + CustomImplem + +Trait specifying that the custom utilities from the backend should be used as much as possible. +Used for internal dispatch only. +""" +struct CustomImplem <: AbstractImplem end + +""" + FallbackImplem + +Trait specifying that the fallback utilities from DifferentiationInterface.jl should be used as much as possible, until they call a pushforward or pullback. +Used for internal dispatch only. +""" +struct FallbackImplem <: AbstractImplem end diff --git a/src/mode.jl b/src/mode.jl new file mode 100644 index 000000000..365e93e5a --- /dev/null +++ b/src/mode.jl @@ -0,0 +1,17 @@ +abstract type AbstractMode end + +""" + ForwardMode + +Trait identifying forward mode AD backends. +Used for internal dispatch only. +""" +struct ForwardMode <: AbstractMode end + +""" + ReverseMode + +Trait identifying reverse mode AD backends. +Used for internal dispatch only. +""" +struct ReverseMode <: AbstractMode end diff --git a/src/scalar_array.jl b/src/scalar_array.jl index 88fd3ed09..af8b2f2ee 100644 --- a/src/scalar_array.jl +++ b/src/scalar_array.jl @@ -4,13 +4,7 @@ Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function, overwriting `multider` if possible. """ function value_and_multiderivative!( - multider::AbstractArray, backend::AbstractADType, f, x::Number -) - return value_and_multiderivative!(Val{:fallback}(), multider, backend, f, x) -end - -function value_and_multiderivative!( - implem::Val{:fallback}, multider::AbstractArray, backend::AbstractADType, f, x::Number + implem::AbstractImplem, multider::AbstractArray, backend::AbstractADType, f, x::Number ) return value_and_multiderivative!( implem, autodiff_mode(backend), multider, backend, f, x @@ -18,8 +12,8 @@ function value_and_multiderivative!( end function value_and_multiderivative!( - ::Val{:fallback}, - ::Val{:forward}, + ::AbstractImplem, + ::ForwardMode, multider::AbstractArray, backend::AbstractADType, f, @@ -29,8 +23,8 @@ function value_and_multiderivative!( end function value_and_multiderivative!( - ::Val{:fallback}, - ::Val{:reverse}, + ::AbstractImplem, + ::ReverseMode, multider::AbstractArray, backend::AbstractADType, f, @@ -49,12 +43,8 @@ end Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function. """ -function value_and_multiderivative(backend::AbstractADType, f, x::Number) - return value_and_multiderivative(Val{:fallback}(), backend, f, x) -end - function value_and_multiderivative( - implem::Val{:fallback}, backend::AbstractADType, f, x::Number + implem::AbstractImplem, backend::AbstractADType, f, x::Number ) multider = similar(f(x)) return value_and_multiderivative!(implem, multider, backend, f, x) @@ -65,12 +55,8 @@ end Compute the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function, overwriting `multider` if possible. """ -function multiderivative!(multider::AbstractArray, backend::AbstractADType, f, x::Number) - return multiderivative!(Val{:fallback}(), multider, backend, f, x) -end - function multiderivative!( - implem::Val{:fallback}, multider::AbstractArray, backend::AbstractADType, f, x::Number + implem::AbstractImplem, multider::AbstractArray, backend::AbstractADType, f, x::Number ) return last(value_and_multiderivative!(implem, multider, backend, f, x)) end @@ -80,10 +66,6 @@ end Compute the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function. """ -function multiderivative(backend::AbstractADType, f, x::Number) - return multiderivative(Val{:fallback}(), backend, f, x) -end - -function multiderivative(implem::Val{:fallback}, backend::AbstractADType, f, x::Number) +function multiderivative(implem::AbstractImplem, backend::AbstractADType, f, x::Number) return last(value_and_multiderivative(implem, backend, f, x)) end diff --git a/src/scalar_scalar.jl b/src/scalar_scalar.jl index 9b9e16d5c..fbd6f3631 100644 --- a/src/scalar_scalar.jl +++ b/src/scalar_scalar.jl @@ -3,22 +3,18 @@ Compute the primal value `y = f(x)` and the derivative `der = f'(x)` of a scalar-to-scalar function. """ -function value_and_derivative(backend::AbstractADType, f, x::Number) - return value_and_derivative(Val{:fallback}(), backend::AbstractADType, f, x::Number) -end - -function value_and_derivative(implem::Val{:fallback}, backend::AbstractADType, f, x::Number) +function value_and_derivative(implem::AbstractImplem, backend::AbstractADType, f, x::Number) return value_and_derivative(implem, autodiff_mode(backend), backend, f, x) end function value_and_derivative( - ::Val{:fallback}, ::Val{:forward}, backend::AbstractADType, f, x::Number + ::AbstractImplem, ::ForwardMode, backend::AbstractADType, f, x::Number ) return value_and_pushforward!(one(x), backend, f, x, one(x)) end function value_and_derivative( - ::Val{:fallback}, ::Val{:reverse}, backend::AbstractADType, f, x::Number + ::AbstractImplem, ::ReverseMode, backend::AbstractADType, f, x::Number ) return value_and_pullback!(one(x), backend, f, x, one(x)) end @@ -28,10 +24,6 @@ end Compute the derivative `der = f'(x)` of a scalar-to-scalar function. """ -function derivative(backend::AbstractADType, f, x::Number) - return derivative(Val{:fallback}(), backend, f, x) -end - -function derivative(implem::Val{:fallback}, backend::AbstractADType, f, x::Number) +function derivative(implem::AbstractImplem, backend::AbstractADType, f, x::Number) return last(value_and_derivative(implem, backend, f, x)) end diff --git a/src/utils.jl b/src/utils.jl index 4ad0de65b..ebcbd4f19 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -16,3 +16,6 @@ end mysimilar(x::Number) = zero(x) mysimilar(x::AbstractArray) = similar(x) + +update!(_old::Number, new::Number) = new +update!(old, new) = old .= new diff --git a/test/Project.toml b/test/Project.toml index 7c5bbd3b0..d903c0ef0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" diff --git a/test/chainrules_forward.jl b/test/chainrules_forward.jl new file mode 100644 index 000000000..f276239ad --- /dev/null +++ b/test/chainrules_forward.jl @@ -0,0 +1,10 @@ +using DifferentiationInterface: AutoChainRules, CustomImplem, FallbackImplem +using Diffractor: DiffractorRuleConfig + +test_pullback(AutoChainRules(DiffractorRuleConfig()), scenarios; type_stability=false); +test_jacobian_and_friends( + CustomImplem(), AutoChainRules(DiffractorRuleConfig()), scenarios; type_stability=false +); +test_jacobian_and_friends( + FallbackImplem().AutoChainRules(DiffractorRuleConfig()), scenarios; type_stability=false +); diff --git a/test/chainrules_reverse.jl b/test/chainrules_reverse.jl new file mode 100644 index 000000000..54e0ef506 --- /dev/null +++ b/test/chainrules_reverse.jl @@ -0,0 +1,10 @@ +using DifferentiationInterface: AutoChainRules, CustomImplem, FallbackImplem +using Zygote: ZygoteRuleConfig + +test_pullback(AutoChainRules(ZygoteRuleConfig()), scenarios; type_stability=false); +test_jacobian_and_friends( + CustomImplem(), AutoChainRules(ZygoteRuleConfig()), scenarios; type_stability=false +); +test_jacobian_and_friends( + FallbackImplem(), AutoChainRules(ZygoteRuleConfig()), scenarios; type_stability=false +); diff --git a/test/diffractor.jl b/test/diffractor.jl new file mode 100644 index 000000000..664da1177 --- /dev/null +++ b/test/diffractor.jl @@ -0,0 +1,8 @@ +using DifferentiationInterface: AutoDiffractor, CustomImplem, FallbackImplem +using Diffractor: Diffractor + +test_pushforward(AutoDiffractor(), scenarios; type_stability=true); +test_jacobian_and_friends(CustomImplem(), AutoDiffractor(), scenarios; type_stability=true); +test_jacobian_and_friends( + FallbackImplem(), AutoDiffractor(), scenarios; type_stability=true +); diff --git a/test/enzyme_forward.jl b/test/enzyme_forward.jl index cb2b497b7..d50e28dd7 100644 --- a/test/enzyme_forward.jl +++ b/test/enzyme_forward.jl @@ -1,8 +1,11 @@ using ADTypes: AutoEnzyme +using DifferentiationInterface: CustomImplem, FallbackImplem using Enzyme: Enzyme test_pushforward(AutoEnzyme(Val(:forward)), scenarios; type_stability=true); -test_jacobian_and_friends(AutoEnzyme(Val(:forward)), scenarios; type_stability=true); test_jacobian_and_friends( - AutoEnzyme(Val(:forward)), scenarios, Val(:fallback); type_stability=true + CustomImplem(), AutoEnzyme(Val(:forward)), scenarios; type_stability=true +); +test_jacobian_and_friends( + FallbackImplem(), AutoEnzyme(Val(:forward)), scenarios; type_stability=true ); diff --git a/test/enzyme_reverse.jl b/test/enzyme_reverse.jl index 79965dacd..b3716215b 100644 --- a/test/enzyme_reverse.jl +++ b/test/enzyme_reverse.jl @@ -1,8 +1,11 @@ using ADTypes: AutoEnzyme +using DifferentiationInterface: CustomImplem, FallbackImplem using Enzyme: Enzyme test_pullback(AutoEnzyme(Val(:reverse)), scenarios; type_stability=true); -test_jacobian_and_friends(AutoEnzyme(Val(:reverse)), scenarios; type_stability=true) test_jacobian_and_friends( - AutoEnzyme(Val(:reverse)), scenarios, Val(:fallback); type_stability=true + CustomImplem(), AutoEnzyme(Val(:reverse)), scenarios; type_stability=true +) +test_jacobian_and_friends( + FallbackImplem(), AutoEnzyme(Val(:reverse)), scenarios; type_stability=true ) diff --git a/test/finitediff.jl b/test/finitediff.jl index e0def3a32..b809cb268 100644 --- a/test/finitediff.jl +++ b/test/finitediff.jl @@ -1,6 +1,9 @@ using ADTypes: AutoFiniteDiff +using DifferentiationInterface: CustomImplem, FallbackImplem using FiniteDiff: FiniteDiff test_pushforward(AutoFiniteDiff(), scenarios; type_stability=true); -test_jacobian_and_friends(AutoFiniteDiff(), scenarios; type_stability=false); -test_jacobian_and_friends(AutoFiniteDiff(), scenarios, Val(:fallback); type_stability=false); +test_jacobian_and_friends(CustomImplem(), AutoFiniteDiff(), scenarios; type_stability=false); +test_jacobian_and_friends( + FallbackImplem(), AutoFiniteDiff(), scenarios; type_stability=false +); diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index 3f75b9dc7..5a2c92c2b 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -1,8 +1,11 @@ using ADTypes: AutoForwardDiff +using DifferentiationInterface: CustomImplem, FallbackImplem using ForwardDiff: ForwardDiff test_pushforward(AutoForwardDiff(), scenarios; type_stability=true); -test_jacobian_and_friends(AutoForwardDiff(), scenarios; type_stability=false); test_jacobian_and_friends( - AutoForwardDiff(), scenarios, Val(:fallback); type_stability=false + CustomImplem(), AutoForwardDiff(), scenarios; type_stability=false +); +test_jacobian_and_friends( + FallbackImplem(), AutoForwardDiff(), scenarios; type_stability=false ); diff --git a/test/polyesterforwarddiff.jl b/test/polyesterforwarddiff.jl index 408e69c55..bf7ab2c4b 100644 --- a/test/polyesterforwarddiff.jl +++ b/test/polyesterforwarddiff.jl @@ -1,4 +1,5 @@ using ADTypes: AutoPolyesterForwardDiff +using DifferentiationInterface: CustomImplem, FallbackImplem using PolyesterForwardDiff: PolyesterForwardDiff # see https://github.com/JuliaDiff/PolyesterForwardDiff.jl/issues/17 @@ -6,6 +7,7 @@ using PolyesterForwardDiff: PolyesterForwardDiff test_pushforward(AutoPolyesterForwardDiff(; chunksize=4), scenarios; type_stability=false); test_jacobian_and_friends( + CustomImplem(), AutoPolyesterForwardDiff(; chunksize=4), scenarios; input_type=Union{Number,AbstractVector}, diff --git a/test/reversediff.jl b/test/reversediff.jl index 80fb2a0d2..0bea69d92 100644 --- a/test/reversediff.jl +++ b/test/reversediff.jl @@ -1,8 +1,11 @@ using ADTypes: AutoReverseDiff +using DifferentiationInterface: CustomImplem, FallbackImplem using ReverseDiff: ReverseDiff test_pullback(AutoReverseDiff(), scenarios; type_stability=false); -test_jacobian_and_friends(AutoReverseDiff(), scenarios; type_stability=false); test_jacobian_and_friends( - AutoReverseDiff(), scenarios, Val(:fallback); type_stability=false + CustomImplem(), AutoReverseDiff(), scenarios; type_stability=false +); +test_jacobian_and_friends( + FallbackImplem(), AutoReverseDiff(), scenarios; type_stability=false ); diff --git a/test/runtests.jl b/test/runtests.jl index 628d366c9..64125e0c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,8 @@ using Test ## Utils -include("utils.jl") +include("scenarios.jl"); +include("utils.jl"); ## Main tests @@ -25,6 +26,15 @@ include("utils.jl") JET.test_package(DifferentiationInterface; target_defined_modules=true) end + @testset "ChainRules (forward)" begin + @test_skip include("chainrules_forward.jl") + end + @testset "ChainRules (reverse)" begin + include("chainrules_reverse.jl") + end + @testset "Diffractor (forward)" begin + @test_skip include("diffractor.jl") + end @testset "Enzyme (forward)" begin include("enzyme_forward.jl") end diff --git a/test/scenarios.jl b/test/scenarios.jl new file mode 100644 index 000000000..2b4bb8989 --- /dev/null +++ b/test/scenarios.jl @@ -0,0 +1,143 @@ +using ForwardDiff: ForwardDiff +using LinearAlgebra +using Random: AbstractRNG, randn! +using StableRNGs + +## Test scenarios + +@kwdef struct Scenario{F,X,Y,D1,D2,D3,D4} + "function" + f::F + "argument" + x::X + "primal value" + y::Y + "pushforward seed" + dx::X + "pullback seed" + dy::Y + "pullback result" + dx_true::X + "pushforward result" + dy_true::Y + "derivative result" + der_true::D1 = nothing + "multiderivative result" + multider_true::D2 = nothing + "gradient result" + grad_true::D3 = nothing + "Jacobian result" + jac_true::D4 = nothing +end + +## Constructors + +function make_scenario(rng::AbstractRNG, f, x) + y = f(x) + return make_scenario(rng, f, x, y) +end + +function make_scenario(rng::AbstractRNG, f::F, x::X, y::Y) where {F,X<:Number,Y<:Number} + dx = randn(rng, X) + dy = randn(rng, Y) + der_true = ForwardDiff.derivative(f, x) + dx_true = der_true * dy + dy_true = der_true * dx + return Scenario(; f, x, y, dx, dy, dx_true, dy_true, der_true) +end + +function make_scenario( + rng::AbstractRNG, f::F, x::X, y::Y +) where {F,X<:Number,Y<:AbstractArray} + dx = randn(rng, X) + dy = similar(y) + randn!(rng, dy) + multider_true = ForwardDiff.derivative(f, x) + dx_true = dot(multider_true, dy) + dy_true = multider_true .* dx + return Scenario(; f, x, y, dx, dy, dx_true, dy_true, multider_true) +end + +function make_scenario( + rng::AbstractRNG, f::F, x::X, y::Y +) where {F,X<:AbstractArray,Y<:Number} + dx = similar(x) + randn!(rng, dx) + dy = randn(rng, Y) + grad_true = ForwardDiff.gradient(f, x) + dx_true = grad_true .* dy + dy_true = dot(grad_true, dx) + return Scenario(; f, x, y, dx, dy, dx_true, dy_true, grad_true) +end + +function make_scenario( + rng::AbstractRNG, f::F, x::X, y::Y +) where {F,X<:AbstractArray,Y<:AbstractArray} + dx = similar(x) + randn!(rng, dx) + dy = similar(y) + randn!(rng, dy) + jac_true = ForwardDiff.jacobian(f, x) + dx_true = reshape(transpose(jac_true) * vec(dy), size(x)) + dy_true = reshape(jac_true * vec(dx), size(y)) + return Scenario(; f, x, y, dx, dy, dx_true, dy_true, jac_true) +end + +## Access + +get_input_type(::Scenario{F,X}) where {F,X} = X +get_output_type(::Scenario{F,X,Y}) where {F,X,Y} = Y + +## Seed + +rng = StableRNG(63) + +## Scenarios + +f_scalar_scalar(x::Number)::Number = sin(x) +f_scalar_vector(x::Number)::AbstractVector = [sin(x), sin(2x)] +f_scalar_matrix(x::Number)::AbstractMatrix = hcat([sin(x) cos(x)], [sin(2x) cos(2x)]) + +function f_vector_scalar(x::AbstractVector)::Number + a = eachindex(x) + return sum(sin.(a .* x)) +end + +function f_matrix_scalar(x::AbstractMatrix)::Number + a, b = axes(x) + return sum(sin.(a .* x)) + sum(cos.(transpose(b) .* x)) +end + +function f_vector_vector(x::AbstractVector)::AbstractVector + a = eachindex(x) + return vcat(sin.(a .* x), cos.(a .* x)) +end + +function f_vector_matrix(x::AbstractVector)::AbstractMatrix + a = eachindex(x) + return hcat(sin.(a .* x), cos.(a .* x)) +end + +function f_matrix_vector(x::AbstractMatrix)::AbstractVector + a, b = axes(x) + return vcat(vec(sin.(a .* x)), vec(cos.(transpose(b) .* x))) +end + +function f_matrix_matrix(x::AbstractMatrix)::AbstractMatrix + a, b = axes(x) + return hcat(vec(sin.(a .* x)), vec(cos.(transpose(b) .* x))) +end + +## All + +scenarios = [ + make_scenario(rng, f_scalar_scalar, 1.0), + make_scenario(rng, f_scalar_vector, 1.0), + make_scenario(rng, f_scalar_matrix, 1.0), + make_scenario(rng, f_vector_scalar, [1.0, 2.0]), + make_scenario(rng, f_matrix_scalar, [1.0 2.0; 3.0 4.0]), + make_scenario(rng, f_vector_vector, [1.0, 2.0]), + make_scenario(rng, f_vector_matrix, [1.0, 2.0]), + make_scenario(rng, f_matrix_vector, [1.0 2.0; 3.0 4.0]), + make_scenario(rng, f_matrix_matrix, [1.0 2.0; 3.0 4.0]), +]; diff --git a/test/utils.jl b/test/utils.jl index 067114ab0..4b52b574e 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,150 +1,12 @@ using ADTypes: AbstractADType using DifferentiationInterface -using ForwardDiff: ForwardDiff -using LinearAlgebra +using DifferentiationInterface: + AbstractImplem, CustomImplem, FallbackImplem, ForwardMode, ReverseMode, autodiff_mode using JET -using Random: AbstractRNG, randn! -using StableRNGs using Test -## Test scenarios - -@kwdef struct Scenario{F,X,Y,D1,D2,D3,D4} - "function" - f::F - "argument" - x::X - "primal value" - y::Y - "pushforward seed" - dx::X - "pullback seed" - dy::Y - "pullback result" - dx_true::X - "pushforward result" - dy_true::Y - "derivative result" - der_true::D1 = nothing - "multiderivative result" - multider_true::D2 = nothing - "gradient result" - grad_true::D3 = nothing - "Jacobian result" - jac_true::D4 = nothing -end - -## Constructors - -function make_scenario(rng::AbstractRNG, f, x) - y = f(x) - return make_scenario(rng, f, x, y) -end - -function make_scenario(rng::AbstractRNG, f::F, x::X, y::Y) where {F,X<:Number,Y<:Number} - dx = randn(rng, X) - dy = randn(rng, Y) - der_true = ForwardDiff.derivative(f, x) - dx_true = der_true * dy - dy_true = der_true * dx - return Scenario(; f, x, y, dx, dy, dx_true, dy_true, der_true) -end - -function make_scenario( - rng::AbstractRNG, f::F, x::X, y::Y -) where {F,X<:Number,Y<:AbstractArray} - dx = randn(rng, X) - dy = similar(y) - randn!(rng, dy) - multider_true = ForwardDiff.derivative(f, x) - dx_true = dot(multider_true, dy) - dy_true = multider_true .* dx - return Scenario(; f, x, y, dx, dy, dx_true, dy_true, multider_true) -end - -function make_scenario( - rng::AbstractRNG, f::F, x::X, y::Y -) where {F,X<:AbstractArray,Y<:Number} - dx = similar(x) - randn!(rng, dx) - dy = randn(rng, Y) - grad_true = ForwardDiff.gradient(f, x) - dx_true = grad_true .* dy - dy_true = dot(grad_true, dx) - return Scenario(; f, x, y, dx, dy, dx_true, dy_true, grad_true) -end - -function make_scenario( - rng::AbstractRNG, f::F, x::X, y::Y -) where {F,X<:AbstractArray,Y<:AbstractArray} - dx = similar(x) - randn!(rng, dx) - dy = similar(y) - randn!(rng, dy) - jac_true = ForwardDiff.jacobian(f, x) - dx_true = reshape(transpose(jac_true) * vec(dy), size(x)) - dy_true = reshape(jac_true * vec(dx), size(y)) - return Scenario(; f, x, y, dx, dy, dx_true, dy_true, jac_true) -end - -## Access - -get_input_type(::Scenario{F,X}) where {F,X} = X -get_output_type(::Scenario{F,X,Y}) where {F,X,Y} = Y - -## Seed - -rng = StableRNG(63) - -## Scenarios - -f_scalar_scalar(x::Number)::Number = sin(x) -f_scalar_vector(x::Number)::AbstractVector = [sin(x), sin(2x)] -f_scalar_matrix(x::Number)::AbstractMatrix = hcat([sin(x) cos(x)], [sin(2x) cos(2x)]) - -function f_vector_scalar(x::AbstractVector)::Number - a = eachindex(x) - return sum(sin.(a .* x)) -end - -function f_matrix_scalar(x::AbstractMatrix)::Number - a, b = axes(x) - return sum(sin.(a .* x)) + sum(cos.(transpose(b) .* x)) -end - -function f_vector_vector(x::AbstractVector)::AbstractVector - a = eachindex(x) - return vcat(sin.(a .* x), cos.(a .* x)) -end - -function f_vector_matrix(x::AbstractVector)::AbstractMatrix - a = eachindex(x) - return hcat(sin.(a .* x), cos.(a .* x)) -end - -function f_matrix_vector(x::AbstractMatrix)::AbstractVector - a, b = axes(x) - return vcat(vec(sin.(a .* x)), vec(cos.(transpose(b) .* x))) -end - -function f_matrix_matrix(x::AbstractMatrix)::AbstractMatrix - a, b = axes(x) - return hcat(vec(sin.(a .* x)), vec(cos.(transpose(b) .* x))) -end - -## All - -scenarios = [ - make_scenario(rng, f_scalar_scalar, 1.0), - make_scenario(rng, f_scalar_vector, 1.0), - make_scenario(rng, f_scalar_matrix, 1.0), - make_scenario(rng, f_vector_scalar, [1.0, 2.0]), - make_scenario(rng, f_matrix_scalar, [1.0 2.0; 3.0 4.0]), - make_scenario(rng, f_vector_vector, [1.0, 2.0]), - make_scenario(rng, f_vector_matrix, [1.0, 2.0]), - make_scenario(rng, f_matrix_vector, [1.0 2.0; 3.0 4.0]), - make_scenario(rng, f_matrix_matrix, [1.0 2.0; 3.0 4.0]), -]; +pretty(::CustomImplem) = "custom" +pretty(::FallbackImplem) = "fallback" ## Test utilities @@ -156,7 +18,7 @@ function test_pushforward( allocs::Bool=false, type_stability::Bool=true, ) - if autodiff_mode(backend) != Val(:forward) + if !isa(autodiff_mode(backend), ForwardMode) return nothing end scenarios = filter(scenarios) do s @@ -214,7 +76,7 @@ function test_pullback( allocs::Bool=false, type_stability::Bool=true, ) - if autodiff_mode(backend) != Val(:reverse) + if !isa(autodiff_mode(backend), ReverseMode) return nothing end scenarios = filter(scenarios) do s @@ -266,17 +128,16 @@ function test_pullback( end function test_derivative( + implem::AbstractImplem, backend::AbstractADType, - scenarios::Vector{<:Scenario}, - implems...; + scenarios::Vector{<:Scenario}; allocs::Bool=false, type_stability::Bool=true, ) scenarios = filter(scenarios) do s (get_input_type(s) <: Number) && (get_output_type(s) <: Number) end - testset_name = isempty(implems) ? "" : "(fallback)" - @testset "Derivative $testset_name" begin + @testset "Derivative ⁻ $(pretty(implem))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -284,9 +145,9 @@ function test_derivative( @testset "$X -> $Y" begin (; f, x, y, der_true) = scenario - y_out1, der_out1 = value_and_derivative(implems..., backend, f, x) + y_out1, der_out1 = value_and_derivative(implem, backend, f, x) - der_out2 = derivative(implems..., backend, f, x) + der_out2 = derivative(implem, backend, f, x) @testset "Primal value" begin @test y_out1 ≈ y @@ -296,12 +157,12 @@ function test_derivative( @test der_out2 ≈ der_true rtol = 1e-3 end allocs && @testset "Allocations" begin - @test iszero(@allocated value_and_derivative(implems..., backend, f, x)) - @test iszero(@allocated derivative(implems..., backend, f, x)) + @test iszero(@allocated value_and_derivative(implem, backend, f, x)) + @test iszero(@allocated derivative(implem, backend, f, x)) end type_stability && @testset "Type stability" begin - @test_opt value_and_derivative(implems..., backend, f, x) - @test_opt derivative(implems..., backend, f, x) + @test_opt value_and_derivative(implem, backend, f, x) + @test_opt derivative(implem, backend, f, x) end end end @@ -309,17 +170,16 @@ function test_derivative( end function test_multiderivative( + implem::AbstractImplem, backend::AbstractADType, - scenarios::Vector{<:Scenario}, - implems...; + scenarios::Vector{<:Scenario}; allocs::Bool=false, type_stability::Bool=true, ) scenarios = filter(scenarios) do s (get_input_type(s) <: Number) && (get_output_type(s) <: AbstractArray) end - testset_name = isempty(implems) ? "" : "(fallback)" - @testset "Multiderivative $testset_name" begin + @testset "Multiderivative ⁻ $(pretty(implem))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -327,15 +187,15 @@ function test_multiderivative( @testset "$X -> $Y" begin (; f, x, y, multider_true) = scenario - y_out1, multider_out1 = value_and_multiderivative(implems..., backend, f, x) + y_out1, multider_out1 = value_and_multiderivative(implem, backend, f, x) multider_in2 = zero(multider_out1) y_out2, multider_out2 = value_and_multiderivative!( - implems..., multider_in2, backend, f, x + implem, multider_in2, backend, f, x ) - multider_out3 = multiderivative(implems..., backend, f, x) + multider_out3 = multiderivative(implem, backend, f, x) multider_in4 = zero(multider_out3) - multider_out4 = multiderivative!(implems..., multider_in4, backend, f, x) + multider_out4 = multiderivative!(implem, multider_in4, backend, f, x) @testset "Primal value" begin @test y_out1 ≈ y @@ -354,20 +214,16 @@ function test_multiderivative( allocs && @testset "Allocations" begin @test iszero( @allocated value_and_multiderivative!( - implems..., multider_in2, backend, f, x + implem, multider_in2, backend, f, x ) ) @test iszero( - @allocated multiderivative!( - implems..., multider_in4, backend, f, x - ) + @allocated multiderivative!(implem, multider_in4, backend, f, x) ) end type_stability && @testset "Type stability" begin - @test_opt value_and_multiderivative!( - implems..., multider_in2, backend, f, x - ) - @test_opt multiderivative!(implems..., multider_in4, backend, f, x) + @test_opt value_and_multiderivative!(implem, multider_in2, backend, f, x) + @test_opt multiderivative!(implem, multider_in4, backend, f, x) end end end @@ -375,17 +231,16 @@ function test_multiderivative( end function test_gradient( + implem::AbstractImplem, backend::AbstractADType, - scenarios::Vector{<:Scenario}, - implems...; + scenarios::Vector{<:Scenario}; allocs::Bool=false, type_stability::Bool=true, ) scenarios = filter(scenarios) do s (get_input_type(s) <: AbstractArray) && (get_output_type(s) <: Number) end - testset_name = isempty(implems) ? "" : "(fallback)" - @testset "Gradient $testset_name" begin + @testset "Gradient ⁻ $(pretty(implem))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -393,13 +248,13 @@ function test_gradient( @testset "$X -> $Y" begin (; f, x, y, grad_true) = scenario - y_out1, grad_out1 = value_and_gradient(implems..., backend, f, x) + y_out1, grad_out1 = value_and_gradient(implem, backend, f, x) grad_in2 = zero(grad_out1) - y_out2, grad_out2 = value_and_gradient!(implems..., grad_in2, backend, f, x) + y_out2, grad_out2 = value_and_gradient!(implem, grad_in2, backend, f, x) - grad_out3 = gradient(implems..., backend, f, x) + grad_out3 = gradient(implem, backend, f, x) grad_in4 = zero(grad_out3) - grad_out4 = gradient!(implems..., grad_in4, backend, f, x) + grad_out4 = gradient!(implem, grad_in4, backend, f, x) @testset "Primal value" begin @test y_out1 ≈ y @@ -417,13 +272,13 @@ function test_gradient( end allocs && @testset "Allocations" begin @test iszero( - @allocated value_and_gradient!(implems..., grad_in2, backend, f, x) + @allocated value_and_gradient!(implem, grad_in2, backend, f, x) ) - @test iszero(@allocated gradient!(implems..., grad_in4, backend, f, x)) + @test iszero(@allocated gradient!(implem, grad_in4, backend, f, x)) end type_stability && @testset "Type stability" begin - @test_opt value_and_gradient!(implems..., grad_in2, backend, f, x) - @test_opt gradient!(implems..., grad_in4, backend, f, x) + @test_opt value_and_gradient!(implem, grad_in2, backend, f, x) + @test_opt gradient!(implem, grad_in4, backend, f, x) end end end @@ -431,17 +286,16 @@ function test_gradient( end function test_jacobian( + implem::AbstractImplem, backend::AbstractADType, - scenarios::Vector{<:Scenario}, - implems...; + scenarios::Vector{<:Scenario}; allocs::Bool=false, type_stability::Bool=true, ) scenarios = filter(scenarios) do s (get_input_type(s) <: AbstractArray) && (get_output_type(s) <: AbstractArray) end - testset_name = isempty(implems) ? "" : "(fallback)" - @testset "Jacobian $testset_name" begin + @testset "Jacobian ⁻ $(pretty(implem))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -449,13 +303,13 @@ function test_jacobian( @testset "$X -> $Y" begin (; f, x, y, jac_true) = scenario - y_out1, jac_out1 = value_and_jacobian(implems..., backend, f, x) + y_out1, jac_out1 = value_and_jacobian(implem, backend, f, x) jac_in2 = zero(jac_out1) - y_out2, jac_out2 = value_and_jacobian!(implems..., jac_in2, backend, f, x) + y_out2, jac_out2 = value_and_jacobian!(implem, jac_in2, backend, f, x) - jac_out3 = jacobian(implems..., backend, f, x) + jac_out3 = jacobian(implem, backend, f, x) jac_in4 = zero(jac_out3) - jac_out4 = jacobian!(implems..., jac_in4, backend, f, x) + jac_out4 = jacobian!(implem, jac_in4, backend, f, x) @testset "Primal value" begin @test y_out1 ≈ y @@ -472,14 +326,12 @@ function test_jacobian( end end allocs && @testset "Allocations" begin - @test iszero( - @allocated value_and_jacobian!(implems..., jac_in2, backend, f, x) - ) - @test iszero(@allocated jacobian!(implems..., jac_in4, backend, f, x)) + @test iszero(@allocated value_and_jacobian!(implem, jac_in2, backend, f, x)) + @test iszero(@allocated jacobian!(implem, jac_in4, backend, f, x)) end type_stability && @testset "Type stability" begin - @test_opt value_and_jacobian!(implems..., jac_in2, backend, f, x) - @test_opt jacobian!(implems..., jac_in4, backend, f, x) + @test_opt value_and_jacobian!(implem, jac_in2, backend, f, x) + @test_opt jacobian!(implem, jac_in4, backend, f, x) end end end @@ -487,9 +339,9 @@ function test_jacobian( end function test_jacobian_and_friends( + implem::AbstractImplem, backend::AbstractADType, - scenarios::Vector{<:Scenario}, - implems...; + scenarios::Vector{<:Scenario}; input_type::Type=Any, output_type::Type=Any, allocs::Bool=false, @@ -498,9 +350,9 @@ function test_jacobian_and_friends( scenarios = filter(scenarios) do s (get_input_type(s) <: input_type) && (get_output_type(s) <: output_type) end - test_derivative(backend, scenarios, implems...; allocs, type_stability) - test_multiderivative(backend, scenarios, implems...; allocs, type_stability) - test_gradient(backend, scenarios, implems...; allocs, type_stability) - test_jacobian(backend, scenarios, implems...; allocs, type_stability) + test_derivative(implem, backend, scenarios; allocs, type_stability) + test_multiderivative(implem, backend, scenarios; allocs, type_stability) + test_gradient(implem, backend, scenarios; allocs, type_stability) + test_jacobian(implem, backend, scenarios; allocs, type_stability) return nothing end diff --git a/test/zygote.jl b/test/zygote.jl index 4839be695..12c4c18bf 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -1,6 +1,7 @@ using ADTypes: AutoZygote +using DifferentiationInterface: CustomImplem, FallbackImplem using Zygote: Zygote test_pullback(AutoZygote(), scenarios; type_stability=false); -test_jacobian_and_friends(AutoZygote(), scenarios; type_stability=false); -test_jacobian_and_friends(AutoZygote(), scenarios, Val(:fallback); type_stability=false); +test_jacobian_and_friends(CustomImplem(), AutoZygote(), scenarios; type_stability=false); +test_jacobian_and_friends(FallbackImplem(), AutoZygote(), scenarios; type_stability=false); From 83fa9771afc00f5fb8f5eb76af82d7a57cf3d86d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 10 Mar 2024 19:55:08 +0100 Subject: [PATCH 5/9] Try to fix docs --- README.md | 2 +- docs/make.jl | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 70e34e67d..fa3190c02 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ We support some of the backends defined by [ADTypes.jl](https://github.com/SciML We also support two more backends which are not yet part of ADTypes.jl: - [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) with `AutoChainRules(ruleconfig)` -- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) with `AutoDiffractor()` +- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) with `AutoDiffractor()` (broken for now) ## Design diff --git a/docs/make.jl b/docs/make.jl index 6eb2f67bb..b5a7cbf32 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,10 +1,9 @@ -using ADTypes using Base: get_extension using DifferentiationInterface import DifferentiationInterface as DI using Documenter -using Diffractor: Diffractor +using ADTypes using Enzyme: Enzyme using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff @@ -13,7 +12,6 @@ using ReverseDiff: ReverseDiff using Zygote: Zygote ChainRulesCoreExt = get_extension(DI, :DifferentiationInterfaceChainRulesCoreExt) -DiffractorExt = get_extension(DI, :DifferentiationInterfaceDiffractorExt) EnzymeExt = get_extension(DI, :DifferentiationInterfaceEnzymeExt) FiniteDiffExt = get_extension(DI, :DifferentiationInterfaceFiniteDiffExt) ForwardDiffExt = get_extension(DI, :DifferentiationInterfaceForwardDiffExt) @@ -49,7 +47,6 @@ makedocs(; DifferentiationInterface, ADTypes, ChainRulesCoreExt, - DiffractorExt, EnzymeExt, FiniteDiffExt, ForwardDiffExt, @@ -62,6 +59,7 @@ makedocs(; format=Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", canonical="https://gdalle.github.io/DifferentiationInterface.jl", + edit_link="main", ), pages=["Home" => "index.md", "Interface" => "interface.md"], ) From c2fce697662b3eff2673c6e10f4fc3d75dd8a1a4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 10 Mar 2024 20:08:18 +0100 Subject: [PATCH 6/9] Fix --- README.md | 2 +- benchmark/utils.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index fa3190c02..cd0084312 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ From these primitives, several utilities are defined, depending on the type of t ## Example -```jldoctest +```julia julia> import DifferentiationInterface, ADTypes, ForwardDiff julia> backend = ADTypes.AutoForwardDiff(); diff --git a/benchmark/utils.jl b/benchmark/utils.jl index 08f08543d..4c85d7e28 100644 --- a/benchmark/utils.jl +++ b/benchmark/utils.jl @@ -3,6 +3,7 @@ using ADTypes: AbstractADType using BenchmarkTools using DifferentiationInterface using DifferentiationInterface: CustomImplem, FallbackImplem, ForwardMode, ReverseMode +using DifferentiationInterface: autodiff_mode ## Pretty printing From abbd41401370ca4b57f0097e83e8ddd831d2ec34 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 10 Mar 2024 21:06:26 +0100 Subject: [PATCH 7/9] Fix docs --- README.md | 2 +- docs/make.jl | 2 +- docs/src/interface.md | 12 ++---------- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index cd0084312..fa3190c02 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ From these primitives, several utilities are defined, depending on the type of t ## Example -```julia +```jldoctest julia> import DifferentiationInterface, ADTypes, ForwardDiff julia> backend = ADTypes.AutoForwardDiff(); diff --git a/docs/make.jl b/docs/make.jl index b5a7cbf32..d22311330 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -24,7 +24,7 @@ ZygoteExt = get_extension(DI, :DifferentiationInterfaceZygoteExt) DocMeta.setdocmeta!( DifferentiationInterface, :DocTestSetup, - :(using DifferentiationInterface); + :(using DifferentiationInterface, ADTypes); recursive=true, ) diff --git a/docs/src/interface.md b/docs/src/interface.md index 17ec45fe2..3a1dbef9d 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -59,24 +59,16 @@ Pages = ["pullback.jl"] ### ADTypes.jl -```@meta -CurrentModule = ADTypes -``` - The following backends are defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl): -```@autodocs -Modules = [ADTypes] +```@docs +AbstractADType ``` Only a subset is supported by DifferentiationInterface.jl at the moment. ### DifferentiationInterface.jl -```@meta -CurrentModule = DifferentiationInterface -``` - The following backends are defined by DifferentiationInterface.jl: ```@autodocs From b5db0d2b20ccd1180fad4bd1f59e64274087c91f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 11 Mar 2024 16:33:51 +0100 Subject: [PATCH 8/9] Put implem and mode at the end --- README.md | 2 +- benchmark/benchmarks.jl | 3 +- benchmark/utils.jl | 33 +++---- ...fferentiationInterfaceChainRulesCoreExt.jl | 20 +++-- ext/DifferentiationInterfaceDiffractorExt.jl | 9 +- .../forward.jl | 12 ++- .../reverse.jl | 19 +++- ext/DifferentiationInterfaceFiniteDiffExt.jl | 51 ++++++++--- ext/DifferentiationInterfaceForwardDiffExt.jl | 45 +++++++--- ...tiationInterfacePolyesterForwardDiffExt.jl | 22 ++++- ext/DifferentiationInterfaceReverseDiffExt.jl | 34 +++++-- ext/DifferentiationInterfaceZygoteExt.jl | 38 ++++++-- src/DifferentiationInterface.jl | 1 - src/array_array.jl | 45 +++++----- src/array_scalar.jl | 45 +++++----- src/custom.jl | 29 ------ src/pullback.jl | 24 ++--- src/pushforward.jl | 24 ++--- src/scalar_array.jl | 45 +++++----- src/scalar_scalar.jl | 26 +++--- test/chainrules_forward.jl | 2 +- test/diffractor.jl | 6 +- test/runtests.jl | 2 +- test/scenarios.jl | 34 ++----- test/utils.jl | 88 ++++++++++++------- 25 files changed, 392 insertions(+), 267 deletions(-) delete mode 100644 src/custom.jl diff --git a/README.md b/README.md index fa3190c02..70e34e67d 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ We support some of the backends defined by [ADTypes.jl](https://github.com/SciML We also support two more backends which are not yet part of ADTypes.jl: - [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) with `AutoChainRules(ruleconfig)` -- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) with `AutoDiffractor()` (broken for now) +- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) with `AutoDiffractor()` ## Design diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 9d6953cf2..8020248d0 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -44,6 +44,7 @@ end all_backends = [ AutoChainRules(ZygoteRuleConfig()), + AutoDiffractor(), AutoEnzyme(Val(:forward)), AutoEnzyme(Val(:reverse)), AutoFiniteDiff(), @@ -108,7 +109,7 @@ include("utils.jl") SUITE = make_suite() # Run benchmarks locally -# results = BenchmarkTools.run(SUITE; verbose=true) +results = BenchmarkTools.run(SUITE; verbose=true) # Compare commits locally # using BenchmarkCI; BenchmarkCI.judge(baseline="origin/main"); BenchmarkCI.displayjudgement() diff --git a/benchmark/utils.jl b/benchmark/utils.jl index 4c85d7e28..bd3ca2c37 100644 --- a/benchmark/utils.jl +++ b/benchmark/utils.jl @@ -5,9 +5,12 @@ using DifferentiationInterface using DifferentiationInterface: CustomImplem, FallbackImplem, ForwardMode, ReverseMode using DifferentiationInterface: autodiff_mode +const NO_EXTRAS = nothing + ## Pretty printing -pretty(::AutoChainRules{<:ZygoteRuleConfig}) = "ChainRules (Zygote)" +pretty(::AutoChainRules{<:ZygoteRuleConfig}) = "ChainRules{Zygote}" +pretty(::AutoDiffractor) = "Diffractor (forward)" pretty(::AutoEnzyme{Val{:forward}}) = "Enzyme (forward)" pretty(::AutoEnzyme{Val{:reverse}}) = "Enzyme (reverse)" pretty(::AutoFiniteDiff) = "FiniteDiff" @@ -93,11 +96,11 @@ function add_derivative_benchmarks!( backend_implem = "$(pretty(backend)) - $(pretty(implem))" suite["value_and_derivative"][(1, 1)][backend_implem] = @benchmarkable begin - value_and_derivative($implem, $backend, $f, $x) + value_and_derivative($backend, $f, $x, $NO_EXTRAS, $implem) end suite["derivative"][(1, 1)][backend_implem] = @benchmarkable begin - derivative($implem, $backend, $f, $x) + derivative($backend, $f, $x, $NO_EXTRAS, $implem) end end @@ -119,17 +122,17 @@ function add_multiderivative_benchmarks!( backend_implem = "$(pretty(backend)) - $(pretty(implem))" suite["value_and_multiderivative"][(1, m)][backend_implem] = @benchmarkable begin - value_and_multiderivative($implem, $backend, $f, $x) + value_and_multiderivative($backend, $f, $x, $NO_EXTRAS, $implem) end suite["value_and_multiderivative!"][(1, m)][backend_implem] = @benchmarkable begin - value_and_multiderivative!($implem, $multider, $backend, $f, $x) + value_and_multiderivative!($multider, $backend, $f, $x, $NO_EXTRAS, $implem) end suite["multiderivative"][(1, m)][backend_implem] = @benchmarkable begin - multiderivative($implem, $backend, $f, $x) + multiderivative($backend, $f, $x, $NO_EXTRAS, $implem) end suite["multiderivative!"][(1, m)][backend_implem] = @benchmarkable begin - multiderivative!($implem, $multider, $backend, $f, $x) + multiderivative!($multider, $backend, $f, $x, $NO_EXTRAS, $implem) end end @@ -151,17 +154,17 @@ function add_gradient_benchmarks!( backend_implem = "$(pretty(backend)) - $(pretty(implem))" suite["value_and_gradient"][(n, 1)][backend_implem] = @benchmarkable begin - value_and_gradient($implem, $backend, $f, $x) + value_and_gradient($backend, $f, $x, $NO_EXTRAS, $implem) end suite["value_and_gradient!"][(n, 1)][backend_implem] = @benchmarkable begin - value_and_gradient!($implem, $grad, $backend, $f, $x) + value_and_gradient!($grad, $backend, $f, $x, $NO_EXTRAS, $implem) end suite["gradient"][(n, 1)][backend_implem] = @benchmarkable begin - gradient($implem, $backend, $f, $x) + gradient($backend, $f, $x, $NO_EXTRAS, $implem) end suite["gradient!"][(n, 1)][backend_implem] = @benchmarkable begin - gradient!($implem, $grad, $backend, $f, $x) + gradient!($grad, $backend, $f, $x, $NO_EXTRAS, $implem) end end @@ -182,17 +185,17 @@ function add_jacobian_benchmarks!( backend_implem = "$(pretty(backend)) - $(pretty(implem))" suite["value_and_jacobian"][(n, m)][backend_implem] = @benchmarkable begin - value_and_jacobian($implem, $backend, $f, $x) + value_and_jacobian($backend, $f, $x, $NO_EXTRAS, $implem) end suite["value_and_jacobian!"][(n, m)][backend_implem] = @benchmarkable begin - value_and_jacobian!($implem, $jac, $backend, $f, $x) + value_and_jacobian!($jac, $backend, $f, $x, $NO_EXTRAS, $implem) end suite["jacobian"][(n, m)][backend_implem] = @benchmarkable begin - jacobian($implem, $backend, $f, $x) + jacobian($backend, $f, $x, $NO_EXTRAS, $implem) end suite["jacobian!"][(n, m)][backend_implem] = @benchmarkable begin - jacobian!($implem, $jac, $backend, $f, $x) + jacobian!($jac, $backend, $f, $x, $NO_EXTRAS, $implem) end end diff --git a/ext/DifferentiationInterfaceChainRulesCoreExt.jl b/ext/DifferentiationInterfaceChainRulesCoreExt.jl index 3deef027a..c2d212ed8 100644 --- a/ext/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/ext/DifferentiationInterfaceChainRulesCoreExt.jl @@ -16,26 +16,34 @@ DI.autodiff_mode(::AutoReverseChainRules) = DI.ReverseMode() ## Primitives -function DI.value_and_pushforward(backend::AutoForwardChainRules, f, x, dx) +function DI.value_and_pushforward( + backend::AutoForwardChainRules, f, x, dx, extras::Nothing=nothing +) rc = ruleconfig(backend) y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x) return y, new_dy end -function DI.value_and_pushforward!(dy, backend::AutoForwardChainRules, f, x, dx) - y, new_dy = DI.value_and_pushforward(backend, f, x, dx) +function DI.value_and_pushforward!( + dy, backend::AutoForwardChainRules, f, x, dx, extras=nothing +) + y, new_dy = DI.value_and_pushforward(backend, f, x, dx, extras) return y, update!(dy, new_dy) end -function DI.value_and_pullback(backend::AutoReverseChainRules, f, x, dy) +function DI.value_and_pullback( + backend::AutoReverseChainRules, f, x, dy, extras::Nothing=nothing +) rc = ruleconfig(backend) y, pullback = rrule_via_ad(rc, f, x) _, new_dx = pullback(dy) return y, new_dx end -function DI.value_and_pullback!(dx, backend::AutoReverseChainRules, f, x, dy) - y, new_dx = DI.value_and_pullback(backend, f, x, dy) +function DI.value_and_pullback!( + dx, backend::AutoReverseChainRules, f, x, dy, extras=nothing +) + y, new_dx = DI.value_and_pullback(backend, f, x, dy, extras) return y, update!(dx, new_dx) end diff --git a/ext/DifferentiationInterfaceDiffractorExt.jl b/ext/DifferentiationInterfaceDiffractorExt.jl index 81c146e75..333fdb167 100644 --- a/ext/DifferentiationInterfaceDiffractorExt.jl +++ b/ext/DifferentiationInterfaceDiffractorExt.jl @@ -1,20 +1,21 @@ module DifferentiationInterfaceDiffractorExt import AbstractDifferentiation as AD # public API for Diffractor -using DifferentiationInterface: AutoDiffractor, update! +using DifferentiationInterface: AutoChainRules, AutoDiffractor, update! import DifferentiationInterface as DI -using Diffractor: DiffractorForwardBackend +using Diffractor: DiffractorForwardBackend, DiffractorRuleConfig using DocStringExtensions DI.autodiff_mode(::AutoDiffractor) = DI.ForwardMode() +DI.autodiff_mode(::AutoChainRules{<:DiffractorRuleConfig}) = DI.ForwardMode() -function DI.value_and_pushforward(::AutoDiffractor, f, x, dx) +function DI.value_and_pushforward(::AutoDiffractor, f, x, dx, extras::Nothing=nothing) vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x) y, dy = vpff((dx,)) return y, dy end -function DI.value_and_pushforward!(dy, ::AutoDiffractor, f, x, dx) +function DI.value_and_pushforward!(dy, ::AutoDiffractor, f, x, dx, extras::Nothing=nothing) vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x) y, new_dy = vpff((dx,)) return y, update!(dy, new_dy) diff --git a/ext/DifferentiationInterfaceEnzymeExt/forward.jl b/ext/DifferentiationInterfaceEnzymeExt/forward.jl index 9a28138bc..d4f06a2a6 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/forward.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/forward.jl @@ -4,14 +4,14 @@ DI.autodiff_mode(::AutoForwardEnzyme) = DI.ForwardMode() ## Primitives function DI.value_and_pushforward!( - _dy::Y, ::AutoForwardEnzyme, f, x::X, dx + _dy::Y, ::AutoForwardEnzyme, f, x::X, dx, extras::Nothing=nothing ) where {X,Y<:Real} y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx)) return y, new_dy end function DI.value_and_pushforward!( - dy::Y, ::AutoForwardEnzyme, f, x::X, dx + dy::Y, ::AutoForwardEnzyme, f, x::X, dx, extras::Nothing=nothing ) where {X,Y<:AbstractArray} y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx)) dy .= new_dy @@ -20,7 +20,13 @@ end ## Utilities -function DI.value_and_jacobian(::CustomImplem, ::AutoForwardEnzyme, f, x::AbstractArray) +function DI.value_and_jacobian( + ::AutoForwardEnzyme, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), +) y = f(x) jac = jacobian(Forward, f, x) # see https://github.com/EnzymeAD/Enzyme.jl/issues/1332 diff --git a/ext/DifferentiationInterfaceEnzymeExt/reverse.jl b/ext/DifferentiationInterfaceEnzymeExt/reverse.jl index 854e69d88..2702aecbc 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/reverse.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/reverse.jl @@ -5,7 +5,7 @@ DI.handles_output_type(::AutoReverseEnzyme, ::Type{<:AbstractArray}) = false ## Primitives function DI.value_and_pullback!( - _dx, ::AutoReverseEnzyme, f, x::X, dy::Y + _dx, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing ) where {X<:Number,Y<:Union{Real,Nothing}} der, y = autodiff(ReverseWithPrimal, f, Active, Active(x)) new_dx = dy * only(der) @@ -13,7 +13,7 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - dx::X, ::AutoReverseEnzyme, f, x::X, dy::Y + dx::X, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing ) where {X<:AbstractArray,Y<:Union{Real,Nothing}} dx .= zero(eltype(dx)) _, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx)) @@ -23,14 +23,25 @@ end ## Utilities -function DI.value_and_gradient(::CustomImplem, ::AutoReverseEnzyme, f, x::AbstractArray) +function DI.value_and_gradient( + ::AutoReverseEnzyme, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), +) y = f(x) grad = gradient(Reverse, f, x) return y, grad end function DI.value_and_gradient!( - ::CustomImplem, grad::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray + grad::AbstractArray, + ::AutoReverseEnzyme, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) y = f(x) gradient!(Reverse, grad, f, x) diff --git a/ext/DifferentiationInterfaceFiniteDiffExt.jl b/ext/DifferentiationInterfaceFiniteDiffExt.jl index b5814f5cb..0f4b47cb7 100644 --- a/ext/DifferentiationInterfaceFiniteDiffExt.jl +++ b/ext/DifferentiationInterfaceFiniteDiffExt.jl @@ -18,7 +18,7 @@ const FUNCTION_NOT_INPLACE = Val{false} ## Primitives function DI.value_and_pushforward!( - dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx + dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing ) where {Y<:Number,fdtype} y = f(x) step(t::Number)::Number = f(x .+ t .* dx) @@ -27,7 +27,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx + dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing ) where {Y<:AbstractArray,fdtype} y = f(x) step(t::Number)::AbstractArray = f(x .+ t .* dx) @@ -40,7 +40,11 @@ end ## Utilities function DI.value_and_derivative( - ::CustomImplem, ::AutoFiniteDiff{fdtype}, f, x::Number + ::AutoFiniteDiff{fdtype}, + f, + x::Number, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {fdtype} y = f(x) der = finite_difference_derivative(f, x, fdtype, eltype(y), y) @@ -48,7 +52,12 @@ function DI.value_and_derivative( end function DI.value_and_multiderivative!( - ::CustomImplem, multider::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x::Number + multider::AbstractArray, + ::AutoFiniteDiff{fdtype}, + f, + x::Number, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {fdtype} y = f(x) finite_difference_gradient!(multider, f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) @@ -56,7 +65,11 @@ function DI.value_and_multiderivative!( end function DI.value_and_multiderivative( - ::CustomImplem, ::AutoFiniteDiff{fdtype}, f, x::Number + ::AutoFiniteDiff{fdtype}, + f, + x::Number, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {fdtype} y = f(x) multider = finite_difference_gradient(f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) @@ -64,7 +77,12 @@ function DI.value_and_multiderivative( end function DI.value_and_gradient!( - ::CustomImplem, grad::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x::AbstractArray + grad::AbstractArray, + ::AutoFiniteDiff{fdtype}, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {fdtype} y = f(x) finite_difference_gradient!(grad, f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) @@ -72,7 +90,11 @@ function DI.value_and_gradient!( end function DI.value_and_gradient( - ::CustomImplem, ::AutoFiniteDiff{fdtype}, f, x::AbstractArray + ::AutoFiniteDiff{fdtype}, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {fdtype} y = f(x) grad = finite_difference_gradient(f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) @@ -80,7 +102,11 @@ function DI.value_and_gradient( end function DI.value_and_jacobian( - ::CustomImplem, ::AutoFiniteDiff{fdtype}, f, x::AbstractArray + ::AutoFiniteDiff{fdtype}, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {fdtype} y = f(x) jac = finite_difference_jacobian(f, x, fdtype, eltype(y)) @@ -88,9 +114,14 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - ::CustomImplem, jac::AbstractMatrix, backend::AutoFiniteDiff, f, x::AbstractArray + jac::AbstractMatrix, + backend::AutoFiniteDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, + implem::CustomImplem=CustomImplem(), ) - y, new_jac = DI.value_and_jacobian(backend, f, x) + y, new_jac = DI.value_and_jacobian(backend, f, x, extras, implem) jac .= new_jac return y, jac end diff --git a/ext/DifferentiationInterfaceForwardDiffExt.jl b/ext/DifferentiationInterfaceForwardDiffExt.jl index 5825a6098..8e66b76db 100644 --- a/ext/DifferentiationInterfaceForwardDiffExt.jl +++ b/ext/DifferentiationInterfaceForwardDiffExt.jl @@ -22,7 +22,7 @@ using LinearAlgebra: mul! ## Primitives function DI.value_and_pushforward!( - _dy::Y, ::AutoForwardDiff, f, x::X, dx + _dy::Y, ::AutoForwardDiff, f, x::X, dx, extras::Nothing=nothing ) where {X<:Real,Y<:Real} T = typeof(Tag(f, X)) xdual = Dual{T}(x, dx) @@ -33,7 +33,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::Y, ::AutoForwardDiff, f, x::X, dx + dy::Y, ::AutoForwardDiff, f, x::X, dx, extras::Nothing=nothing ) where {X<:Real,Y<:AbstractArray} T = typeof(Tag(f, X)) xdual = Dual{T}(x, dx) @@ -44,7 +44,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - _dy::Y, ::AutoForwardDiff, f, x::X, dx + _dy::Y, ::AutoForwardDiff, f, x::X, dx, extras::Nothing=nothing ) where {X<:AbstractArray,Y<:Real} T = typeof(Tag(f, X)) # TODO: unsure xdual = Dual{T}.(x, dx) # TODO: allocation @@ -55,7 +55,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::Y, ::AutoForwardDiff, f, x::X, dx + dy::Y, ::AutoForwardDiff, f, x::X, dx, extras::Nothing=nothing ) where {X<:AbstractArray,Y<:AbstractArray} T = typeof(Tag(f, X)) # TODO: unsure xdual = Dual{T}.(x, dx) # TODO: allocation @@ -67,48 +67,71 @@ end ## Utilities (TODO: use DiffResults) -function DI.value_and_derivative(::CustomImplem, ::AutoForwardDiff, f, x::Number) +function DI.value_and_derivative( + ::AutoForwardDiff, f, x::Number, extras::Nothing, ::CustomImplem=CustomImplem() +) y = f(x) der = derivative(f, x) return y, der end -function DI.value_and_multiderivative(::CustomImplem, ::AutoForwardDiff, f, x::Number) +function DI.value_and_multiderivative( + ::AutoForwardDiff, f, x::Number, extras::Nothing, ::CustomImplem=CustomImplem() +) y = f(x) multider = derivative(f, x) return y, multider end function DI.value_and_multiderivative!( - ::CustomImplem, multider::AbstractArray, ::AutoForwardDiff, f, x::Number + multider::AbstractArray, + ::AutoForwardDiff, + f, + x::Number, + extras::Nothing, + ::CustomImplem=CustomImplem(), ) y = f(x) derivative!(multider, f, x) return y, multider end -function DI.value_and_gradient(::CustomImplem, ::AutoForwardDiff, f, x::AbstractArray) +function DI.value_and_gradient( + ::AutoForwardDiff, f, x::AbstractArray, extras::Nothing, ::CustomImplem=CustomImplem() +) y = f(x) grad = gradient(f, x) return y, grad end function DI.value_and_gradient!( - ::CustomImplem, grad::AbstractArray, ::AutoForwardDiff, f, x::AbstractArray + grad::AbstractArray, + ::AutoForwardDiff, + f, + x::AbstractArray, + extras::Nothing, + ::CustomImplem=CustomImplem(), ) y = f(x) gradient!(grad, f, x) return y, grad end -function DI.value_and_jacobian(::CustomImplem, ::AutoForwardDiff, f, x::AbstractArray) +function DI.value_and_jacobian( + ::AutoForwardDiff, f, x::AbstractArray, extras::Nothing, ::CustomImplem=CustomImplem() +) y = f(x) jac = jacobian(f, x) return y, jac end function DI.value_and_jacobian!( - ::CustomImplem, jac::AbstractMatrix, ::AutoForwardDiff, f, x::AbstractArray + jac::AbstractMatrix, + ::AutoForwardDiff, + f, + x::AbstractArray, + extras::Nothing, + ::CustomImplem=CustomImplem(), ) y = f(x) jacobian!(jac, f, x) diff --git a/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl b/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl index 868cf1d09..074e96091 100644 --- a/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -11,14 +11,23 @@ using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian! ## Primitives -function DI.value_and_pushforward!(dy, ::AutoPolyesterForwardDiff{C}, f, x, dx) where {C} - return DI.value_and_pushforward!(dy, AutoForwardDiff{C,Nothing}(nothing), f, x, dx) +function DI.value_and_pushforward!( + dy, ::AutoPolyesterForwardDiff{C}, f, x, dx, extras::Nothing=nothing +) where {C} + return DI.value_and_pushforward!( + dy, AutoForwardDiff{C,Nothing}(nothing), f, x, dx, extras + ) end ## Utilities function DI.value_and_gradient!( - ::CustomImplem, grad::AbstractArray, ::AutoPolyesterForwardDiff{C}, f, x::AbstractArray + grad::AbstractArray, + ::AutoPolyesterForwardDiff{C}, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {C} y = f(x) threaded_gradient!(f, grad, x, Chunk{C}()) @@ -26,7 +35,12 @@ function DI.value_and_gradient!( end function DI.value_and_jacobian!( - ::CustomImplem, jac::AbstractMatrix, ::AutoPolyesterForwardDiff{C}, f, x::AbstractArray + jac::AbstractMatrix, + ::AutoPolyesterForwardDiff{C}, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {C} y = f(x) threaded_jacobian!(f, jac, x, Chunk{C}()) diff --git a/ext/DifferentiationInterfaceReverseDiffExt.jl b/ext/DifferentiationInterfaceReverseDiffExt.jl index d7d29a0de..c5061584b 100644 --- a/ext/DifferentiationInterfaceReverseDiffExt.jl +++ b/ext/DifferentiationInterfaceReverseDiffExt.jl @@ -15,7 +15,7 @@ DI.handles_input_type(::AutoReverseDiff, ::Type{<:Number}) = false ## Primitives function DI.value_and_pullback!( - dx, ::AutoReverseDiff, f, x::X, dy::Y + dx, ::AutoReverseDiff, f, x::X, dy::Y, extras::Nothing=nothing ) where {X<:AbstractArray,Y<:Real} res = DiffResults.DiffResult(zero(Y), dx) res = gradient!(res, f, x) @@ -25,7 +25,7 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - dx, ::AutoReverseDiff, f, x::X, dy::Y + dx, ::AutoReverseDiff, f, x::X, dy::Y, extras::Nothing=nothing ) where {X<:AbstractArray,Y<:AbstractArray} res = DiffResults.DiffResult(similar(dy), similar(dy, length(dy), length(x))) res = jacobian!(res, f, x) @@ -37,28 +37,50 @@ end ## Utilities (TODO: use DiffResults) -function DI.value_and_gradient(::CustomImplem, ::AutoReverseDiff, f, x::AbstractArray) +function DI.value_and_gradient( + ::AutoReverseDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), +) y = f(x) grad = gradient(f, x) return y, grad end function DI.value_and_gradient!( - ::CustomImplem, grad::AbstractArray, ::AutoReverseDiff, f, x::AbstractArray + grad::AbstractArray, + ::AutoReverseDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) y = f(x) gradient!(grad, f, x) return y, grad end -function DI.value_and_jacobian(::CustomImplem, ::AutoReverseDiff, f, x::AbstractArray) +function DI.value_and_jacobian( + ::AutoReverseDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), +) y = f(x) jac = jacobian(f, x) return y, jac end function DI.value_and_jacobian!( - ::CustomImplem, jac::AbstractMatrix, ::AutoReverseDiff, f, x::AbstractArray + jac::AbstractMatrix, + ::AutoReverseDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) y = f(x) jacobian!(jac, f, x) diff --git a/ext/DifferentiationInterfaceZygoteExt.jl b/ext/DifferentiationInterfaceZygoteExt.jl index f864b8603..d6312d080 100644 --- a/ext/DifferentiationInterfaceZygoteExt.jl +++ b/ext/DifferentiationInterfaceZygoteExt.jl @@ -10,13 +10,13 @@ using Zygote: ZygoteRuleConfig, gradient, jacobian, pullback, withgradient, with const zygote_chainrules_backend = AutoChainRules(ZygoteRuleConfig()) -function DI.value_and_pullback!(dx, ::AutoZygote, f, x, dy) +function DI.value_and_pullback!(dx, ::AutoZygote, f, x, dy, extras::Nothing=nothing) y, back = pullback(f, x) new_dx = only(back(dy)) return y, update!(dx, new_dx) end -function DI.value_and_pullback(::AutoZygote, f, x, dy) +function DI.value_and_pullback(::AutoZygote, f, x, dy, extras::Nothing=nothing) y, back = pullback(f, x) dx = only(back(dy)) return y, dx @@ -24,29 +24,51 @@ end ## Utilities -function DI.value_and_gradient(::CustomImplem, ::AutoZygote, f, x::AbstractArray) +function DI.value_and_gradient( + ::AutoZygote, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), +) res = withgradient(f, x) return res.val, only(res.grad) end function DI.value_and_gradient!( - ::CustomImplem, grad::AbstractArray, backend::AutoZygote, f, x::AbstractArray + grad::AbstractArray, + backend::AutoZygote, + f, + x::AbstractArray, + extras=nothing, + implem::CustomImplem=CustomImplem(), ) - y, new_grad = DI.value_and_gradient(backend, f, x) + y, new_grad = DI.value_and_gradient(backend, f, x, extras, implem) grad .= new_grad return y, grad end -function DI.value_and_jacobian(::CustomImplem, ::AutoZygote, f, x::AbstractArray) +function DI.value_and_jacobian( + ::AutoZygote, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), +) y = f(x) jac = jacobian(f, x) return y, only(jac) end function DI.value_and_jacobian!( - ::CustomImplem, jac::AbstractMatrix, backend::AutoZygote, f, x::AbstractArray + jac::AbstractMatrix, + backend::AutoZygote, + f, + x::AbstractArray, + extras::Nothing=nothing, + implem::CustomImplem=CustomImplem(), ) - y, new_jac = DI.value_and_jacobian(backend, f, x) + y, new_jac = DI.value_and_jacobian(backend, f, x, extras, implem) jac .= new_jac return y, jac end diff --git a/src/DifferentiationInterface.jl b/src/DifferentiationInterface.jl index c1a941cf4..b8be910cc 100644 --- a/src/DifferentiationInterface.jl +++ b/src/DifferentiationInterface.jl @@ -24,7 +24,6 @@ include("scalar_scalar.jl") include("scalar_array.jl") include("array_scalar.jl") include("array_array.jl") -include("custom.jl") export AutoChainRules, AutoDiffractor diff --git a/src/array_array.jl b/src/array_array.jl index df1218477..c9607e032 100644 --- a/src/array_array.jl +++ b/src/array_array.jl @@ -6,20 +6,21 @@ This function acts as if the input and output had been flattened with `vec`. """ """ - value_and_jacobian!(jac, backend, f, x) -> (y, jac) + value_and_jacobian!(jac, backend, f, x, [extras]) -> (y, jac) Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of an array-to-array function, overwriting `jac` if possible. $JAC_NOTES """ function value_and_jacobian!( - implem::AbstractImplem, jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray, + extras=nothing, + implem::AbstractImplem=CustomImplem(), ) - return value_and_jacobian!(implem, autodiff_mode(backend), jac, backend, f, x) + return value_and_jacobian!(jac, backend, f, x, extras, implem, autodiff_mode(backend)) end function check_jac(jac::AbstractMatrix, x::AbstractArray, y::AbstractArray) @@ -31,81 +32,77 @@ function check_jac(jac::AbstractMatrix, x::AbstractArray, y::AbstractArray) end function value_and_jacobian!( - ::AbstractImplem, - ::ForwardMode, jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray, + extras, + ::AbstractImplem, + ::ForwardMode, ) y = f(x) check_jac(jac, x, y) for (k, j) in enumerate(eachindex(IndexCartesian(), x)) dx_j = basisarray(backend, x, j) jac_col_j = reshape(view(jac, :, k), size(y)) - pushforward!(jac_col_j, backend, f, x, dx_j) + pushforward!(jac_col_j, backend, f, x, dx_j, extras) end return y, jac end function value_and_jacobian!( - ::AbstractImplem, - ::ReverseMode, jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray, + extras, + ::AbstractImplem, + ::ReverseMode, ) y = f(x) check_jac(jac, x, y) for (k, i) in enumerate(eachindex(IndexCartesian(), y)) dy_i = basisarray(backend, y, i) jac_row_i = reshape(view(jac, k, :), size(x)) - pullback!(jac_row_i, backend, f, x, dy_i) + pullback!(jac_row_i, backend, f, x, dy_i, extras) end return y, jac end """ - value_and_jacobian(backend, f, x) -> (y, jac) + value_and_jacobian(backend, f, x, [extras]) -> (y, jac) Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of an array-to-array function. $JAC_NOTES """ -function value_and_jacobian( - implem::AbstractImplem, backend::AbstractADType, f, x::AbstractArray -) +function value_and_jacobian(backend::AbstractADType, f, x::AbstractArray, args...) y = f(x) T = promote_type(eltype(x), eltype(y)) jac = similar(y, T, length(y), length(x)) - return value_and_jacobian!(implem, jac, backend, f, x) + return value_and_jacobian!(jac, backend, f, x, args...) end """ - jacobian!(jac, backend, f, x) -> jac + jacobian!(jac, backend, f, x, [extras]) -> jac Compute the Jacobian matrix `jac = ∂f(x)` of an array-to-array function, overwriting `jac` if possible. $JAC_NOTES """ function jacobian!( - implem::AbstractImplem, - jac::AbstractMatrix, - backend::AbstractADType, - f, - x::AbstractArray, + jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray, args... ) - return last(value_and_jacobian!(implem, jac, backend, f, x)) + return last(value_and_jacobian!(jac, backend, f, x, args...)) end """ - jacobian(backend, f, x) -> jac + jacobian(backend, f, x, [extras]) -> jac Compute the Jacobian matrix `jac = ∂f(x)` of an array-to-array function. $JAC_NOTES """ -function jacobian(implem::AbstractImplem, backend::AbstractADType, f, x::AbstractArray) - return last(value_and_jacobian(implem, backend, f, x)) +function jacobian(backend::AbstractADType, f, x::AbstractArray, args...) + return last(value_and_jacobian(backend, f, x, args...)) end diff --git a/src/array_scalar.jl b/src/array_scalar.jl index 95d981b10..98aca4abe 100644 --- a/src/array_scalar.jl +++ b/src/array_scalar.jl @@ -1,78 +1,75 @@ """ - value_and_gradient!(grad, backend, f, x) -> (y, grad) + value_and_gradient!(grad, backend, f, x, [extras]) -> (y, grad) Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function, overwriting `grad` if possible. """ function value_and_gradient!( - implem::AbstractImplem, grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray, + extras=nothing, + implem::AbstractImplem=CustomImplem(), ) - return value_and_gradient!(implem, autodiff_mode(backend), grad, backend, f, x) + return value_and_gradient!(grad, backend, f, x, extras, implem, autodiff_mode(backend)) end function value_and_gradient!( - ::AbstractImplem, - ::ForwardMode, grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray, + extras, + ::AbstractImplem, + ::ForwardMode, ) y = f(x) for j in eachindex(IndexCartesian(), x) dx_j = basisarray(backend, x, j) - grad[j] = pushforward!(grad[j], backend, f, x, dx_j) + grad[j] = pushforward!(grad[j], backend, f, x, dx_j, extras) end return y, grad end function value_and_gradient!( - ::AbstractImplem, - ::ReverseMode, grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray, + extras, + ::AbstractImplem, + ::ReverseMode, ) y = f(x) - return y, pullback!(grad, backend, f, x, one(y)) + return y, pullback!(grad, backend, f, x, one(y), extras) end """ - value_and_gradient(backend, f, x) -> (y, grad) + value_and_gradient(backend, f, x, [extras]) -> (y, grad) Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function. """ -function value_and_gradient( - implem::AbstractImplem, backend::AbstractADType, f, x::AbstractArray -) +function value_and_gradient(backend::AbstractADType, f, x::AbstractArray, args...) grad = similar(x) - return value_and_gradient!(implem, grad, backend, f, x) + return value_and_gradient!(grad, backend, f, x, args...) end """ - gradient!(grad, backend, f, x) -> grad + gradient!(grad, backend, f, x, [extras]) -> grad Compute the gradient `grad = ∇f(x)` of an array-to-scalar function, overwriting `grad` if possible. """ function gradient!( - implem::AbstractImplem, - grad::AbstractArray, - backend::AbstractADType, - f, - x::AbstractArray, + grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray, args... ) - return last(value_and_gradient!(implem, grad, backend, f, x)) + return last(value_and_gradient!(grad, backend, f, x, args...)) end """ - gradient(backend, f, x) -> grad + gradient(backend, f, x, [extras]) -> grad Compute the gradient `grad = ∇f(x)` of an array-to-scalar function. """ -function gradient(implem::AbstractImplem, backend::AbstractADType, f, x::AbstractArray) - return last(value_and_gradient(implem, backend, f, x)) +function gradient(backend::AbstractADType, f, x::AbstractArray, args...) + return last(value_and_gradient(backend, f, x, args...)) end diff --git a/src/custom.jl b/src/custom.jl deleted file mode 100644 index 683e1cf8c..000000000 --- a/src/custom.jl +++ /dev/null @@ -1,29 +0,0 @@ -for utility in [ - :value_and_derivative, - :value_and_multiderivative, - :value_and_gradient, - :value_and_jacobian, - :derivative, - :multiderivative, - :gradient, - :jacobian, -] - @eval $utility(backend::AbstractADType, f, x::Union{Number,AbstractArray}) = - $utility(CustomImplem(), backend, f, x) -end - -for utility! in [ - :value_and_multiderivative!, - :value_and_gradient!, - :value_and_jacobian!, - :multiderivative!, - :gradient!, - :jacobian!, -] - @eval $utility!( - storage::Union{Number,AbstractArray}, - backend::AbstractADType, - f, - x::Union{Number,AbstractArray}, - ) = $utility!(CustomImplem(), storage, backend, f, x) -end diff --git a/src/pullback.jl b/src/pullback.jl index 637524987..bb496b392 100644 --- a/src/pullback.jl +++ b/src/pullback.jl @@ -1,41 +1,41 @@ """ - value_and_pullback!(dx, backend, f, x, dy) -> (y, dx) + value_and_pullback!(dx, backend, f, x, dy, [extras]) -> (y, dx) Compute the primal value `y = f(x)` and the vector-Jacobian product `dx = ∂f(x)' * dy`, overwriting `dx` if possible. !!! info "Interface requirement" This is the only required implementation for a reverse mode backend. """ -function value_and_pullback!(dx, backend::AbstractADType, f, x, dy) +function value_and_pullback!(dx, backend::AbstractADType, f, x, dy, extras=nothing) return error( - "Backend $backend is not loaded or does not support this type combination: `typeof(x) = $(typeof(x))` and `typeof(y) = $(typeof(dy))`", + "The package for $(typeof(backend)) is not loaded, or the backend does not support this type combination: `typeof(x) = $(typeof(x))` and `typeof(y) = $(typeof(dy))`", ) end """ - value_and_pullback(backend, f, x, dy) -> (y, dx) + value_and_pullback(backend, f, x, dy, [extras]) -> (y, dx) Compute the primal value `y = f(x)` and the vector-Jacobian product `dx = ∂f(x)' * dy`. """ -function value_and_pullback(backend::AbstractADType, f, x, dy) +function value_and_pullback(backend::AbstractADType, f, x, dy, extras=nothing) dx = mysimilar(x) - return value_and_pullback!(dx, backend, f, x, dy) + return value_and_pullback!(dx, backend, f, x, dy, extras) end """ - pullback!(dx, backend, f, x, dy) -> dx + pullback!(dx, backend, f, x, dy, [extras]) -> dx Compute the vector-Jacobian product `dx = ∂f(x)' * dy`, overwriting `dx` if possible. """ -function pullback!(dx, backend::AbstractADType, f, x, dy) - return last(value_and_pullback!(dx, backend, f, x, dy)) +function pullback!(dx, backend::AbstractADType, f, x, dy, extras=nothing) + return last(value_and_pullback!(dx, backend, f, x, dy, extras)) end """ - pullback(backend, f, x, dy) -> dx + pullback(backend, f, x, dy, [extras]) -> dx Compute the vector-Jacobian product `dx = ∂f(x)' * dy`. """ -function pullback(backend::AbstractADType, f, x, dy) - return last(value_and_pullback(backend, f, x, dy)) +function pullback(backend::AbstractADType, f, x, dy, extras=nothing) + return last(value_and_pullback(backend, f, x, dy, extras)) end diff --git a/src/pushforward.jl b/src/pushforward.jl index a4f2a9c4e..0b1836f6c 100644 --- a/src/pushforward.jl +++ b/src/pushforward.jl @@ -1,41 +1,41 @@ """ - value_and_pushforward!(dy, backend, f, x, dx) -> (y, dy) + value_and_pushforward!(dy, backend, f, x, dx, [extras]) -> (y, dy) Compute the primal value `y = f(x)` and the Jacobian-vector product `dy = ∂f(x) * dx`, overwriting `dy` if possible. !!! info "Interface requirement" This is the only required implementation for a forward mode backend. """ -function value_and_pushforward!(dy, backend::AbstractADType, f, x, dx) +function value_and_pushforward!(dy, backend::AbstractADType, f, x, dx, extras=nothing) return error( - "Backend $backend is not loaded or does not support this type combination: `typeof(x) = $(typeof(x))` and `typeof(y) = $(typeof(dy))`", + "The package for $(typeof(backend)) is not loaded, or the backend does not support this type combination: `typeof(x) = $(typeof(x))` and `typeof(y) = $(typeof(dy))`", ) end """ - value_and_pushforward(backend, f, x, dx) -> (y, dy) + value_and_pushforward(backend, f, x, dx, [extras]) -> (y, dy) Compute the primal value `y = f(x)` and the Jacobian-vector product `dy = ∂f(x) * dx`. """ -function value_and_pushforward(backend::AbstractADType, f, x, dx) +function value_and_pushforward(backend::AbstractADType, f, x, dx, extras=nothing) dy = mysimilar(f(x)) - return value_and_pushforward!(dy, backend, f, x, dx) + return value_and_pushforward!(dy, backend, f, x, dx, extras) end """ - pushforward!(dy, backend, f, x, dx) -> dy + pushforward!(dy, backend, f, x, dx, [extras]) -> dy Compute the Jacobian-vector product `dy = ∂f(x) * dx`, overwriting `dy` if possible. """ -function pushforward!(dy, backend::AbstractADType, f, x, dx) - return last(value_and_pushforward!(dy, backend, f, x, dx)) +function pushforward!(dy, backend::AbstractADType, f, x, dx, extras=nothing) + return last(value_and_pushforward!(dy, backend, f, x, dx, extras)) end """ - pushforward(backend, f, x, dx) -> dy + pushforward(backend, f, x, dx, [extras]) -> dy Compute the Jacobian-vector product `dy = ∂f(x) * dx`. """ -function pushforward(backend::AbstractADType, f, x, dx) - return last(value_and_pushforward(backend, f, x, dx)) +function pushforward(backend::AbstractADType, f, x, dx, extras=nothing) + return last(value_and_pushforward(backend, f, x, dx, extras)) end diff --git a/src/scalar_array.jl b/src/scalar_array.jl index af8b2f2ee..c1a7d43e4 100644 --- a/src/scalar_array.jl +++ b/src/scalar_array.jl @@ -1,71 +1,76 @@ """ - value_and_multiderivative!(multider, backend, f, x) -> (y, multider) + value_and_multiderivative!(multider, backend, f, x, [extras]) -> (y, multider) Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function, overwriting `multider` if possible. """ function value_and_multiderivative!( - implem::AbstractImplem, multider::AbstractArray, backend::AbstractADType, f, x::Number + multider::AbstractArray, + backend::AbstractADType, + f, + x::Number, + extras=nothing, + implem::AbstractImplem=CustomImplem(), ) return value_and_multiderivative!( - implem, autodiff_mode(backend), multider, backend, f, x + multider, backend, f, x, extras, implem, autodiff_mode(backend) ) end function value_and_multiderivative!( - ::AbstractImplem, - ::ForwardMode, multider::AbstractArray, backend::AbstractADType, f, x::Number, + extras, + ::AbstractImplem, + ::ForwardMode, ) - return value_and_pushforward!(multider, backend, f, x, one(x)) + return value_and_pushforward!(multider, backend, f, x, one(x), extras) end function value_and_multiderivative!( - ::AbstractImplem, - ::ReverseMode, multider::AbstractArray, backend::AbstractADType, f, x::Number, + extras, + ::AbstractImplem, + ::ReverseMode, ) y = f(x) for i in eachindex(IndexCartesian(), y) dy_i = basisarray(backend, y, i) - multider[i] = pullback!(multider[i], backend, f, x, dy_i) + multider[i] = pullback!(multider[i], backend, f, x, dy_i, extras) end return y, multider end """ - value_and_multiderivative(backend, f, x) -> (y, multider) + value_and_multiderivative(backend, f, x, [extras]) -> (y, multider) Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function. """ -function value_and_multiderivative( - implem::AbstractImplem, backend::AbstractADType, f, x::Number -) +function value_and_multiderivative(backend::AbstractADType, f, x::Number, args...) multider = similar(f(x)) - return value_and_multiderivative!(implem, multider, backend, f, x) + return value_and_multiderivative!(multider, backend, f, x, args...) end """ - multiderivative!(multider, backend, f, x) -> multider + multiderivative!(multider, backend, f, x, [extras]) -> multider Compute the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function, overwriting `multider` if possible. """ function multiderivative!( - implem::AbstractImplem, multider::AbstractArray, backend::AbstractADType, f, x::Number + multider::AbstractArray, backend::AbstractADType, f, x::Number, args... ) - return last(value_and_multiderivative!(implem, multider, backend, f, x)) + return last(value_and_multiderivative!(multider, backend, f, x, args...)) end """ - multiderivative(backend, f, x) -> multider + multiderivative(backend, f, x, [extras]) -> multider Compute the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function. """ -function multiderivative(implem::AbstractImplem, backend::AbstractADType, f, x::Number) - return last(value_and_multiderivative(implem, backend, f, x)) +function multiderivative(backend::AbstractADType, f, x::Number, args...) + return last(value_and_multiderivative(backend, f, x, args...)) end diff --git a/src/scalar_scalar.jl b/src/scalar_scalar.jl index fbd6f3631..1b3c9c349 100644 --- a/src/scalar_scalar.jl +++ b/src/scalar_scalar.jl @@ -1,29 +1,35 @@ """ - value_and_derivative(backend, f, x) -> (y, der) + value_and_derivative(backend, f, x, [extras]) -> (y, der) Compute the primal value `y = f(x)` and the derivative `der = f'(x)` of a scalar-to-scalar function. """ -function value_and_derivative(implem::AbstractImplem, backend::AbstractADType, f, x::Number) - return value_and_derivative(implem, autodiff_mode(backend), backend, f, x) +function value_and_derivative( + backend::AbstractADType, + f, + x::Number, + extras=nothing, + implem::AbstractImplem=CustomImplem(), +) + return value_and_derivative(backend, f, x, extras, implem, autodiff_mode(backend)) end function value_and_derivative( - ::AbstractImplem, ::ForwardMode, backend::AbstractADType, f, x::Number + backend::AbstractADType, f, x::Number, extras, ::AbstractImplem, ::ForwardMode ) - return value_and_pushforward!(one(x), backend, f, x, one(x)) + return value_and_pushforward!(one(x), backend, f, x, one(x), extras) end function value_and_derivative( - ::AbstractImplem, ::ReverseMode, backend::AbstractADType, f, x::Number + backend::AbstractADType, f, x::Number, extras, ::AbstractImplem, ::ReverseMode ) - return value_and_pullback!(one(x), backend, f, x, one(x)) + return value_and_pullback!(one(x), backend, f, x, one(x), extras) end """ - derivative(backend, f, x) -> der + derivative(backend, f, x, [extras]) -> der Compute the derivative `der = f'(x)` of a scalar-to-scalar function. """ -function derivative(implem::AbstractImplem, backend::AbstractADType, f, x::Number) - return last(value_and_derivative(implem, backend, f, x)) +function derivative(backend::AbstractADType, f, x::Number, args...) + return last(value_and_derivative(backend, f, x, args...)) end diff --git a/test/chainrules_forward.jl b/test/chainrules_forward.jl index f276239ad..3f2d2e549 100644 --- a/test/chainrules_forward.jl +++ b/test/chainrules_forward.jl @@ -1,7 +1,7 @@ using DifferentiationInterface: AutoChainRules, CustomImplem, FallbackImplem using Diffractor: DiffractorRuleConfig -test_pullback(AutoChainRules(DiffractorRuleConfig()), scenarios; type_stability=false); +test_pushforward(AutoChainRules(DiffractorRuleConfig()), scenarios; type_stability=false); test_jacobian_and_friends( CustomImplem(), AutoChainRules(DiffractorRuleConfig()), scenarios; type_stability=false ); diff --git a/test/diffractor.jl b/test/diffractor.jl index 664da1177..ad9c56c57 100644 --- a/test/diffractor.jl +++ b/test/diffractor.jl @@ -1,8 +1,8 @@ using DifferentiationInterface: AutoDiffractor, CustomImplem, FallbackImplem using Diffractor: Diffractor -test_pushforward(AutoDiffractor(), scenarios; type_stability=true); -test_jacobian_and_friends(CustomImplem(), AutoDiffractor(), scenarios; type_stability=true); +test_pushforward(AutoDiffractor(), scenarios; type_stability=false); +test_jacobian_and_friends(CustomImplem(), AutoDiffractor(), scenarios; type_stability=false); test_jacobian_and_friends( - FallbackImplem(), AutoDiffractor(), scenarios; type_stability=true + FallbackImplem(), AutoDiffractor(), scenarios; type_stability=false ); diff --git a/test/runtests.jl b/test/runtests.jl index 64125e0c0..8eb10b614 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,7 +33,7 @@ include("utils.jl"); include("chainrules_reverse.jl") end @testset "Diffractor (forward)" begin - @test_skip include("diffractor.jl") + include("diffractor.jl") end @testset "Enzyme (forward)" begin include("enzyme_forward.jl") diff --git a/test/scenarios.jl b/test/scenarios.jl index 2b4bb8989..2cf4e8813 100644 --- a/test/scenarios.jl +++ b/test/scenarios.jl @@ -95,38 +95,18 @@ rng = StableRNG(63) ## Scenarios f_scalar_scalar(x::Number)::Number = sin(x) + f_scalar_vector(x::Number)::AbstractVector = [sin(x), sin(2x)] f_scalar_matrix(x::Number)::AbstractMatrix = hcat([sin(x) cos(x)], [sin(2x) cos(2x)]) -function f_vector_scalar(x::AbstractVector)::Number - a = eachindex(x) - return sum(sin.(a .* x)) -end - -function f_matrix_scalar(x::AbstractMatrix)::Number - a, b = axes(x) - return sum(sin.(a .* x)) + sum(cos.(transpose(b) .* x)) -end - -function f_vector_vector(x::AbstractVector)::AbstractVector - a = eachindex(x) - return vcat(sin.(a .* x), cos.(a .* x)) -end +f_vector_scalar(x::AbstractVector)::Number = sum(sin, x) +f_matrix_scalar(x::AbstractMatrix)::Number = sum(sin, x) -function f_vector_matrix(x::AbstractVector)::AbstractMatrix - a = eachindex(x) - return hcat(sin.(a .* x), cos.(a .* x)) -end +f_vector_vector(x::AbstractVector)::AbstractVector = vcat(sin.(x), cos.(x)) +f_vector_matrix(x::AbstractVector)::AbstractMatrix = hcat(sin.(x), cos.(x)) -function f_matrix_vector(x::AbstractMatrix)::AbstractVector - a, b = axes(x) - return vcat(vec(sin.(a .* x)), vec(cos.(transpose(b) .* x))) -end - -function f_matrix_matrix(x::AbstractMatrix)::AbstractMatrix - a, b = axes(x) - return hcat(vec(sin.(a .* x)), vec(cos.(transpose(b) .* x))) -end +f_matrix_vector(x::AbstractMatrix)::AbstractVector = vcat(vec(sin.(x)), vec(cos.(x))) +f_matrix_matrix(x::AbstractMatrix)::AbstractMatrix = hcat(vec(sin.(x)), vec(cos.(x))) ## All diff --git a/test/utils.jl b/test/utils.jl index 4b52b574e..052cc0cd8 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -8,6 +8,8 @@ using Test pretty(::CustomImplem) = "custom" pretty(::FallbackImplem) = "fallback" +const NO_EXTRAS = nothing + ## Test utilities function test_pushforward( @@ -145,9 +147,9 @@ function test_derivative( @testset "$X -> $Y" begin (; f, x, y, der_true) = scenario - y_out1, der_out1 = value_and_derivative(implem, backend, f, x) + y_out1, der_out1 = value_and_derivative(backend, f, x, NO_EXTRAS, implem) - der_out2 = derivative(implem, backend, f, x) + der_out2 = derivative(backend, f, x, NO_EXTRAS, implem) @testset "Primal value" begin @test y_out1 ≈ y @@ -157,12 +159,14 @@ function test_derivative( @test der_out2 ≈ der_true rtol = 1e-3 end allocs && @testset "Allocations" begin - @test iszero(@allocated value_and_derivative(implem, backend, f, x)) - @test iszero(@allocated derivative(implem, backend, f, x)) + @test iszero( + @allocated value_and_derivative(backend, f, x, NO_EXTRAS, implem) + ) + @test iszero(@allocated derivative(backend, f, x, NO_EXTRAS, implem)) end type_stability && @testset "Type stability" begin - @test_opt value_and_derivative(implem, backend, f, x) - @test_opt derivative(implem, backend, f, x) + @test_opt value_and_derivative(backend, f, x, NO_EXTRAS, implem) + @test_opt derivative(backend, f, x, NO_EXTRAS, implem) end end end @@ -187,15 +191,19 @@ function test_multiderivative( @testset "$X -> $Y" begin (; f, x, y, multider_true) = scenario - y_out1, multider_out1 = value_and_multiderivative(implem, backend, f, x) + y_out1, multider_out1 = value_and_multiderivative( + backend, f, x, NO_EXTRAS, implem + ) multider_in2 = zero(multider_out1) y_out2, multider_out2 = value_and_multiderivative!( - implem, multider_in2, backend, f, x + multider_in2, backend, f, x, NO_EXTRAS, implem ) - multider_out3 = multiderivative(implem, backend, f, x) + multider_out3 = multiderivative(backend, f, x, NO_EXTRAS, implem) multider_in4 = zero(multider_out3) - multider_out4 = multiderivative!(implem, multider_in4, backend, f, x) + multider_out4 = multiderivative!( + multider_in4, backend, f, x, NO_EXTRAS, implem + ) @testset "Primal value" begin @test y_out1 ≈ y @@ -214,16 +222,22 @@ function test_multiderivative( allocs && @testset "Allocations" begin @test iszero( @allocated value_and_multiderivative!( - implem, multider_in2, backend, f, x + multider_in2, backend, f, x, NO_EXTRAS, implem ) ) @test iszero( - @allocated multiderivative!(implem, multider_in4, backend, f, x) + @allocated multiderivative!( + multider_in4, backend, f, x, NO_EXTRAS, implem + ) ) end type_stability && @testset "Type stability" begin - @test_opt value_and_multiderivative!(implem, multider_in2, backend, f, x) - @test_opt multiderivative!(implem, multider_in4, backend, f, x) + @test_opt value_and_multiderivative!( + multider_in2, backend, f, x, NO_EXTRAS, implem + ) + @test_opt multiderivative!( + multider_in4, backend, f, x, NO_EXTRAS, implem + ) end end end @@ -248,13 +262,15 @@ function test_gradient( @testset "$X -> $Y" begin (; f, x, y, grad_true) = scenario - y_out1, grad_out1 = value_and_gradient(implem, backend, f, x) + y_out1, grad_out1 = value_and_gradient(backend, f, x, NO_EXTRAS, implem) grad_in2 = zero(grad_out1) - y_out2, grad_out2 = value_and_gradient!(implem, grad_in2, backend, f, x) + y_out2, grad_out2 = value_and_gradient!( + grad_in2, backend, f, x, NO_EXTRAS, implem + ) - grad_out3 = gradient(implem, backend, f, x) + grad_out3 = gradient(backend, f, x, NO_EXTRAS, implem) grad_in4 = zero(grad_out3) - grad_out4 = gradient!(implem, grad_in4, backend, f, x) + grad_out4 = gradient!(grad_in4, backend, f, x, NO_EXTRAS, implem) @testset "Primal value" begin @test y_out1 ≈ y @@ -272,13 +288,17 @@ function test_gradient( end allocs && @testset "Allocations" begin @test iszero( - @allocated value_and_gradient!(implem, grad_in2, backend, f, x) + @allocated value_and_gradient!( + grad_in2, backend, f, x, NO_EXTRAS, implem + ) + ) + @test iszero( + @allocated gradient!(grad_in4, backend, f, x, NO_EXTRAS, implem) ) - @test iszero(@allocated gradient!(implem, grad_in4, backend, f, x)) end type_stability && @testset "Type stability" begin - @test_opt value_and_gradient!(implem, grad_in2, backend, f, x) - @test_opt gradient!(implem, grad_in4, backend, f, x) + @test_opt value_and_gradient!(grad_in2, backend, f, x, NO_EXTRAS, implem) + @test_opt gradient!(grad_in4, backend, f, x, NO_EXTRAS, implem) end end end @@ -303,13 +323,15 @@ function test_jacobian( @testset "$X -> $Y" begin (; f, x, y, jac_true) = scenario - y_out1, jac_out1 = value_and_jacobian(implem, backend, f, x) + y_out1, jac_out1 = value_and_jacobian(backend, f, x, NO_EXTRAS, implem) jac_in2 = zero(jac_out1) - y_out2, jac_out2 = value_and_jacobian!(implem, jac_in2, backend, f, x) + y_out2, jac_out2 = value_and_jacobian!( + jac_in2, backend, f, x, NO_EXTRAS, implem + ) - jac_out3 = jacobian(implem, backend, f, x) + jac_out3 = jacobian(backend, f, x, NO_EXTRAS, implem) jac_in4 = zero(jac_out3) - jac_out4 = jacobian!(implem, jac_in4, backend, f, x) + jac_out4 = jacobian!(jac_in4, backend, f, x, NO_EXTRAS, implem) @testset "Primal value" begin @test y_out1 ≈ y @@ -326,12 +348,18 @@ function test_jacobian( end end allocs && @testset "Allocations" begin - @test iszero(@allocated value_and_jacobian!(implem, jac_in2, backend, f, x)) - @test iszero(@allocated jacobian!(implem, jac_in4, backend, f, x)) + @test iszero( + @allocated value_and_jacobian!( + jac_in2, backend, f, x, NO_EXTRAS, implem + ) + ) + @test iszero( + @allocated jacobian!(jac_in4, backend, f, x, NO_EXTRAS, implem) + ) end type_stability && @testset "Type stability" begin - @test_opt value_and_jacobian!(implem, jac_in2, backend, f, x) - @test_opt jacobian!(implem, jac_in4, backend, f, x) + @test_opt value_and_jacobian!(jac_in2, backend, f, x, NO_EXTRAS, implem) + @test_opt jacobian!(jac_in4, backend, f, x, NO_EXTRAS, implem) end end end From d611833343f41a6bd187b541e2a37f84b63e3203 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 11 Mar 2024 17:03:56 +0100 Subject: [PATCH 9/9] Add Diffractor import to benchmarks --- benchmark/benchmarks.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 8020248d0..f3569ae29 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -3,6 +3,7 @@ using BenchmarkTools using DifferentiationInterface using LinearAlgebra +using Diffractor: Diffractor using Enzyme: Enzyme using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff