diff --git a/Project.toml b/Project.toml index 06c8b9268..527505474 100644 --- a/Project.toml +++ b/Project.toml @@ -4,12 +4,15 @@ authors = ["Guillaume Dalle", "Adrian Hill"] version = "0.1.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" @@ -20,15 +23,26 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore" +DifferentiationInterfaceDiffractorExt = [ + "Diffractor", + "AbstractDifferentiation", +] DifferentiationInterfaceEnzymeExt = "Enzyme" DifferentiationInterfaceFiniteDiffExt = "FiniteDiff" DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] -DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"] +DifferentiationInterfacePolyesterForwardDiffExt = [ + "PolyesterForwardDiff", + "ForwardDiff", + "DiffResults", +] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceZygoteExt = ["Zygote"] [compat] +AbstractDifferentiation = "0.6" +ADTypes = "0.2.6" ChainRulesCore = "1.19" +Diffractor = "0.2" DiffResults = "1.1" DocStringExtensions = "0.9" Enzyme = "0.11" diff --git a/README.md b/README.md index f76db1417..70e34e67d 100644 --- a/README.md +++ b/README.md @@ -13,18 +13,21 @@ This package provides a backend-agnostic syntax to differentiate functions `f(x) It started out as an experimental redesign for [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl). -## Example +## Compatibility -```jldoctest -julia> using DifferentiationInterface, Enzyme +We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl): -julia> backend = EnzymeReverseBackend(); +- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) with `AutoEnzyme(Val(:forward))` or `AutoEnzyme(Val(:reverse))` +- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) with `AutoFiniteDiff()` +- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) with `AutoForwardDiff()` +- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) with `AutoPolyesterForwardDiff(; chunksize=C)` +- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) with `AutoReverseDiff()` +- [Zygote.jl](https://github.com/FluxML/Zygote.jl) with `AutoZygote()` -julia> f(x) = sum(abs2, x); +We also support two more backends which are not yet part of ADTypes.jl: -julia> value_and_gradient(backend, f, [1., 2., 3.]) -(14.0, [2.0, 4.0, 6.0]) -``` +- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) with `AutoChainRules(ruleconfig)` +- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) with `AutoDiffractor()` ## Design @@ -40,22 +43,15 @@ From these primitives, several utilities are defined, depending on the type of t | scalar input | derivative | multiderivative | | array input | gradient | jacobian | -## Supported backends - -Forward mode: - -- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) -- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) -- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) -- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) +## Example -Reverse mode: +```jldoctest +julia> import DifferentiationInterface, ADTypes, ForwardDiff -- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) -- [Zygote.jl](https://github.com/FluxML/Zygote.jl) -- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) +julia> backend = ADTypes.AutoForwardDiff(); -Experimental: +julia> f(x) = sum(abs2, x); -- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) -- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) (currently broken due to [#277](https://github.com/JuliaDiff/Diffractor.jl/issues/277)) +julia> DifferentiationInterface.value_and_gradient(backend, f, [1., 2., 3.]) +(14.0, [2.0, 4.0, 6.0]) +``` diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 36f358793..62de913cf 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -1,4 +1,5 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 7c9145c51..f3569ae29 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -1,13 +1,15 @@ +using ADTypes using BenchmarkTools using DifferentiationInterface using LinearAlgebra +using Diffractor: Diffractor using Enzyme: Enzyme using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using PolyesterForwardDiff: PolyesterForwardDiff using ReverseDiff: ReverseDiff -using Zygote: Zygote +using Zygote: Zygote, ZygoteRuleConfig ## Settings @@ -41,38 +43,18 @@ end ## Backends -forward_custom_backends = [ - EnzymeForwardBackend(; custom=true), - FiniteDiffBackend(; custom=true), - ForwardDiffBackend(; custom=true), - PolyesterForwardDiffBackend(4; custom=true), +all_backends = [ + AutoChainRules(ZygoteRuleConfig()), + AutoDiffractor(), + AutoEnzyme(Val(:forward)), + AutoEnzyme(Val(:reverse)), + AutoFiniteDiff(), + AutoForwardDiff(), + AutoPolyesterForwardDiff(; chunksize=4), + AutoReverseDiff(), + AutoZygote(), ] -forward_fallback_backends = [ - EnzymeForwardBackend(; custom=false), - FiniteDiffBackend(; custom=false), - ForwardDiffBackend(; custom=false), -] - -reverse_custom_backends = [ - ZygoteBackend(; custom=true), - EnzymeReverseBackend(; custom=true), - ReverseDiffBackend(; custom=true), -] - -reverse_fallback_backends = [ - ZygoteBackend(; custom=false), - EnzymeReverseBackend(; custom=false), - ReverseDiffBackend(; custom=false), -] - -all_backends = vcat( - forward_custom_backends, - forward_fallback_backends, - reverse_custom_backends, - reverse_fallback_backends, -) - ## Suite function make_suite() @@ -83,11 +65,7 @@ function make_suite() for backend in all_backends add_derivative_benchmarks!(SUITE, backend, scalar_to_scalar, 1, 1) - end - for backend in forward_fallback_backends add_pushforward_benchmarks!(SUITE, backend, scalar_to_scalar, 1, 1) - end - for backend in reverse_fallback_backends add_pullback_benchmarks!(SUITE, backend, scalar_to_scalar, 1, 1) end @@ -97,11 +75,7 @@ function make_suite() for backend in all_backends add_multiderivative_benchmarks!(SUITE, backend, scalar_to_vector, 1, m) - end - for backend in forward_fallback_backends add_pushforward_benchmarks!(SUITE, backend, scalar_to_vector, 1, m) - end - for backend in reverse_fallback_backends add_pullback_benchmarks!(SUITE, backend, scalar_to_vector, 1, m) end end @@ -112,11 +86,7 @@ function make_suite() for backend in all_backends add_gradient_benchmarks!(SUITE, backend, vector_to_scalar, n, 1) - end - for backend in forward_fallback_backends add_pushforward_benchmarks!(SUITE, backend, vector_to_scalar, n, 1) - end - for backend in reverse_fallback_backends add_pullback_benchmarks!(SUITE, backend, vector_to_scalar, n, 1) end end @@ -127,11 +97,7 @@ function make_suite() for backend in all_backends add_jacobian_benchmarks!(SUITE, backend, vector_to_vector, n, m) - end - for backend in forward_fallback_backends add_pushforward_benchmarks!(SUITE, backend, vector_to_vector, n, m) - end - for backend in reverse_fallback_backends add_pullback_benchmarks!(SUITE, backend, vector_to_vector, n, m) end end @@ -144,7 +110,7 @@ include("utils.jl") SUITE = make_suite() # Run benchmarks locally -# results = BenchmarkTools.run(SUITE; verbose=true) +results = BenchmarkTools.run(SUITE; verbose=true) # Compare commits locally # using BenchmarkCI; BenchmarkCI.judge(baseline="origin/main"); BenchmarkCI.displayjudgement() diff --git a/benchmark/utils.jl b/benchmark/utils.jl index cea537f28..bd3ca2c37 100644 --- a/benchmark/utils.jl +++ b/benchmark/utils.jl @@ -1,28 +1,52 @@ -using DifferentiationInterface +using ADTypes +using ADTypes: AbstractADType using BenchmarkTools +using DifferentiationInterface +using DifferentiationInterface: CustomImplem, FallbackImplem, ForwardMode, ReverseMode +using DifferentiationInterface: autodiff_mode + +const NO_EXTRAS = nothing + +## Pretty printing + +pretty(::AutoChainRules{<:ZygoteRuleConfig}) = "ChainRules{Zygote}" +pretty(::AutoDiffractor) = "Diffractor (forward)" +pretty(::AutoEnzyme{Val{:forward}}) = "Enzyme (forward)" +pretty(::AutoEnzyme{Val{:reverse}}) = "Enzyme (reverse)" +pretty(::AutoFiniteDiff) = "FiniteDiff" +pretty(::AutoForwardDiff) = "ForwardDiff" +pretty(::AutoPolyesterForwardDiff) = "PolyesterForwardDiff" +pretty(::AutoReverseDiff) = "ReverseDiff" +pretty(::AutoZygote) = "Zygote" + +pretty(::CustomImplem) = "custom" +pretty(::FallbackImplem) = "fallback" + +## Benchmark suite function add_pushforward_benchmarks!( - suite::BenchmarkGroup, backend::AbstractBackend, f::F, n::Integer, m::Integer + suite::BenchmarkGroup, backend::AbstractADType, f::F, n::Integer, m::Integer ) where {F} x = n == 1 ? randn() : randn(n) dx = n == 1 ? randn() : randn(n) dy = m == 1 ? 0.0 : zeros(m) - if autodiff_mode(backend) != :forward || !handles_types(backend, typeof(x), typeof(dy)) + if !isa(autodiff_mode(backend), ForwardMode) || + !handles_types(backend, typeof(x), typeof(dy)) return nothing end - suite["value_and_pushforward"][(n, m)][string(backend)] = @benchmarkable begin + suite["value_and_pushforward"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin value_and_pushforward($backend, $f, $x, $dx) end - suite["value_and_pushforward!"][(n, m)][string(backend)] = @benchmarkable begin + suite["value_and_pushforward!"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin value_and_pushforward!($dy, $backend, $f, $x, $dx) end - suite["pushforward"][(n, m)][string(backend)] = @benchmarkable begin + suite["pushforward"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin pushforward($backend, $f, $x, $dx) end - suite["pushforward!"][(n, m)][string(backend)] = @benchmarkable begin + suite["pushforward!"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin pushforward!($dy, $backend, $f, $x, $dx) end @@ -30,27 +54,28 @@ function add_pushforward_benchmarks!( end function add_pullback_benchmarks!( - suite::BenchmarkGroup, backend::AbstractBackend, f::F, n::Integer, m::Integer + suite::BenchmarkGroup, backend::AbstractADType, f::F, n::Integer, m::Integer ) where {F} x = n == 1 ? randn() : randn(n) dx = n == 1 ? 0.0 : zeros(n) dy = m == 1 ? randn() : randn(m) - if autodiff_mode(backend) != :reverse || !handles_types(backend, typeof(x), typeof(dy)) + if !isa(autodiff_mode(backend), ReverseMode) || + !handles_types(backend, typeof(x), typeof(dy)) return nothing end - suite["value_and_pullback"][(n, m)][string(backend)] = @benchmarkable begin + suite["value_and_pullback"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin value_and_pullback($backend, $f, $x, $dy) end - suite["value_and_pullback!"][(n, m)][string(backend)] = @benchmarkable begin + suite["value_and_pullback!"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin value_and_pullback!($dx, $backend, $f, $x, $dy) end - suite["pullback"][(n, m)][string(backend)] = @benchmarkable begin + suite["pullback"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin pullback($backend, $f, $x, $dy) end - suite["pullback!"][(n, m)][string(backend)] = @benchmarkable begin + suite["pullback!"][(n, m)]["$(pretty(backend))"] = @benchmarkable begin pullback!($dx, $backend, $f, $x, $dy) end @@ -58,7 +83,7 @@ function add_pullback_benchmarks!( end function add_derivative_benchmarks!( - suite::BenchmarkGroup, backend::AbstractBackend, f::F, n::Integer, m::Integer + suite::BenchmarkGroup, backend::AbstractADType, f::F, n::Integer, m::Integer ) where {F} @assert n == m == 1 if !handles_types(backend, Number, Number) @@ -67,19 +92,23 @@ function add_derivative_benchmarks!( x = randn() - suite["value_and_derivative"][(1, 1)][string(backend)] = @benchmarkable begin - value_and_derivative($backend, $f, $x) - end + for implem in (CustomImplem(), FallbackImplem()) + backend_implem = "$(pretty(backend)) - $(pretty(implem))" + + suite["value_and_derivative"][(1, 1)][backend_implem] = @benchmarkable begin + value_and_derivative($backend, $f, $x, $NO_EXTRAS, $implem) + end - suite["derivative"][(1, 1)][string(backend)] = @benchmarkable begin - derivative($backend, $f, $x) + suite["derivative"][(1, 1)][backend_implem] = @benchmarkable begin + derivative($backend, $f, $x, $NO_EXTRAS, $implem) + end end return nothing end function add_multiderivative_benchmarks!( - suite::BenchmarkGroup, backend::AbstractBackend, f::F, n::Integer, m::Integer + suite::BenchmarkGroup, backend::AbstractADType, f::F, n::Integer, m::Integer ) where {F} @assert n == 1 if !handles_types(backend, Number, Vector) @@ -89,25 +118,29 @@ function add_multiderivative_benchmarks!( x = randn() multider = zeros(m) - suite["value_and_multiderivative"][(1, m)][string(backend)] = @benchmarkable begin - value_and_multiderivative($backend, $f, $x) - end - suite["value_and_multiderivative!"][(1, m)][string(backend)] = @benchmarkable begin - value_and_multiderivative!($multider, $backend, $f, $x) - end + for implem in (CustomImplem(), FallbackImplem()) + backend_implem = "$(pretty(backend)) - $(pretty(implem))" - suite["multiderivative"][(1, m)][string(backend)] = @benchmarkable begin - multiderivative($backend, $f, $x) - end - suite["multiderivative!"][(1, m)][string(backend)] = @benchmarkable begin - multiderivative!($multider, $backend, $f, $x) + suite["value_and_multiderivative"][(1, m)][backend_implem] = @benchmarkable begin + value_and_multiderivative($backend, $f, $x, $NO_EXTRAS, $implem) + end + suite["value_and_multiderivative!"][(1, m)][backend_implem] = @benchmarkable begin + value_and_multiderivative!($multider, $backend, $f, $x, $NO_EXTRAS, $implem) + end + + suite["multiderivative"][(1, m)][backend_implem] = @benchmarkable begin + multiderivative($backend, $f, $x, $NO_EXTRAS, $implem) + end + suite["multiderivative!"][(1, m)][backend_implem] = @benchmarkable begin + multiderivative!($multider, $backend, $f, $x, $NO_EXTRAS, $implem) + end end return nothing end function add_gradient_benchmarks!( - suite::BenchmarkGroup, backend::AbstractBackend, f::F, n::Integer, m::Integer + suite::BenchmarkGroup, backend::AbstractADType, f::F, n::Integer, m::Integer ) where {F} @assert m == 1 if !handles_types(backend, Vector, Number) @@ -117,25 +150,29 @@ function add_gradient_benchmarks!( x = randn(n) grad = zeros(n) - suite["value_and_gradient"][(n, 1)][string(backend)] = @benchmarkable begin - value_and_gradient($backend, $f, $x) - end - suite["value_and_gradient!"][(n, 1)][string(backend)] = @benchmarkable begin - value_and_gradient!($grad, $backend, $f, $x) - end + for implem in (CustomImplem(), FallbackImplem()) + backend_implem = "$(pretty(backend)) - $(pretty(implem))" - suite["gradient"][(n, 1)][string(backend)] = @benchmarkable begin - gradient($backend, $f, $x) - end - suite["gradient!"][(n, 1)][string(backend)] = @benchmarkable begin - gradient!($grad, $backend, $f, $x) + suite["value_and_gradient"][(n, 1)][backend_implem] = @benchmarkable begin + value_and_gradient($backend, $f, $x, $NO_EXTRAS, $implem) + end + suite["value_and_gradient!"][(n, 1)][backend_implem] = @benchmarkable begin + value_and_gradient!($grad, $backend, $f, $x, $NO_EXTRAS, $implem) + end + + suite["gradient"][(n, 1)][backend_implem] = @benchmarkable begin + gradient($backend, $f, $x, $NO_EXTRAS, $implem) + end + suite["gradient!"][(n, 1)][backend_implem] = @benchmarkable begin + gradient!($grad, $backend, $f, $x, $NO_EXTRAS, $implem) + end end return nothing end function add_jacobian_benchmarks!( - suite::BenchmarkGroup, backend::AbstractBackend, f::F, n::Integer, m::Integer + suite::BenchmarkGroup, backend::AbstractADType, f::F, n::Integer, m::Integer ) where {F} if !handles_types(backend, Vector, Vector) return nothing @@ -144,18 +181,22 @@ function add_jacobian_benchmarks!( x = randn(n) jac = zeros(m, n) - suite["value_and_jacobian"][(n, m)][string(backend)] = @benchmarkable begin - value_and_jacobian($backend, $f, $x) - end - suite["value_and_jacobian!"][(n, m)][string(backend)] = @benchmarkable begin - value_and_jacobian!($jac, $backend, $f, $x) - end - - suite["jacobian"][(n, m)][string(backend)] = @benchmarkable begin - jacobian($backend, $f, $x) - end - suite["jacobian!"][(n, m)][string(backend)] = @benchmarkable begin - jacobian!($jac, $backend, $f, $x) + for implem in (CustomImplem(), FallbackImplem()) + backend_implem = "$(pretty(backend)) - $(pretty(implem))" + + suite["value_and_jacobian"][(n, m)][backend_implem] = @benchmarkable begin + value_and_jacobian($backend, $f, $x, $NO_EXTRAS, $implem) + end + suite["value_and_jacobian!"][(n, m)][backend_implem] = @benchmarkable begin + value_and_jacobian!($jac, $backend, $f, $x, $NO_EXTRAS, $implem) + end + + suite["jacobian"][(n, m)][backend_implem] = @benchmarkable begin + jacobian($backend, $f, $x, $NO_EXTRAS, $implem) + end + suite["jacobian!"][(n, m)][backend_implem] = @benchmarkable begin + jacobian!($jac, $backend, $f, $x, $NO_EXTRAS, $implem) + end end return nothing diff --git a/docs/Project.toml b/docs/Project.toml index 7850516f4..7d881d2b2 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,7 +1,9 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" diff --git a/docs/make.jl b/docs/make.jl index 16c3a5566..d22311330 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -3,7 +3,7 @@ using DifferentiationInterface import DifferentiationInterface as DI using Documenter -using DiffResults: DiffResults +using ADTypes using Enzyme: Enzyme using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff @@ -24,7 +24,7 @@ ZygoteExt = get_extension(DI, :DifferentiationInterfaceZygoteExt) DocMeta.setdocmeta!( DifferentiationInterface, :DocTestSetup, - :(using DifferentiationInterface); + :(using DifferentiationInterface, ADTypes); recursive=true, ) @@ -45,6 +45,7 @@ end makedocs(; modules=[ DifferentiationInterface, + ADTypes, ChainRulesCoreExt, EnzymeExt, FiniteDiffExt, @@ -58,10 +59,9 @@ makedocs(; format=Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", canonical="https://gdalle.github.io/DifferentiationInterface.jl", + edit_link="main", ), - pages=[ - "Home" => "index.md", "Interface" => "interface.md", "Backends" => "backends.md" - ], + pages=["Home" => "index.md", "Interface" => "interface.md"], ) deploydocs(; repo="github.com/gdalle/DifferentiationInterface.jl", devbranch="main") diff --git a/docs/src/backends.md b/docs/src/backends.md deleted file mode 100644 index df9b2b271..000000000 --- a/docs/src/backends.md +++ /dev/null @@ -1,73 +0,0 @@ -```@meta -CollapsedDocStrings = true -``` - -# Backends - -## ChainRulesCore - -```@docs -ChainRulesForwardBackend -ChainRulesReverseBackend -``` - -```@autodocs -Modules = [ChainRulesCoreExt] -``` - -## Enzyme - -```@docs -EnzymeForwardBackend -EnzymeReverseBackend -``` - -```@autodocs -Modules = [EnzymeExt] -``` - -## FiniteDiff - -```@docs -FiniteDiffBackend -``` - -```@autodocs -Modules = [FiniteDiffExt] -``` - -## ForwardDiff - -```@docs -ForwardDiffBackend -``` - -```@autodocs -Modules = [ForwardDiffExt] -``` - -## PolyesterForwardDiff - -```@docs -PolyesterForwardDiffBackend -``` - -```@autodocs -Modules = [PolyesterForwardDiffExt] -``` - -## ReverseDiff - -```@docs -ReverseDiffBackend -``` - -```@autodocs -Modules = [ReverseDiffExt] -``` - -## Zygote - -```@autodocs -Modules = [ZygoteExt] -``` diff --git a/docs/src/index.md b/docs/src/index.md deleted file mode 100644 index 9cdd27b68..000000000 --- a/docs/src/index.md +++ /dev/null @@ -1,65 +0,0 @@ -```@meta -EditURL = "https://github.com/gdalle/DifferentiationInterface.jl/blob/main/README.md" -``` - -# DifferentiationInterface - -[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://gdalle.github.io/DifferentiationInterface.jl/dev/) -[![Build Status](https://github.com/gdalle/DifferentiationInterface.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/gdalle/DifferentiationInterface.jl/actions/workflows/CI.yml?query=branch%3Amain) -[![Coverage](https://codecov.io/gh/gdalle/DifferentiationInterface.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/gdalle/DifferentiationInterface.jl) -[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) - -An interface to various automatic differentiation backends in Julia. - -## Goal - -This package provides a backend-agnostic syntax to differentiate functions `f(x) = y`, where `x` and `y` are either numbers or abstract arrays. - -It started out as an experimental redesign for [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl). - -## Example - -```jldoctest -julia> using DifferentiationInterface, Enzyme - -julia> backend = EnzymeReverseBackend(); - -julia> f(x) = sum(abs2, x); - -julia> value_and_gradient(backend, f, [1., 2., 3.]) -(14.0, [2.0, 4.0, 6.0]) -``` - -## Design - -Each backend must implement only one primitive: - -- forward mode: the pushforward, computing a Jacobian-vector product -- reverse mode: the pullback, computing a vector-Jacobian product - -From these primitives, several utilities are defined, depending on the type of the input and output: - -| | scalar output | array output | -| ------------ | ------------- | --------------- | -| scalar input | derivative | multiderivative | -| array input | gradient | jacobian | - -## Supported backends - -Forward mode: - -- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) -- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) -- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) -- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) - -Reverse mode: - -- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) -- [Zygote.jl](https://github.com/FluxML/Zygote.jl) -- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) - -Experimental: - -- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) -- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) (currently broken due to [#277](https://github.com/JuliaDiff/Diffractor.jl/issues/277)) diff --git a/docs/src/interface.md b/docs/src/interface.md index 6f25db112..3a1dbef9d 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -55,16 +55,42 @@ Modules = [DifferentiationInterface] Pages = ["pullback.jl"] ``` -## Abstract backends +## Backends + +### ADTypes.jl + +The following backends are defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl): + +```@docs +AbstractADType +``` + +Only a subset is supported by DifferentiationInterface.jl at the moment. + +### DifferentiationInterface.jl + +The following backends are defined by DifferentiationInterface.jl: + +```@autodocs +Modules = [DifferentiationInterface] +Pages = ["backends.jl"] +Order = [:type] +Private = false +``` + +### Input / output types ```@autodocs Modules = [DifferentiationInterface] -Pages = ["backends_abstract.jl"] +Pages = ["backends.jl"] +Order = [:function] +Private = false ``` ## Internals ```@autodocs Modules = [DifferentiationInterface] -Pages = ["utils.jl"] +Pages = ["implem.jl", "mode.jl", "utils.jl", "backends.jl"] +Public = false ``` diff --git a/ext/DifferentiationInterfaceChainRulesCoreExt.jl b/ext/DifferentiationInterfaceChainRulesCoreExt.jl index 5e116c557..c2d212ed8 100644 --- a/ext/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/ext/DifferentiationInterfaceChainRulesCoreExt.jl @@ -2,72 +2,48 @@ module DifferentiationInterfaceChainRulesCoreExt using ChainRulesCore: HasForwardsMode, HasReverseMode, NoTangent, RuleConfig, frule_via_ad, rrule_via_ad -using DifferentiationInterface: ChainRulesForwardBackend, ChainRulesReverseBackend +using DifferentiationInterface: AutoChainRules, CustomImplem, update! import DifferentiationInterface as DI using DocStringExtensions -ruleconfig(backend::ChainRulesForwardBackend) = backend.ruleconfig -ruleconfig(backend::ChainRulesReverseBackend) = backend.ruleconfig +ruleconfig(backend::AutoChainRules) = backend.ruleconfig -update!(_old::Number, new::Number) = new -update!(old, new) = old .= new +const AutoForwardChainRules = AutoChainRules{<:RuleConfig{>:HasForwardsMode}} +const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}} -## Backend construction - -""" - ChainRulesForwardBackend(rc::RuleConfig; custom=true) - -Construct a [`ChainRulesForwardBackend`](@ref) from a [`RuleConfig`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/superpowers/ruleconfig.html) object that `HasForwardsMode`. - -## Example - -```julia -using Diffractor, DifferentiationInterface -backend = ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()) -``` -""" -function DI.ChainRulesForwardBackend(rc::RuleConfig{>:HasForwardsMode}; custom::Bool=true) - return ChainRulesForwardBackend{custom,typeof(rc)}(rc) -end - -""" - ChainRulesReverseBackend(rc::RuleConfig; custom=true) - -Construct a [`ChainRulesReverseBackend`](@ref) from a [`RuleConfig`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/superpowers/ruleconfig.html) object that `HasReverseMode`. - -## Example - -```julia -using Zygote, DifferentiationInterface -backend = ChainRulesReverseBackend(Zygote.ZygoteRuleConfig()) -``` -""" -function DI.ChainRulesReverseBackend(rc::RuleConfig{>:HasReverseMode}; custom::Bool=true) - return ChainRulesReverseBackend{custom,typeof(rc)}(rc) -end +DI.autodiff_mode(::AutoForwardChainRules) = DI.ForwardMode() +DI.autodiff_mode(::AutoReverseChainRules) = DI.ReverseMode() ## Primitives -function DI.value_and_pushforward(backend::ChainRulesForwardBackend, f, x, dx) +function DI.value_and_pushforward( + backend::AutoForwardChainRules, f, x, dx, extras::Nothing=nothing +) rc = ruleconfig(backend) y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x) return y, new_dy end -function DI.value_and_pushforward!(dy, backend::ChainRulesForwardBackend, f, x, dx) - y, new_dy = DI.value_and_pushforward(backend, f, x, dx) +function DI.value_and_pushforward!( + dy, backend::AutoForwardChainRules, f, x, dx, extras=nothing +) + y, new_dy = DI.value_and_pushforward(backend, f, x, dx, extras) return y, update!(dy, new_dy) end -function DI.value_and_pullback(backend::ChainRulesReverseBackend, f, x, dy) +function DI.value_and_pullback( + backend::AutoReverseChainRules, f, x, dy, extras::Nothing=nothing +) rc = ruleconfig(backend) y, pullback = rrule_via_ad(rc, f, x) _, new_dx = pullback(dy) return y, new_dx end -function DI.value_and_pullback!(dx, backend::ChainRulesReverseBackend, f, x, dy) - y, new_dx = DI.value_and_pullback(backend, f, x, dy) +function DI.value_and_pullback!( + dx, backend::AutoReverseChainRules, f, x, dy, extras=nothing +) + y, new_dx = DI.value_and_pullback(backend, f, x, dy, extras) return y, update!(dx, new_dx) end diff --git a/ext/DifferentiationInterfaceDiffractorExt.jl b/ext/DifferentiationInterfaceDiffractorExt.jl new file mode 100644 index 000000000..333fdb167 --- /dev/null +++ b/ext/DifferentiationInterfaceDiffractorExt.jl @@ -0,0 +1,24 @@ +module DifferentiationInterfaceDiffractorExt + +import AbstractDifferentiation as AD # public API for Diffractor +using DifferentiationInterface: AutoChainRules, AutoDiffractor, update! +import DifferentiationInterface as DI +using Diffractor: DiffractorForwardBackend, DiffractorRuleConfig +using DocStringExtensions + +DI.autodiff_mode(::AutoDiffractor) = DI.ForwardMode() +DI.autodiff_mode(::AutoChainRules{<:DiffractorRuleConfig}) = DI.ForwardMode() + +function DI.value_and_pushforward(::AutoDiffractor, f, x, dx, extras::Nothing=nothing) + vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x) + y, dy = vpff((dx,)) + return y, dy +end + +function DI.value_and_pushforward!(dy, ::AutoDiffractor, f, x, dx, extras::Nothing=nothing) + vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x) + y, new_dy = vpff((dx,)) + return y, update!(dy, new_dy) +end + +end diff --git a/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index fcbeda077..a7c7f1688 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -1,6 +1,7 @@ module DifferentiationInterfaceEnzymeExt -using DifferentiationInterface: EnzymeForwardBackend, EnzymeReverseBackend +using ADTypes: AutoEnzyme +using DifferentiationInterface: CustomImplem import DifferentiationInterface as DI using DocStringExtensions using Enzyme: @@ -15,9 +16,7 @@ using Enzyme: jacobian # Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type -function DI.basisarray( - ::Union{EnzymeForwardBackend,EnzymeReverseBackend}, a::AbstractArray{T}, i -) where {T} +function DI.basisarray(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T} b = zero(a) b[i] = one(T) return b diff --git a/ext/DifferentiationInterfaceEnzymeExt/forward.jl b/ext/DifferentiationInterfaceEnzymeExt/forward.jl index 95aadfa51..d4f06a2a6 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/forward.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/forward.jl @@ -1,24 +1,17 @@ - -## Backend construction - -""" - EnzymeForwardBackend(; custom=true) - -Construct a [`EnzymeForwardBackend`](@ref). -""" -DI.EnzymeForwardBackend(; custom::Bool=true) = EnzymeForwardBackend{custom}() +const AutoForwardEnzyme = AutoEnzyme{Val{:forward}} +DI.autodiff_mode(::AutoForwardEnzyme) = DI.ForwardMode() ## Primitives function DI.value_and_pushforward!( - _dy::Y, ::EnzymeForwardBackend, f, x::X, dx + _dy::Y, ::AutoForwardEnzyme, f, x::X, dx, extras::Nothing=nothing ) where {X,Y<:Real} y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx)) return y, new_dy end function DI.value_and_pushforward!( - dy::Y, ::EnzymeForwardBackend, f, x::X, dx + dy::Y, ::AutoForwardEnzyme, f, x::X, dx, extras::Nothing=nothing ) where {X,Y<:AbstractArray} y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx)) dy .= new_dy @@ -27,7 +20,13 @@ end ## Utilities -function DI.value_and_jacobian(::EnzymeForwardBackend{true}, f, x::AbstractArray) +function DI.value_and_jacobian( + ::AutoForwardEnzyme, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), +) y = f(x) jac = jacobian(Forward, f, x) # see https://github.com/EnzymeAD/Enzyme.jl/issues/1332 diff --git a/ext/DifferentiationInterfaceEnzymeExt/reverse.jl b/ext/DifferentiationInterfaceEnzymeExt/reverse.jl index 053bd3e37..2702aecbc 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/reverse.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/reverse.jl @@ -1,17 +1,11 @@ - -## Backend construction - -""" - EnzymeReverseBackend(; custom=true) - -Construct a [`EnzymeReverseBackend`](@ref). -""" -DI.EnzymeReverseBackend(; custom::Bool=true) = EnzymeReverseBackend{custom}() +const AutoReverseEnzyme = AutoEnzyme{Val{:reverse}} +DI.autodiff_mode(::AutoReverseEnzyme) = DI.ReverseMode() +DI.handles_output_type(::AutoReverseEnzyme, ::Type{<:AbstractArray}) = false ## Primitives function DI.value_and_pullback!( - _dx, ::EnzymeReverseBackend, f, x::X, dy::Y + _dx, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing ) where {X<:Number,Y<:Union{Real,Nothing}} der, y = autodiff(ReverseWithPrimal, f, Active, Active(x)) new_dx = dy * only(der) @@ -19,7 +13,7 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - dx::X, ::EnzymeReverseBackend, f, x::X, dy::Y + dx::X, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing ) where {X<:AbstractArray,Y<:Union{Real,Nothing}} dx .= zero(eltype(dx)) _, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx)) @@ -29,14 +23,25 @@ end ## Utilities -function DI.value_and_gradient(::EnzymeReverseBackend{true}, f, x::AbstractArray) +function DI.value_and_gradient( + ::AutoReverseEnzyme, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), +) y = f(x) grad = gradient(Reverse, f, x) return y, grad end function DI.value_and_gradient!( - grad::AbstractArray, ::EnzymeReverseBackend{true}, f, x::AbstractArray + grad::AbstractArray, + ::AutoReverseEnzyme, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) y = f(x) gradient!(Reverse, grad, f, x) diff --git a/ext/DifferentiationInterfaceFiniteDiffExt.jl b/ext/DifferentiationInterfaceFiniteDiffExt.jl index a1c18c3fe..0f4b47cb7 100644 --- a/ext/DifferentiationInterfaceFiniteDiffExt.jl +++ b/ext/DifferentiationInterfaceFiniteDiffExt.jl @@ -1,6 +1,7 @@ module DifferentiationInterfaceFiniteDiffExt -using DifferentiationInterface: FiniteDiffBackend +using ADTypes: AutoFiniteDiff +using DifferentiationInterface: CustomImplem import DifferentiationInterface as DI using DocStringExtensions using FiniteDiff: @@ -14,24 +15,11 @@ using LinearAlgebra: dot, mul! const FUNCTION_INPLACE = Val{true} const FUNCTION_NOT_INPLACE = Val{false} -## Backend construction - -""" - FiniteDiffBackend(::Type{fdtype}=Val{:central}; custom=true) - -Construct a [`FiniteDiffBackend`](@ref) with any finite difference type `fdtype` (`Val{:forward}` or `Val{:central}`). -""" -function DI.FiniteDiffBackend( - ::Type{fdtype}=Val{:central}; custom::Bool=true -) where {fdtype} - return FiniteDiffBackend{custom,fdtype}() -end - ## Primitives function DI.value_and_pushforward!( - dy::Y, ::FiniteDiffBackend{custom,fdtype}, f, x, dx -) where {Y<:Number,custom,fdtype} + dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing +) where {Y<:Number,fdtype} y = f(x) step(t::Number)::Number = f(x .+ t .* dx) new_dy = finite_difference_derivative(step, zero(eltype(dx)), fdtype, eltype(y), y) @@ -39,8 +27,8 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::Y, ::FiniteDiffBackend{custom,fdtype}, f, x, dx -) where {Y<:AbstractArray,custom,fdtype} + dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing +) where {Y<:AbstractArray,fdtype} y = f(x) step(t::Number)::AbstractArray = f(x .+ t .* dx) finite_difference_gradient!( @@ -52,7 +40,11 @@ end ## Utilities function DI.value_and_derivative( - ::FiniteDiffBackend{true,fdtype}, f, x::Number + ::AutoFiniteDiff{fdtype}, + f, + x::Number, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {fdtype} y = f(x) der = finite_difference_derivative(f, x, fdtype, eltype(y), y) @@ -60,7 +52,12 @@ function DI.value_and_derivative( end function DI.value_and_multiderivative!( - multider::AbstractArray, ::FiniteDiffBackend{true,fdtype}, f, x::Number + multider::AbstractArray, + ::AutoFiniteDiff{fdtype}, + f, + x::Number, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {fdtype} y = f(x) finite_difference_gradient!(multider, f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) @@ -68,7 +65,11 @@ function DI.value_and_multiderivative!( end function DI.value_and_multiderivative( - ::FiniteDiffBackend{true,fdtype}, f, x::Number + ::AutoFiniteDiff{fdtype}, + f, + x::Number, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {fdtype} y = f(x) multider = finite_difference_gradient(f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) @@ -76,7 +77,12 @@ function DI.value_and_multiderivative( end function DI.value_and_gradient!( - grad::AbstractArray, ::FiniteDiffBackend{true,fdtype}, f, x::AbstractArray + grad::AbstractArray, + ::AutoFiniteDiff{fdtype}, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {fdtype} y = f(x) finite_difference_gradient!(grad, f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) @@ -84,7 +90,11 @@ function DI.value_and_gradient!( end function DI.value_and_gradient( - ::FiniteDiffBackend{true,fdtype}, f, x::AbstractArray + ::AutoFiniteDiff{fdtype}, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {fdtype} y = f(x) grad = finite_difference_gradient(f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) @@ -92,7 +102,11 @@ function DI.value_and_gradient( end function DI.value_and_jacobian( - ::FiniteDiffBackend{true,fdtype}, f, x::AbstractArray + ::AutoFiniteDiff{fdtype}, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {fdtype} y = f(x) jac = finite_difference_jacobian(f, x, fdtype, eltype(y)) @@ -100,9 +114,14 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - jac::AbstractMatrix, backend::FiniteDiffBackend{true}, f, x::AbstractArray + jac::AbstractMatrix, + backend::AutoFiniteDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, + implem::CustomImplem=CustomImplem(), ) - y, new_jac = DI.value_and_jacobian(backend, f, x) + y, new_jac = DI.value_and_jacobian(backend, f, x, extras, implem) jac .= new_jac return y, jac end diff --git a/ext/DifferentiationInterfaceForwardDiffExt.jl b/ext/DifferentiationInterfaceForwardDiffExt.jl index 367fca5d9..8e66b76db 100644 --- a/ext/DifferentiationInterfaceForwardDiffExt.jl +++ b/ext/DifferentiationInterfaceForwardDiffExt.jl @@ -1,6 +1,7 @@ module DifferentiationInterfaceForwardDiffExt -using DifferentiationInterface: ForwardDiffBackend +using ADTypes: AutoForwardDiff +using DifferentiationInterface: CustomImplem import DifferentiationInterface as DI using DiffResults: DiffResults using DocStringExtensions @@ -18,19 +19,10 @@ using ForwardDiff: value using LinearAlgebra: mul! -## Backend construction - -""" - ForwardDiffBackend(; custom=true) - -Construct a [`ForwardDiffBackend`](@ref). -""" -DI.ForwardDiffBackend(; custom::Bool=true) = ForwardDiffBackend{custom}() - ## Primitives function DI.value_and_pushforward!( - _dy::Y, ::ForwardDiffBackend, f, x::X, dx + _dy::Y, ::AutoForwardDiff, f, x::X, dx, extras::Nothing=nothing ) where {X<:Real,Y<:Real} T = typeof(Tag(f, X)) xdual = Dual{T}(x, dx) @@ -41,7 +33,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::Y, ::ForwardDiffBackend, f, x::X, dx + dy::Y, ::AutoForwardDiff, f, x::X, dx, extras::Nothing=nothing ) where {X<:Real,Y<:AbstractArray} T = typeof(Tag(f, X)) xdual = Dual{T}(x, dx) @@ -52,7 +44,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - _dy::Y, ::ForwardDiffBackend, f, x::X, dx + _dy::Y, ::AutoForwardDiff, f, x::X, dx, extras::Nothing=nothing ) where {X<:AbstractArray,Y<:Real} T = typeof(Tag(f, X)) # TODO: unsure xdual = Dual{T}.(x, dx) # TODO: allocation @@ -63,7 +55,7 @@ function DI.value_and_pushforward!( end function DI.value_and_pushforward!( - dy::Y, ::ForwardDiffBackend, f, x::X, dx + dy::Y, ::AutoForwardDiff, f, x::X, dx, extras::Nothing=nothing ) where {X<:AbstractArray,Y<:AbstractArray} T = typeof(Tag(f, X)) # TODO: unsure xdual = Dual{T}.(x, dx) # TODO: allocation @@ -75,48 +67,71 @@ end ## Utilities (TODO: use DiffResults) -function DI.value_and_derivative(::ForwardDiffBackend{true}, f, x::Number) +function DI.value_and_derivative( + ::AutoForwardDiff, f, x::Number, extras::Nothing, ::CustomImplem=CustomImplem() +) y = f(x) der = derivative(f, x) return y, der end -function DI.value_and_multiderivative(::ForwardDiffBackend{true}, f, x::Number) +function DI.value_and_multiderivative( + ::AutoForwardDiff, f, x::Number, extras::Nothing, ::CustomImplem=CustomImplem() +) y = f(x) multider = derivative(f, x) return y, multider end function DI.value_and_multiderivative!( - multider::AbstractArray, ::ForwardDiffBackend{true}, f, x::Number + multider::AbstractArray, + ::AutoForwardDiff, + f, + x::Number, + extras::Nothing, + ::CustomImplem=CustomImplem(), ) y = f(x) derivative!(multider, f, x) return y, multider end -function DI.value_and_gradient(::ForwardDiffBackend{true}, f, x::AbstractArray) +function DI.value_and_gradient( + ::AutoForwardDiff, f, x::AbstractArray, extras::Nothing, ::CustomImplem=CustomImplem() +) y = f(x) grad = gradient(f, x) return y, grad end function DI.value_and_gradient!( - grad::AbstractArray, ::ForwardDiffBackend{true}, f, x::AbstractArray + grad::AbstractArray, + ::AutoForwardDiff, + f, + x::AbstractArray, + extras::Nothing, + ::CustomImplem=CustomImplem(), ) y = f(x) gradient!(grad, f, x) return y, grad end -function DI.value_and_jacobian(::ForwardDiffBackend{true}, f, x::AbstractArray) +function DI.value_and_jacobian( + ::AutoForwardDiff, f, x::AbstractArray, extras::Nothing, ::CustomImplem=CustomImplem() +) y = f(x) jac = jacobian(f, x) return y, jac end function DI.value_and_jacobian!( - jac::AbstractMatrix, ::ForwardDiffBackend{true}, f, x::AbstractArray + jac::AbstractMatrix, + ::AutoForwardDiff, + f, + x::AbstractArray, + extras::Nothing, + ::CustomImplem=CustomImplem(), ) y = f(x) jacobian!(jac, f, x) diff --git a/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl b/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl index 25d54d6f9..074e96091 100644 --- a/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -1,6 +1,7 @@ module DifferentiationInterfacePolyesterForwardDiffExt -using DifferentiationInterface: ForwardDiffBackend, PolyesterForwardDiffBackend +using ADTypes: AutoPolyesterForwardDiff, AutoForwardDiff +using DifferentiationInterface: CustomImplem import DifferentiationInterface as DI using DiffResults: DiffResults using DocStringExtensions @@ -8,29 +9,25 @@ using ForwardDiff: Chunk using LinearAlgebra: mul! using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian! -## Backend construction - -""" - PolyesterForwardDiffBackend(C; custom=true) - -Construct a [`PolyesterForwardDiffBackend`](@ref) with chunk size `C`. -""" -function DI.PolyesterForwardDiffBackend(C::Integer; custom::Bool=true) - return PolyesterForwardDiffBackend{custom,C}() -end - ## Primitives function DI.value_and_pushforward!( - dy, ::PolyesterForwardDiffBackend{custom}, f, x, dx -) where {custom} - return DI.value_and_pushforward!(dy, ForwardDiffBackend{custom}(), f, x, dx) + dy, ::AutoPolyesterForwardDiff{C}, f, x, dx, extras::Nothing=nothing +) where {C} + return DI.value_and_pushforward!( + dy, AutoForwardDiff{C,Nothing}(nothing), f, x, dx, extras + ) end ## Utilities function DI.value_and_gradient!( - grad::AbstractArray, ::PolyesterForwardDiffBackend{true,C}, f, x::AbstractArray + grad::AbstractArray, + ::AutoPolyesterForwardDiff{C}, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {C} y = f(x) threaded_gradient!(f, grad, x, Chunk{C}()) @@ -38,7 +35,12 @@ function DI.value_and_gradient!( end function DI.value_and_jacobian!( - jac::AbstractMatrix, ::PolyesterForwardDiffBackend{true,C}, f, x::AbstractArray + jac::AbstractMatrix, + ::AutoPolyesterForwardDiff{C}, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) where {C} y = f(x) threaded_jacobian!(f, jac, x, Chunk{C}()) diff --git a/ext/DifferentiationInterfaceReverseDiffExt.jl b/ext/DifferentiationInterfaceReverseDiffExt.jl index 99c5b325c..c5061584b 100644 --- a/ext/DifferentiationInterfaceReverseDiffExt.jl +++ b/ext/DifferentiationInterfaceReverseDiffExt.jl @@ -1,25 +1,21 @@ module DifferentiationInterfaceReverseDiffExt -using DifferentiationInterface: ReverseDiffBackend +using ADTypes: AutoReverseDiff +using DifferentiationInterface: CustomImplem import DifferentiationInterface as DI using DiffResults: DiffResults using DocStringExtensions using LinearAlgebra: mul! using ReverseDiff: gradient, gradient!, jacobian, jacobian! -## Backend construction +## Limitations -""" - ReverseDiffBackend(; custom) - -Construct a [`ReverseDiffBackend`](@ref). -""" -DI.ReverseDiffBackend(; custom::Bool=true) = ReverseDiffBackend{custom}() +DI.handles_input_type(::AutoReverseDiff, ::Type{<:Number}) = false ## Primitives function DI.value_and_pullback!( - dx, ::ReverseDiffBackend, f, x::X, dy::Y + dx, ::AutoReverseDiff, f, x::X, dy::Y, extras::Nothing=nothing ) where {X<:AbstractArray,Y<:Real} res = DiffResults.DiffResult(zero(Y), dx) res = gradient!(res, f, x) @@ -29,7 +25,7 @@ function DI.value_and_pullback!( end function DI.value_and_pullback!( - dx, ::ReverseDiffBackend, f, x::X, dy::Y + dx, ::AutoReverseDiff, f, x::X, dy::Y, extras::Nothing=nothing ) where {X<:AbstractArray,Y<:AbstractArray} res = DiffResults.DiffResult(similar(dy), similar(dy, length(dy), length(x))) res = jacobian!(res, f, x) @@ -41,28 +37,50 @@ end ## Utilities (TODO: use DiffResults) -function DI.value_and_gradient(::ReverseDiffBackend{true}, f, x::AbstractArray) +function DI.value_and_gradient( + ::AutoReverseDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), +) y = f(x) grad = gradient(f, x) return y, grad end function DI.value_and_gradient!( - grad::AbstractArray, ::ReverseDiffBackend{true}, f, x::AbstractArray + grad::AbstractArray, + ::AutoReverseDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) y = f(x) gradient!(grad, f, x) return y, grad end -function DI.value_and_jacobian(::ReverseDiffBackend{true}, f, x::AbstractArray) +function DI.value_and_jacobian( + ::AutoReverseDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), +) y = f(x) jac = jacobian(f, x) return y, jac end function DI.value_and_jacobian!( - jac::AbstractMatrix, ::ReverseDiffBackend{true}, f, x::AbstractArray + jac::AbstractMatrix, + ::AutoReverseDiff, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), ) y = f(x) jacobian!(jac, f, x) diff --git a/ext/DifferentiationInterfaceZygoteExt.jl b/ext/DifferentiationInterfaceZygoteExt.jl index 28c0a7998..d6312d080 100644 --- a/ext/DifferentiationInterfaceZygoteExt.jl +++ b/ext/DifferentiationInterfaceZygoteExt.jl @@ -1,50 +1,74 @@ module DifferentiationInterfaceZygoteExt -using DifferentiationInterface: ChainRulesReverseBackend, ZygoteBackend +using ADTypes: AutoZygote +using DifferentiationInterface: AutoChainRules, CustomImplem, update! import DifferentiationInterface as DI using DocStringExtensions -using Zygote: ZygoteRuleConfig, gradient, jacobian, withgradient, withjacobian +using Zygote: ZygoteRuleConfig, gradient, jacobian, pullback, withgradient, withjacobian -## Backend construction +## Primitives -""" - ZygoteBackend(; custom=true) +const zygote_chainrules_backend = AutoChainRules(ZygoteRuleConfig()) -Enables the use of [Zygote.jl](https://github.com/FluxML/Zygote.jl) by constructing a [`ChainRulesReverseBackend`](@ref) from `ZygoteRuleConfig()`. -""" -DI.ZygoteBackend(; custom::Bool=true) = ChainRulesReverseBackend(ZygoteRuleConfig(); custom) - -const ZygoteBackendType{custom} = ChainRulesReverseBackend{custom,<:ZygoteRuleConfig} +function DI.value_and_pullback!(dx, ::AutoZygote, f, x, dy, extras::Nothing=nothing) + y, back = pullback(f, x) + new_dx = only(back(dy)) + return y, update!(dx, new_dx) +end -function Base.show(io::IO, backend::ZygoteBackendType{custom}) where {custom} - return print(io, "ZygoteBackend{$(custom ? "custom" : "fallback")}()") +function DI.value_and_pullback(::AutoZygote, f, x, dy, extras::Nothing=nothing) + y, back = pullback(f, x) + dx = only(back(dy)) + return y, dx end ## Utilities -function DI.value_and_gradient(::ZygoteBackendType{true}, f, x::AbstractArray) +function DI.value_and_gradient( + ::AutoZygote, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), +) res = withgradient(f, x) return res.val, only(res.grad) end function DI.value_and_gradient!( - grad::AbstractArray, backend::ZygoteBackendType{true}, f, x::AbstractArray + grad::AbstractArray, + backend::AutoZygote, + f, + x::AbstractArray, + extras=nothing, + implem::CustomImplem=CustomImplem(), ) - y, new_grad = DI.value_and_gradient(backend, f, x) + y, new_grad = DI.value_and_gradient(backend, f, x, extras, implem) grad .= new_grad return y, grad end -function DI.value_and_jacobian(::ZygoteBackendType{true}, f, x::AbstractArray) +function DI.value_and_jacobian( + ::AutoZygote, + f, + x::AbstractArray, + extras::Nothing=nothing, + ::CustomImplem=CustomImplem(), +) y = f(x) jac = jacobian(f, x) return y, only(jac) end function DI.value_and_jacobian!( - jac::AbstractMatrix, backend::ZygoteBackendType{true}, f, x::AbstractArray + jac::AbstractMatrix, + backend::AutoZygote, + f, + x::AbstractArray, + extras::Nothing=nothing, + implem::CustomImplem=CustomImplem(), ) - y, new_jac = DI.value_and_jacobian(backend, f, x) + y, new_jac = DI.value_and_jacobian(backend, f, x, extras, implem) jac .= new_jac return y, jac end diff --git a/src/DifferentiationInterface.jl b/src/DifferentiationInterface.jl index 51e33d01c..b8be910cc 100644 --- a/src/DifferentiationInterface.jl +++ b/src/DifferentiationInterface.jl @@ -9,11 +9,14 @@ $(EXPORTS) """ module DifferentiationInterface +using ADTypes: + AbstractADType, AbstractForwardMode, AbstractReverseMode, AbstractFiniteDifferencesMode using DocStringExtensions using FillArrays: OneElement -include("backends_abstract.jl") include("backends.jl") +include("implem.jl") +include("mode.jl") include("utils.jl") include("pushforward.jl") include("pullback.jl") @@ -22,19 +25,9 @@ include("scalar_array.jl") include("array_scalar.jl") include("array_array.jl") -export AbstractBackend, AbstractForwardBackend, AbstractReverseBackend -export autodiff_mode, is_custom -export handles_input_type, handles_output_type, handles_types +export AutoChainRules, AutoDiffractor -export ChainRulesForwardBackend, - ChainRulesReverseBackend, - EnzymeForwardBackend, - EnzymeReverseBackend, - FiniteDiffBackend, - ForwardDiffBackend, - PolyesterForwardDiffBackend, - ReverseDiffBackend, - ZygoteBackend +export handles_input_type, handles_output_type, handles_types export value_and_pushforward!, value_and_pushforward export pushforward!, pushforward diff --git a/src/array_array.jl b/src/array_array.jl index 38e23d2a1..c9607e032 100644 --- a/src/array_array.jl +++ b/src/array_array.jl @@ -6,85 +6,103 @@ This function acts as if the input and output had been flattened with `vec`. """ """ - value_and_jacobian!(jac, backend, f, x) -> (y, jac) + value_and_jacobian!(jac, backend, f, x, [extras]) -> (y, jac) Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of an array-to-array function, overwriting `jac` if possible. $JAC_NOTES """ function value_and_jacobian!( - jac::AbstractMatrix, backend::AbstractBackend, f, x::AbstractArray + jac::AbstractMatrix, + backend::AbstractADType, + f, + x::AbstractArray, + extras=nothing, + implem::AbstractImplem=CustomImplem(), ) - y = f(x) + return value_and_jacobian!(jac, backend, f, x, extras, implem, autodiff_mode(backend)) +end + +function check_jac(jac::AbstractMatrix, x::AbstractArray, y::AbstractArray) nx, ny = length(x), length(y) size(jac) != (ny, nx) && throw( DimensionMismatch("Size of Jacobian buffer doesn't match expected size ($ny, $nx)"), ) - return _value_and_jacobian!(jac, backend, f, x, y) + return nothing end -function _value_and_jacobian!( +function value_and_jacobian!( jac::AbstractMatrix, - backend::AbstractForwardBackend, + backend::AbstractADType, f, x::AbstractArray, - y::AbstractArray, + extras, + ::AbstractImplem, + ::ForwardMode, ) + y = f(x) + check_jac(jac, x, y) for (k, j) in enumerate(eachindex(IndexCartesian(), x)) dx_j = basisarray(backend, x, j) jac_col_j = reshape(view(jac, :, k), size(y)) - pushforward!(jac_col_j, backend, f, x, dx_j) + pushforward!(jac_col_j, backend, f, x, dx_j, extras) end return y, jac end -function _value_and_jacobian!( +function value_and_jacobian!( jac::AbstractMatrix, - backend::AbstractReverseBackend, + backend::AbstractADType, f, x::AbstractArray, - y::AbstractArray, + extras, + ::AbstractImplem, + ::ReverseMode, ) + y = f(x) + check_jac(jac, x, y) for (k, i) in enumerate(eachindex(IndexCartesian(), y)) dy_i = basisarray(backend, y, i) jac_row_i = reshape(view(jac, k, :), size(x)) - pullback!(jac_row_i, backend, f, x, dy_i) + pullback!(jac_row_i, backend, f, x, dy_i, extras) end return y, jac end """ - value_and_jacobian(backend, f, x) -> (y, jac) + value_and_jacobian(backend, f, x, [extras]) -> (y, jac) Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of an array-to-array function. $JAC_NOTES """ -function value_and_jacobian(backend::AbstractBackend, f, x::AbstractArray) +function value_and_jacobian(backend::AbstractADType, f, x::AbstractArray, args...) y = f(x) T = promote_type(eltype(x), eltype(y)) jac = similar(y, T, length(y), length(x)) - return value_and_jacobian!(jac, backend, f, x) + return value_and_jacobian!(jac, backend, f, x, args...) end """ - jacobian!(jac, backend, f, x) -> jac + jacobian!(jac, backend, f, x, [extras]) -> jac Compute the Jacobian matrix `jac = ∂f(x)` of an array-to-array function, overwriting `jac` if possible. $JAC_NOTES """ -function jacobian!(jac::AbstractMatrix, backend::AbstractBackend, f, x::AbstractArray) - return last(value_and_jacobian!(jac, backend, f, x)) +function jacobian!( + jac::AbstractMatrix, backend::AbstractADType, f, x::AbstractArray, args... +) + return last(value_and_jacobian!(jac, backend, f, x, args...)) end """ - jacobian(backend, f, x) -> jac + jacobian(backend, f, x, [extras]) -> jac Compute the Jacobian matrix `jac = ∂f(x)` of an array-to-array function. $JAC_NOTES """ -function jacobian(backend::AbstractBackend, f, x::AbstractArray) - return last(value_and_jacobian(backend, f, x)) +function jacobian(backend::AbstractADType, f, x::AbstractArray, args...) + return last(value_and_jacobian(backend, f, x, args...)) end diff --git a/src/array_scalar.jl b/src/array_scalar.jl index 6617bbccc..98aca4abe 100644 --- a/src/array_scalar.jl +++ b/src/array_scalar.jl @@ -1,52 +1,75 @@ """ - value_and_gradient!(grad, backend, f, x) -> (y, grad) + value_and_gradient!(grad, backend, f, x, [extras]) -> (y, grad) Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function, overwriting `grad` if possible. """ -function value_and_gradient! end +function value_and_gradient!( + grad::AbstractArray, + backend::AbstractADType, + f, + x::AbstractArray, + extras=nothing, + implem::AbstractImplem=CustomImplem(), +) + return value_and_gradient!(grad, backend, f, x, extras, implem, autodiff_mode(backend)) +end function value_and_gradient!( - grad::AbstractArray, backend::AbstractForwardBackend, f, x::AbstractArray + grad::AbstractArray, + backend::AbstractADType, + f, + x::AbstractArray, + extras, + ::AbstractImplem, + ::ForwardMode, ) y = f(x) for j in eachindex(IndexCartesian(), x) dx_j = basisarray(backend, x, j) - grad[j] = pushforward!(grad[j], backend, f, x, dx_j) + grad[j] = pushforward!(grad[j], backend, f, x, dx_j, extras) end return y, grad end function value_and_gradient!( - grad::AbstractArray, backend::AbstractReverseBackend, f, x::AbstractArray + grad::AbstractArray, + backend::AbstractADType, + f, + x::AbstractArray, + extras, + ::AbstractImplem, + ::ReverseMode, ) y = f(x) - return y, pullback!(grad, backend, f, x, one(y)) + return y, pullback!(grad, backend, f, x, one(y), extras) end """ - value_and_gradient(backend, f, x) -> (y, grad) + value_and_gradient(backend, f, x, [extras]) -> (y, grad) Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function. """ -function value_and_gradient(backend::AbstractBackend, f, x::AbstractArray) +function value_and_gradient(backend::AbstractADType, f, x::AbstractArray, args...) grad = similar(x) - return value_and_gradient!(grad, backend, f, x) + return value_and_gradient!(grad, backend, f, x, args...) end """ - gradient!(grad, backend, f, x) -> grad + gradient!(grad, backend, f, x, [extras]) -> grad Compute the gradient `grad = ∇f(x)` of an array-to-scalar function, overwriting `grad` if possible. """ -function gradient!(grad::AbstractArray, backend::AbstractBackend, f, x::AbstractArray) - return last(value_and_gradient!(grad, backend, f, x)) +function gradient!( + grad::AbstractArray, backend::AbstractADType, f, x::AbstractArray, args... +) + return last(value_and_gradient!(grad, backend, f, x, args...)) end """ - gradient(backend, f, x) -> grad + gradient(backend, f, x, [extras]) -> grad Compute the gradient `grad = ∇f(x)` of an array-to-scalar function. """ -function gradient(backend::AbstractBackend, f, x::AbstractArray) - return last(value_and_gradient(backend, f, x)) +function gradient(backend::AbstractADType, f, x::AbstractArray, args...) + return last(value_and_gradient(backend, f, x, args...)) end diff --git a/src/backends.jl b/src/backends.jl index 81581e87d..12cbdd9a8 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -1,115 +1,72 @@ -""" - ChainRulesForwardBackend <: AbstractForwardBackend +## Additional backends + +# TODO: remove once https://github.com/SciML/ADTypes.jl/pull/21 is merged and released -Enables the use of forward mode AD packages based on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl). """ -struct ChainRulesForwardBackend{custom,RC} <: AbstractForwardBackend{custom} - ruleconfig::RC -end + AutoChainRules{RC} -function Base.show(io::IO, backend::ChainRulesForwardBackend{custom}) where {custom} - return print( - io, - "ChainRulesForwardBackend{$(custom ? "custom" : "fallback")}($(backend.ruleconfig))", - ) -end +Enables the use of AD libraries based on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl). -""" - ChainRulesReverseBackend <: AbstractReverseBackend +# Fields + +- `ruleconfig::RC`: a [`RuleConfig`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/superpowers/ruleconfig.html) object -Enables the use of reverse mode AD packages based on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl). +# Example + +```julia +using DifferentiationInterface, Zygote +backend = AutoChainRules(Zygote.ZygoteRuleConfig()) +``` """ -struct ChainRulesReverseBackend{custom,RC} <: AbstractReverseBackend{custom} +struct AutoChainRules{RC} <: AbstractADType ruleconfig::RC end -function Base.show(io::IO, backend::ChainRulesReverseBackend{custom}) where {custom} - return print( - io, - "ChainRulesReverseBackend{$(custom ? "custom" : "fallback")}($(backend.ruleconfig))", - ) -end +# TODO: remove this once https://github.com/SciML/ADTypes.jl/issues/27 is solved """ - FiniteDiffBackend <: AbstractForwardBackend + AutoDiffractor -Enables the use of [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl). +Enables the use of [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl). """ -struct FiniteDiffBackend{custom,fdtype} <: AbstractForwardBackend{custom} end +struct AutoDiffractor <: AbstractADType end -function Base.show(io::IO, ::FiniteDiffBackend{custom,fdtype}) where {custom,fdtype} - return print(io, "FiniteDiffBackend{$(custom ? "custom" : "fallback"),$fdtype}()") -end +## Traits and access """ - EnzymeForwardBackend <: AbstractForwardBackend - -Enables the use of [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) in forward mode. -""" -struct EnzymeForwardBackend{custom} <: AbstractForwardBackend{custom} end - -function Base.show(io::IO, ::EnzymeForwardBackend{custom}) where {custom} - return print(io, "EnzymeForwardBackend{$(custom ? "custom" : "fallback")}()") -end + autodiff_mode(backend) -""" - EnzymeReverseBackend <: AbstractReverseBackend +Return `ForwardMode()` or `ReverseMode()` in a statically predictable way. -Enables the use of [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) in reverse mode. +This function must be overloaded for backends that do not inherit from `ADTypes.AbstractForwardMode` or `ADTypes.AbstractReverseMode` (e.g. because they support both forward and reverse). -!!! warning - This backend only works for scalar output. +We classify `ADTypes.AbstractFiniteDifferencesMode` as forward mode. """ -struct EnzymeReverseBackend{custom} <: AbstractReverseBackend{custom} end - -function Base.show(io::IO, ::EnzymeReverseBackend{custom}) where {custom} - return print(io, "EnzymeReverseBackend{$(custom ? "custom" : "fallback")}()") -end +autodiff_mode(::AbstractForwardMode) = ForwardMode() +autodiff_mode(::AbstractFiniteDifferencesMode) = ForwardMode() +autodiff_mode(::AbstractReverseMode) = ReverseMode() """ - ForwardDiffBackend <: AbstractForwardBackend + handles_input_type(backend, ::Type{X}) -Enables the use of [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). +Check if `backend` can differentiate functions with input type `X`. """ -struct ForwardDiffBackend{custom} <: AbstractForwardBackend{custom} end - -function Base.show(io::IO, ::ForwardDiffBackend{custom}) where {custom} - return print(io, "ForwardDiffBackend{$(custom ? "custom" : "fallback")}()") -end +handles_input_type(::AbstractADType, ::Type{<:Number}) = true +handles_input_type(::AbstractADType, ::Type{<:AbstractArray}) = true """ - PolyesterForwardDiffBackend <: AbstractForwardBackend - -Enables the use of [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl), falling back on [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) if needed. + handles_output_type(backend, ::Type{Y}) -!!! warning - This backend only works when the arrays are vectors. +Check if `backend` can differentiate functions with output type `Y`. """ -struct PolyesterForwardDiffBackend{custom,C} <: AbstractForwardBackend{custom} end - -function Base.show(io::IO, ::PolyesterForwardDiffBackend{custom,C}) where {custom,C} - return print(io, "PolyesterForwardDiffBackend{$(custom ? "custom" : "fallback"),$C}()") -end +handles_output_type(::AbstractADType, ::Type{<:Number}) = true +handles_output_type(::AbstractADType, ::Type{<:AbstractArray}) = true """ - ReverseDiffBackend <: AbstractReverseBackend - -Performs autodiff with [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl). + handles_types(backend, ::Type{X}, ::Type{Y}) -!!! warning - This backend only works for array input. +Check if `backend` can differentiate functions with input type `X` and output type `Y`. """ -struct ReverseDiffBackend{custom} <: AbstractReverseBackend{custom} end - -function Base.show(io::IO, ::ReverseDiffBackend{custom}) where {custom} - return print(io, "ReverseDiffBackend{$(custom ? "custom" : "fallback")}()") +function handles_types(backend::AbstractADType, ::Type{X}, ::Type{Y}) where {X,Y} + return handles_input_type(backend, X) && handles_output_type(backend, Y) end - -## Pseudo backends - -function ZygoteBackend end - -## Limitations - -handles_input_type(::ReverseDiffBackend, ::Type{<:Number}) = false -handles_output_type(::EnzymeReverseBackend, ::Type{<:AbstractArray}) = false diff --git a/src/backends_abstract.jl b/src/backends_abstract.jl deleted file mode 100644 index dbb9e540f..000000000 --- a/src/backends_abstract.jl +++ /dev/null @@ -1,68 +0,0 @@ - -""" - AbstractBackend - -Abstract type pointing to the AD package chosen by the user, which is called a "backend". - -## Custom - -When we say that a backend is "custom", it describes how the utilities (derivative, multiderivative, gradient and jacobian) are implemented: - -- Custom backends use specific routines defined in their package whenever they exist -- Non-custom backends use fallbacks defined in DifferentiationInterface.jl, which end up calling the pushforward or pullback -""" -abstract type AbstractBackend{custom} end - -""" - is_custom(backend) - -Return a boolean `custom` that describes how utilities (derivative, multiderivative, gradient and jacobian) are implemented. -""" -is_custom(::AbstractBackend{custom}) where {custom} = custom - -""" - handles_input_type(backend, ::Type{X}) - -Check if `backend` can differentiate functions with input type `X`. -""" -handles_input_type(::AbstractBackend, ::Type{<:Number}) = true -handles_input_type(::AbstractBackend, ::Type{<:AbstractArray}) = true - -""" - handles_output_type(backend, ::Type{Y}) - -Check if `backend` can differentiate functions with output type `Y`. -""" -handles_output_type(::AbstractBackend, ::Type{<:Number}) = true -handles_output_type(::AbstractBackend, ::Type{<:AbstractArray}) = true - -""" - handles_types(backend, ::Type{X}, ::Type{Y}) - -Check if `backend` can differentiate functions with input type `X` and output type `Y`. -""" -function handles_types(backend::AbstractBackend, ::Type{X}, ::Type{Y}) where {X,Y} - return handles_input_type(backend, X) && handles_output_type(backend, Y) -end - -""" - AbstractForwardBackend <: AbstractBackend - -Abstract subtype of [`AbstractBackend`](@ref) for forward mode AD packages. -""" -abstract type AbstractForwardBackend{custom} <: AbstractBackend{custom} end - -""" - AbstractReverseBackend <: AbstractBackend - -Abstract subtype of [`AbstractBackend`](@ref) for reverse mode AD packages. -""" -abstract type AbstractReverseBackend{custom} <: AbstractBackend{custom} end - -""" - autodiff_mode(backend) - -Return either `:forward` or `:reverse` depending on the mode of `backend`. -""" -autodiff_mode(::AbstractForwardBackend) = :forward -autodiff_mode(::AbstractReverseBackend) = :reverse diff --git a/src/implem.jl b/src/implem.jl new file mode 100644 index 000000000..de9413e87 --- /dev/null +++ b/src/implem.jl @@ -0,0 +1,17 @@ +abstract type AbstractImplem end + +""" + CustomImplem + +Trait specifying that the custom utilities from the backend should be used as much as possible. +Used for internal dispatch only. +""" +struct CustomImplem <: AbstractImplem end + +""" + FallbackImplem + +Trait specifying that the fallback utilities from DifferentiationInterface.jl should be used as much as possible, until they call a pushforward or pullback. +Used for internal dispatch only. +""" +struct FallbackImplem <: AbstractImplem end diff --git a/src/mode.jl b/src/mode.jl new file mode 100644 index 000000000..365e93e5a --- /dev/null +++ b/src/mode.jl @@ -0,0 +1,17 @@ +abstract type AbstractMode end + +""" + ForwardMode + +Trait identifying forward mode AD backends. +Used for internal dispatch only. +""" +struct ForwardMode <: AbstractMode end + +""" + ReverseMode + +Trait identifying reverse mode AD backends. +Used for internal dispatch only. +""" +struct ReverseMode <: AbstractMode end diff --git a/src/pullback.jl b/src/pullback.jl index 501ff7c57..bb496b392 100644 --- a/src/pullback.jl +++ b/src/pullback.jl @@ -1,41 +1,41 @@ """ - value_and_pullback!(dx, backend::AbstractReverseBackend, f, x, dy) -> (y, dx) + value_and_pullback!(dx, backend, f, x, dy, [extras]) -> (y, dx) Compute the primal value `y = f(x)` and the vector-Jacobian product `dx = ∂f(x)' * dy`, overwriting `dx` if possible. !!! info "Interface requirement" - This is the only required implementation for an [`AbstractReverseBackend`](@ref). + This is the only required implementation for a reverse mode backend. """ -function value_and_pullback!(dx, backend::AbstractReverseBackend, f, x, dy) +function value_and_pullback!(dx, backend::AbstractADType, f, x, dy, extras=nothing) return error( - "Backend $backend is not loaded or does not support this type combination." + "The package for $(typeof(backend)) is not loaded, or the backend does not support this type combination: `typeof(x) = $(typeof(x))` and `typeof(y) = $(typeof(dy))`", ) end """ - value_and_pullback(backend::AbstractReverseBackend, f, x, dy) -> (y, dx) + value_and_pullback(backend, f, x, dy, [extras]) -> (y, dx) Compute the primal value `y = f(x)` and the vector-Jacobian product `dx = ∂f(x)' * dy`. """ -function value_and_pullback(backend::AbstractReverseBackend, f, x, dy) +function value_and_pullback(backend::AbstractADType, f, x, dy, extras=nothing) dx = mysimilar(x) - return value_and_pullback!(dx, backend, f, x, dy) + return value_and_pullback!(dx, backend, f, x, dy, extras) end """ - pullback!(dx, backend::AbstractReverseBackend, f, x, dy) -> dx + pullback!(dx, backend, f, x, dy, [extras]) -> dx Compute the vector-Jacobian product `dx = ∂f(x)' * dy`, overwriting `dx` if possible. """ -function pullback!(dx, backend::AbstractReverseBackend, f, x, dy) - return last(value_and_pullback!(dx, backend, f, x, dy)) +function pullback!(dx, backend::AbstractADType, f, x, dy, extras=nothing) + return last(value_and_pullback!(dx, backend, f, x, dy, extras)) end """ - pullback(backend::AbstractReverseBackend, f, x, dy) -> dx + pullback(backend, f, x, dy, [extras]) -> dx Compute the vector-Jacobian product `dx = ∂f(x)' * dy`. """ -function pullback(backend::AbstractReverseBackend, f, x, dy) - return last(value_and_pullback(backend, f, x, dy)) +function pullback(backend::AbstractADType, f, x, dy, extras=nothing) + return last(value_and_pullback(backend, f, x, dy, extras)) end diff --git a/src/pushforward.jl b/src/pushforward.jl index e71b92197..0b1836f6c 100644 --- a/src/pushforward.jl +++ b/src/pushforward.jl @@ -1,41 +1,41 @@ """ - value_and_pushforward!(dy, backend::AbstractForwardBackend, f, x, dx) -> (y, dy) + value_and_pushforward!(dy, backend, f, x, dx, [extras]) -> (y, dy) Compute the primal value `y = f(x)` and the Jacobian-vector product `dy = ∂f(x) * dx`, overwriting `dy` if possible. !!! info "Interface requirement" - This is the only required implementation for an [`AbstractForwardBackend`](@ref). + This is the only required implementation for a forward mode backend. """ -function value_and_pushforward!(dy, backend::AbstractForwardBackend, f, x, dx) +function value_and_pushforward!(dy, backend::AbstractADType, f, x, dx, extras=nothing) return error( - "Backend $backend is not loaded or does not support this type combination." + "The package for $(typeof(backend)) is not loaded, or the backend does not support this type combination: `typeof(x) = $(typeof(x))` and `typeof(y) = $(typeof(dy))`", ) end """ - value_and_pushforward(backend::AbstractForwardBackend, f, x, dx) -> (y, dy) + value_and_pushforward(backend, f, x, dx, [extras]) -> (y, dy) Compute the primal value `y = f(x)` and the Jacobian-vector product `dy = ∂f(x) * dx`. """ -function value_and_pushforward(backend::AbstractForwardBackend, f, x, dx) +function value_and_pushforward(backend::AbstractADType, f, x, dx, extras=nothing) dy = mysimilar(f(x)) - return value_and_pushforward!(dy, backend, f, x, dx) + return value_and_pushforward!(dy, backend, f, x, dx, extras) end """ - pushforward!(dy, backend::AbstractForwardBackend, f, x, dx) -> dy + pushforward!(dy, backend, f, x, dx, [extras]) -> dy Compute the Jacobian-vector product `dy = ∂f(x) * dx`, overwriting `dy` if possible. """ -function pushforward!(dy, backend::AbstractForwardBackend, f, x, dx) - return last(value_and_pushforward!(dy, backend, f, x, dx)) +function pushforward!(dy, backend::AbstractADType, f, x, dx, extras=nothing) + return last(value_and_pushforward!(dy, backend, f, x, dx, extras)) end """ - pushforward(backend::AbstractForwardBackend, f, x, dx) -> dy + pushforward(backend, f, x, dx, [extras]) -> dy Compute the Jacobian-vector product `dy = ∂f(x) * dx`. """ -function pushforward(backend::AbstractForwardBackend, f, x, dx) - return last(value_and_pushforward(backend, f, x, dx)) +function pushforward(backend::AbstractADType, f, x, dx, extras=nothing) + return last(value_and_pushforward(backend, f, x, dx, extras)) end diff --git a/src/scalar_array.jl b/src/scalar_array.jl index 1b6aed54c..c1a7d43e4 100644 --- a/src/scalar_array.jl +++ b/src/scalar_array.jl @@ -1,51 +1,76 @@ """ - value_and_multiderivative!(multider, backend, f, x) -> (y, multider) + value_and_multiderivative!(multider, backend, f, x, [extras]) -> (y, multider) Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function, overwriting `multider` if possible. """ -function value_and_multiderivative! end +function value_and_multiderivative!( + multider::AbstractArray, + backend::AbstractADType, + f, + x::Number, + extras=nothing, + implem::AbstractImplem=CustomImplem(), +) + return value_and_multiderivative!( + multider, backend, f, x, extras, implem, autodiff_mode(backend) + ) +end function value_and_multiderivative!( - multider::AbstractArray, backend::AbstractForwardBackend, f, x::Number + multider::AbstractArray, + backend::AbstractADType, + f, + x::Number, + extras, + ::AbstractImplem, + ::ForwardMode, ) - return value_and_pushforward!(multider, backend, f, x, one(x)) + return value_and_pushforward!(multider, backend, f, x, one(x), extras) end function value_and_multiderivative!( - multider::AbstractArray, backend::AbstractReverseBackend, f, x::Number + multider::AbstractArray, + backend::AbstractADType, + f, + x::Number, + extras, + ::AbstractImplem, + ::ReverseMode, ) y = f(x) for i in eachindex(IndexCartesian(), y) dy_i = basisarray(backend, y, i) - multider[i] = pullback!(multider[i], backend, f, x, dy_i) + multider[i] = pullback!(multider[i], backend, f, x, dy_i, extras) end return y, multider end """ - value_and_multiderivative(backend, f, x) -> (y, multider) + value_and_multiderivative(backend, f, x, [extras]) -> (y, multider) Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function. """ -function value_and_multiderivative(backend::AbstractBackend, f, x::Number) +function value_and_multiderivative(backend::AbstractADType, f, x::Number, args...) multider = similar(f(x)) - return value_and_multiderivative!(multider, backend, f, x) + return value_and_multiderivative!(multider, backend, f, x, args...) end """ - multiderivative!(multider, backend, f, x) -> multider + multiderivative!(multider, backend, f, x, [extras]) -> multider Compute the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function, overwriting `multider` if possible. """ -function multiderivative!(multider::AbstractArray, backend::AbstractBackend, f, x::Number) - return last(value_and_multiderivative!(multider, backend, f, x)) +function multiderivative!( + multider::AbstractArray, backend::AbstractADType, f, x::Number, args... +) + return last(value_and_multiderivative!(multider, backend, f, x, args...)) end """ - multiderivative(backend, f, x) -> multider + multiderivative(backend, f, x, [extras]) -> multider Compute the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function. """ -function multiderivative(backend::AbstractBackend, f, x::Number) - return last(value_and_multiderivative(backend, f, x)) +function multiderivative(backend::AbstractADType, f, x::Number, args...) + return last(value_and_multiderivative(backend, f, x, args...)) end diff --git a/src/scalar_scalar.jl b/src/scalar_scalar.jl index 7cf8c9c2a..1b3c9c349 100644 --- a/src/scalar_scalar.jl +++ b/src/scalar_scalar.jl @@ -1,23 +1,35 @@ """ - value_and_derivative(backend, f, x) -> (y, der) + value_and_derivative(backend, f, x, [extras]) -> (y, der) Compute the primal value `y = f(x)` and the derivative `der = f'(x)` of a scalar-to-scalar function. """ -function value_and_derivative end +function value_and_derivative( + backend::AbstractADType, + f, + x::Number, + extras=nothing, + implem::AbstractImplem=CustomImplem(), +) + return value_and_derivative(backend, f, x, extras, implem, autodiff_mode(backend)) +end -function value_and_derivative(backend::AbstractForwardBackend, f, x::Number) - return value_and_pushforward!(one(x), backend, f, x, one(x)) +function value_and_derivative( + backend::AbstractADType, f, x::Number, extras, ::AbstractImplem, ::ForwardMode +) + return value_and_pushforward!(one(x), backend, f, x, one(x), extras) end -function value_and_derivative(backend::AbstractReverseBackend, f, x::Number) - return value_and_pullback!(one(x), backend, f, x, one(x)) +function value_and_derivative( + backend::AbstractADType, f, x::Number, extras, ::AbstractImplem, ::ReverseMode +) + return value_and_pullback!(one(x), backend, f, x, one(x), extras) end """ - derivative(backend, f, x) -> der + derivative(backend, f, x, [extras]) -> der Compute the derivative `der = f'(x)` of a scalar-to-scalar function. """ -function derivative(backend::AbstractBackend, f, x::Number) - return last(value_and_derivative(backend, f, x)) +function derivative(backend::AbstractADType, f, x::Number, args...) + return last(value_and_derivative(backend, f, x, args...)) end diff --git a/src/utils.jl b/src/utils.jl index 8c4f9746a..ebcbd4f19 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -8,7 +8,7 @@ Construct the `i`-th stardard basis array in the vector space of `a` with elemen If an AD backend benefits from a more specialized basis array implementation, this function can be extended on the backend type. """ -basisarray(::AbstractBackend, a::AbstractArray, i) = basisarray(a, i) +basisarray(::AbstractADType, a::AbstractArray, i) = basisarray(a, i) function basisarray(a::AbstractArray{T,N}, i::CartesianIndex{N}) where {T,N} return OneElement(one(T), Tuple(i), axes(a)) @@ -16,3 +16,6 @@ end mysimilar(x::Number) = zero(x) mysimilar(x::AbstractArray) = similar(x) + +update!(_old::Number, new::Number) = new +update!(old, new) = old .= new diff --git a/test/Project.toml b/test/Project.toml index e23957ab6..d903c0ef0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" diff --git a/test/backends.jl b/test/backends.jl deleted file mode 100644 index b694a6a89..000000000 --- a/test/backends.jl +++ /dev/null @@ -1,11 +0,0 @@ -using DifferentiationInterface -using DifferentiationInterface: is_custom -using Test - -backend = ChainRulesForwardBackend{false,Nothing}(nothing) -@test !is_custom(backend) -@test autodiff_mode(backend) == :forward - -backend = ChainRulesReverseBackend{true,Nothing}(nothing) -@test is_custom(backend) -@test autodiff_mode(backend) == :reverse diff --git a/test/chainrules_forward.jl b/test/chainrules_forward.jl new file mode 100644 index 000000000..3f2d2e549 --- /dev/null +++ b/test/chainrules_forward.jl @@ -0,0 +1,10 @@ +using DifferentiationInterface: AutoChainRules, CustomImplem, FallbackImplem +using Diffractor: DiffractorRuleConfig + +test_pushforward(AutoChainRules(DiffractorRuleConfig()), scenarios; type_stability=false); +test_jacobian_and_friends( + CustomImplem(), AutoChainRules(DiffractorRuleConfig()), scenarios; type_stability=false +); +test_jacobian_and_friends( + FallbackImplem().AutoChainRules(DiffractorRuleConfig()), scenarios; type_stability=false +); diff --git a/test/chainrules_reverse.jl b/test/chainrules_reverse.jl new file mode 100644 index 000000000..54e0ef506 --- /dev/null +++ b/test/chainrules_reverse.jl @@ -0,0 +1,10 @@ +using DifferentiationInterface: AutoChainRules, CustomImplem, FallbackImplem +using Zygote: ZygoteRuleConfig + +test_pullback(AutoChainRules(ZygoteRuleConfig()), scenarios; type_stability=false); +test_jacobian_and_friends( + CustomImplem(), AutoChainRules(ZygoteRuleConfig()), scenarios; type_stability=false +); +test_jacobian_and_friends( + FallbackImplem(), AutoChainRules(ZygoteRuleConfig()), scenarios; type_stability=false +); diff --git a/test/diffractor.jl b/test/diffractor.jl index 62b293e50..ad9c56c57 100644 --- a/test/diffractor.jl +++ b/test/diffractor.jl @@ -1,15 +1,8 @@ -using DifferentiationInterface +using DifferentiationInterface: AutoDiffractor, CustomImplem, FallbackImplem using Diffractor: Diffractor -# see https://github.com/JuliaDiff/Diffractor.jl/issues/277 - -@test_skip test_pushforward( - ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()), - scenarios; - type_stability=false, -); -@test_skip test_jacobian_and_friends( - ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()), - scenarios; - type_stability=false, +test_pushforward(AutoDiffractor(), scenarios; type_stability=false); +test_jacobian_and_friends(CustomImplem(), AutoDiffractor(), scenarios; type_stability=false); +test_jacobian_and_friends( + FallbackImplem(), AutoDiffractor(), scenarios; type_stability=false ); diff --git a/test/enzyme_forward.jl b/test/enzyme_forward.jl index 9289a07dc..d50e28dd7 100644 --- a/test/enzyme_forward.jl +++ b/test/enzyme_forward.jl @@ -1,10 +1,11 @@ -using DifferentiationInterface +using ADTypes: AutoEnzyme +using DifferentiationInterface: CustomImplem, FallbackImplem using Enzyme: Enzyme -test_pushforward(EnzymeForwardBackend(), scenarios; type_stability=true); +test_pushforward(AutoEnzyme(Val(:forward)), scenarios; type_stability=true); test_jacobian_and_friends( - EnzymeForwardBackend(; custom=true), scenarios; type_stability=true + CustomImplem(), AutoEnzyme(Val(:forward)), scenarios; type_stability=true ); test_jacobian_and_friends( - EnzymeForwardBackend(; custom=false), scenarios; type_stability=true + FallbackImplem(), AutoEnzyme(Val(:forward)), scenarios; type_stability=true ); diff --git a/test/enzyme_reverse.jl b/test/enzyme_reverse.jl index 022e64843..b3716215b 100644 --- a/test/enzyme_reverse.jl +++ b/test/enzyme_reverse.jl @@ -1,10 +1,11 @@ -using DifferentiationInterface +using ADTypes: AutoEnzyme +using DifferentiationInterface: CustomImplem, FallbackImplem using Enzyme: Enzyme -test_pullback(EnzymeReverseBackend(), scenarios; type_stability=true); +test_pullback(AutoEnzyme(Val(:reverse)), scenarios; type_stability=true); test_jacobian_and_friends( - EnzymeReverseBackend(; custom=true), scenarios; type_stability=true + CustomImplem(), AutoEnzyme(Val(:reverse)), scenarios; type_stability=true ) test_jacobian_and_friends( - EnzymeReverseBackend(; custom=false), scenarios; type_stability=true + FallbackImplem(), AutoEnzyme(Val(:reverse)), scenarios; type_stability=true ) diff --git a/test/finitediff.jl b/test/finitediff.jl index b976b1f62..b809cb268 100644 --- a/test/finitediff.jl +++ b/test/finitediff.jl @@ -1,6 +1,9 @@ -using DifferentiationInterface +using ADTypes: AutoFiniteDiff +using DifferentiationInterface: CustomImplem, FallbackImplem using FiniteDiff: FiniteDiff -test_pushforward(FiniteDiffBackend(), scenarios; type_stability=true); -test_jacobian_and_friends(FiniteDiffBackend(; custom=true), scenarios; type_stability=false); -test_jacobian_and_friends(FiniteDiffBackend(; custom=false), scenarios; type_stability=true); +test_pushforward(AutoFiniteDiff(), scenarios; type_stability=true); +test_jacobian_and_friends(CustomImplem(), AutoFiniteDiff(), scenarios; type_stability=false); +test_jacobian_and_friends( + FallbackImplem(), AutoFiniteDiff(), scenarios; type_stability=false +); diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index f6344fc61..5a2c92c2b 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -1,10 +1,11 @@ -using DifferentiationInterface +using ADTypes: AutoForwardDiff +using DifferentiationInterface: CustomImplem, FallbackImplem using ForwardDiff: ForwardDiff -test_pushforward(ForwardDiffBackend(), scenarios; type_stability=true); +test_pushforward(AutoForwardDiff(), scenarios; type_stability=true); test_jacobian_and_friends( - ForwardDiffBackend(; custom=true), scenarios; type_stability=false + CustomImplem(), AutoForwardDiff(), scenarios; type_stability=false ); test_jacobian_and_friends( - ForwardDiffBackend(; custom=false), scenarios; type_stability=true + FallbackImplem(), AutoForwardDiff(), scenarios; type_stability=false ); diff --git a/test/polyesterforwarddiff.jl b/test/polyesterforwarddiff.jl index 1a3a51233..bf7ab2c4b 100644 --- a/test/polyesterforwarddiff.jl +++ b/test/polyesterforwarddiff.jl @@ -1,14 +1,14 @@ -using DifferentiationInterface +using ADTypes: AutoPolyesterForwardDiff +using DifferentiationInterface: CustomImplem, FallbackImplem using PolyesterForwardDiff: PolyesterForwardDiff # see https://github.com/JuliaDiff/PolyesterForwardDiff.jl/issues/17 -test_pushforward( - PolyesterForwardDiffBackend(4; custom=true), scenarios; type_stability=false -); +test_pushforward(AutoPolyesterForwardDiff(; chunksize=4), scenarios; type_stability=false); test_jacobian_and_friends( - PolyesterForwardDiffBackend(4; custom=true), + CustomImplem(), + AutoPolyesterForwardDiff(; chunksize=4), scenarios; input_type=Union{Number,AbstractVector}, output_type=Union{Number,AbstractVector}, diff --git a/test/reversediff.jl b/test/reversediff.jl index 748949ac6..0bea69d92 100644 --- a/test/reversediff.jl +++ b/test/reversediff.jl @@ -1,10 +1,11 @@ -using DifferentiationInterface +using ADTypes: AutoReverseDiff +using DifferentiationInterface: CustomImplem, FallbackImplem using ReverseDiff: ReverseDiff -test_pullback(ReverseDiffBackend(), scenarios; type_stability=false); +test_pullback(AutoReverseDiff(), scenarios; type_stability=false); test_jacobian_and_friends( - ReverseDiffBackend(; custom=true), scenarios; type_stability=false + CustomImplem(), AutoReverseDiff(), scenarios; type_stability=false ); test_jacobian_and_friends( - ReverseDiffBackend(; custom=false), scenarios; type_stability=false + FallbackImplem(), AutoReverseDiff(), scenarios; type_stability=false ); diff --git a/test/runtests.jl b/test/runtests.jl index 50f0f19d2..8eb10b614 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,8 @@ using Test ## Utils -include("utils.jl") +include("scenarios.jl"); +include("utils.jl"); ## Main tests @@ -25,11 +26,13 @@ include("utils.jl") JET.test_package(DifferentiationInterface; target_defined_modules=true) end - @testset "Backend utilities" begin - include("backends.jl") + @testset "ChainRules (forward)" begin + @test_skip include("chainrules_forward.jl") end - - @testset "Diffractor" begin + @testset "ChainRules (reverse)" begin + include("chainrules_reverse.jl") + end + @testset "Diffractor (forward)" begin include("diffractor.jl") end @testset "Enzyme (forward)" begin diff --git a/test/scenarios.jl b/test/scenarios.jl new file mode 100644 index 000000000..2cf4e8813 --- /dev/null +++ b/test/scenarios.jl @@ -0,0 +1,123 @@ +using ForwardDiff: ForwardDiff +using LinearAlgebra +using Random: AbstractRNG, randn! +using StableRNGs + +## Test scenarios + +@kwdef struct Scenario{F,X,Y,D1,D2,D3,D4} + "function" + f::F + "argument" + x::X + "primal value" + y::Y + "pushforward seed" + dx::X + "pullback seed" + dy::Y + "pullback result" + dx_true::X + "pushforward result" + dy_true::Y + "derivative result" + der_true::D1 = nothing + "multiderivative result" + multider_true::D2 = nothing + "gradient result" + grad_true::D3 = nothing + "Jacobian result" + jac_true::D4 = nothing +end + +## Constructors + +function make_scenario(rng::AbstractRNG, f, x) + y = f(x) + return make_scenario(rng, f, x, y) +end + +function make_scenario(rng::AbstractRNG, f::F, x::X, y::Y) where {F,X<:Number,Y<:Number} + dx = randn(rng, X) + dy = randn(rng, Y) + der_true = ForwardDiff.derivative(f, x) + dx_true = der_true * dy + dy_true = der_true * dx + return Scenario(; f, x, y, dx, dy, dx_true, dy_true, der_true) +end + +function make_scenario( + rng::AbstractRNG, f::F, x::X, y::Y +) where {F,X<:Number,Y<:AbstractArray} + dx = randn(rng, X) + dy = similar(y) + randn!(rng, dy) + multider_true = ForwardDiff.derivative(f, x) + dx_true = dot(multider_true, dy) + dy_true = multider_true .* dx + return Scenario(; f, x, y, dx, dy, dx_true, dy_true, multider_true) +end + +function make_scenario( + rng::AbstractRNG, f::F, x::X, y::Y +) where {F,X<:AbstractArray,Y<:Number} + dx = similar(x) + randn!(rng, dx) + dy = randn(rng, Y) + grad_true = ForwardDiff.gradient(f, x) + dx_true = grad_true .* dy + dy_true = dot(grad_true, dx) + return Scenario(; f, x, y, dx, dy, dx_true, dy_true, grad_true) +end + +function make_scenario( + rng::AbstractRNG, f::F, x::X, y::Y +) where {F,X<:AbstractArray,Y<:AbstractArray} + dx = similar(x) + randn!(rng, dx) + dy = similar(y) + randn!(rng, dy) + jac_true = ForwardDiff.jacobian(f, x) + dx_true = reshape(transpose(jac_true) * vec(dy), size(x)) + dy_true = reshape(jac_true * vec(dx), size(y)) + return Scenario(; f, x, y, dx, dy, dx_true, dy_true, jac_true) +end + +## Access + +get_input_type(::Scenario{F,X}) where {F,X} = X +get_output_type(::Scenario{F,X,Y}) where {F,X,Y} = Y + +## Seed + +rng = StableRNG(63) + +## Scenarios + +f_scalar_scalar(x::Number)::Number = sin(x) + +f_scalar_vector(x::Number)::AbstractVector = [sin(x), sin(2x)] +f_scalar_matrix(x::Number)::AbstractMatrix = hcat([sin(x) cos(x)], [sin(2x) cos(2x)]) + +f_vector_scalar(x::AbstractVector)::Number = sum(sin, x) +f_matrix_scalar(x::AbstractMatrix)::Number = sum(sin, x) + +f_vector_vector(x::AbstractVector)::AbstractVector = vcat(sin.(x), cos.(x)) +f_vector_matrix(x::AbstractVector)::AbstractMatrix = hcat(sin.(x), cos.(x)) + +f_matrix_vector(x::AbstractMatrix)::AbstractVector = vcat(vec(sin.(x)), vec(cos.(x))) +f_matrix_matrix(x::AbstractMatrix)::AbstractMatrix = hcat(vec(sin.(x)), vec(cos.(x))) + +## All + +scenarios = [ + make_scenario(rng, f_scalar_scalar, 1.0), + make_scenario(rng, f_scalar_vector, 1.0), + make_scenario(rng, f_scalar_matrix, 1.0), + make_scenario(rng, f_vector_scalar, [1.0, 2.0]), + make_scenario(rng, f_matrix_scalar, [1.0 2.0; 3.0 4.0]), + make_scenario(rng, f_vector_vector, [1.0, 2.0]), + make_scenario(rng, f_vector_matrix, [1.0, 2.0]), + make_scenario(rng, f_matrix_vector, [1.0 2.0; 3.0 4.0]), + make_scenario(rng, f_matrix_matrix, [1.0 2.0; 3.0 4.0]), +]; diff --git a/test/utils.jl b/test/utils.jl index 7fed9e7e2..052cc0cd8 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,166 +1,32 @@ +using ADTypes: AbstractADType using DifferentiationInterface using DifferentiationInterface: - AbstractBackend, AbstractReverseBackend, AbstractForwardBackend -using ForwardDiff: ForwardDiff -using LinearAlgebra + AbstractImplem, CustomImplem, FallbackImplem, ForwardMode, ReverseMode, autodiff_mode using JET -using Random: AbstractRNG, randn! -using StableRNGs using Test -## Test scenarios - -@kwdef struct Scenario{F,X,Y,D1,D2,D3,D4} - "function" - f::F - "argument" - x::X - "primal value" - y::Y - "pushforward seed" - dx::X - "pullback seed" - dy::Y - "pullback result" - dx_true::X - "pushforward result" - dy_true::Y - "derivative result" - der_true::D1 = nothing - "multiderivative result" - multider_true::D2 = nothing - "gradient result" - grad_true::D3 = nothing - "Jacobian result" - jac_true::D4 = nothing -end - -## Constructors - -function make_scenario(rng::AbstractRNG, f, x) - y = f(x) - return make_scenario(rng, f, x, y) -end - -function make_scenario(rng::AbstractRNG, f::F, x::X, y::Y) where {F,X<:Number,Y<:Number} - dx = randn(rng, X) - dy = randn(rng, Y) - der_true = ForwardDiff.derivative(f, x) - dx_true = der_true * dy - dy_true = der_true * dx - return Scenario(; f, x, y, dx, dy, dx_true, dy_true, der_true) -end - -function make_scenario( - rng::AbstractRNG, f::F, x::X, y::Y -) where {F,X<:Number,Y<:AbstractArray} - dx = randn(rng, X) - dy = similar(y) - randn!(rng, dy) - multider_true = ForwardDiff.derivative(f, x) - dx_true = dot(multider_true, dy) - dy_true = multider_true .* dx - return Scenario(; f, x, y, dx, dy, dx_true, dy_true, multider_true) -end - -function make_scenario( - rng::AbstractRNG, f::F, x::X, y::Y -) where {F,X<:AbstractArray,Y<:Number} - dx = similar(x) - randn!(rng, dx) - dy = randn(rng, Y) - grad_true = ForwardDiff.gradient(f, x) - dx_true = grad_true .* dy - dy_true = dot(grad_true, dx) - return Scenario(; f, x, y, dx, dy, dx_true, dy_true, grad_true) -end - -function make_scenario( - rng::AbstractRNG, f::F, x::X, y::Y -) where {F,X<:AbstractArray,Y<:AbstractArray} - dx = similar(x) - randn!(rng, dx) - dy = similar(y) - randn!(rng, dy) - jac_true = ForwardDiff.jacobian(f, x) - dx_true = reshape(transpose(jac_true) * vec(dy), size(x)) - dy_true = reshape(jac_true * vec(dx), size(y)) - return Scenario(; f, x, y, dx, dy, dx_true, dy_true, jac_true) -end - -## Access +pretty(::CustomImplem) = "custom" +pretty(::FallbackImplem) = "fallback" -get_input_type(::Scenario{F,X}) where {F,X} = X -get_output_type(::Scenario{F,X,Y}) where {F,X,Y} = Y - -## Seed - -rng = StableRNG(63) - -## Scenarios - -f_scalar_scalar(x::Number)::Number = sin(x) -f_scalar_vector(x::Number)::AbstractVector = [sin(x), sin(2x)] -f_scalar_matrix(x::Number)::AbstractMatrix = hcat([sin(x) cos(x)], [sin(2x) cos(2x)]) - -function f_vector_scalar(x::AbstractVector)::Number - a = eachindex(x) - return sum(sin.(a .* x)) -end - -function f_matrix_scalar(x::AbstractMatrix)::Number - a, b = axes(x) - return sum(sin.(a .* x)) + sum(cos.(transpose(b) .* x)) -end - -function f_vector_vector(x::AbstractVector)::AbstractVector - a = eachindex(x) - return vcat(sin.(a .* x), cos.(a .* x)) -end - -function f_vector_matrix(x::AbstractVector)::AbstractMatrix - a = eachindex(x) - return hcat(sin.(a .* x), cos.(a .* x)) -end - -function f_matrix_vector(x::AbstractMatrix)::AbstractVector - a, b = axes(x) - return vcat(vec(sin.(a .* x)), vec(cos.(transpose(b) .* x))) -end - -function f_matrix_matrix(x::AbstractMatrix)::AbstractMatrix - a, b = axes(x) - return hcat(vec(sin.(a .* x)), vec(cos.(transpose(b) .* x))) -end - -## All - -scenarios = [ - make_scenario(rng, f_scalar_scalar, 1.0), - make_scenario(rng, f_scalar_vector, 1.0), - make_scenario(rng, f_scalar_matrix, 1.0), - make_scenario(rng, f_vector_scalar, [1.0, 2.0]), - make_scenario(rng, f_matrix_scalar, [1.0 2.0; 3.0 4.0]), - make_scenario(rng, f_vector_vector, [1.0, 2.0]), - make_scenario(rng, f_vector_matrix, [1.0, 2.0]), - make_scenario(rng, f_matrix_vector, [1.0 2.0; 3.0 4.0]), - make_scenario(rng, f_matrix_matrix, [1.0 2.0; 3.0 4.0]), -]; +const NO_EXTRAS = nothing ## Test utilities function test_pushforward( - backend::AbstractForwardBackend, + backend::AbstractADType, scenarios::Vector{<:Scenario}; input_type::Type=Any, output_type::Type=Any, allocs::Bool=false, type_stability::Bool=true, ) + if !isa(autodiff_mode(backend), ForwardMode) + return nothing + end scenarios = filter(scenarios) do s get_input_type(s) <: input_type && get_output_type(s) <: output_type end - @testset "Pushforward ($(is_custom(backend) ? "custom" : "fallback"))" begin + @testset "Pushforward" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -205,17 +71,20 @@ function test_pushforward( end function test_pullback( - backend::AbstractReverseBackend, + backend::AbstractADType, scenarios::Vector{<:Scenario}; input_type::Type=Any, output_type::Type=Any, allocs::Bool=false, type_stability::Bool=true, ) + if !isa(autodiff_mode(backend), ReverseMode) + return nothing + end scenarios = filter(scenarios) do s (get_input_type(s) <: input_type) && (get_output_type(s) <: output_type) end - @testset "Pullback ($(is_custom(backend) ? "custom" : "fallback"))" begin + @testset "Pullback" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -261,7 +130,8 @@ function test_pullback( end function test_derivative( - backend::AbstractBackend, + implem::AbstractImplem, + backend::AbstractADType, scenarios::Vector{<:Scenario}; allocs::Bool=false, type_stability::Bool=true, @@ -269,7 +139,7 @@ function test_derivative( scenarios = filter(scenarios) do s (get_input_type(s) <: Number) && (get_output_type(s) <: Number) end - @testset "Derivative ($(is_custom(backend) ? "custom" : "fallback"))" begin + @testset "Derivative ⁻ $(pretty(implem))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -277,9 +147,9 @@ function test_derivative( @testset "$X -> $Y" begin (; f, x, y, der_true) = scenario - y_out1, der_out1 = value_and_derivative(backend, f, x) + y_out1, der_out1 = value_and_derivative(backend, f, x, NO_EXTRAS, implem) - der_out2 = derivative(backend, f, x) + der_out2 = derivative(backend, f, x, NO_EXTRAS, implem) @testset "Primal value" begin @test y_out1 ≈ y @@ -289,12 +159,14 @@ function test_derivative( @test der_out2 ≈ der_true rtol = 1e-3 end allocs && @testset "Allocations" begin - @test iszero(@allocated value_and_derivative(backend, f, x)) - @test iszero(@allocated derivative(backend, f, x)) + @test iszero( + @allocated value_and_derivative(backend, f, x, NO_EXTRAS, implem) + ) + @test iszero(@allocated derivative(backend, f, x, NO_EXTRAS, implem)) end type_stability && @testset "Type stability" begin - @test_opt value_and_derivative(backend, f, x) - @test_opt derivative(backend, f, x) + @test_opt value_and_derivative(backend, f, x, NO_EXTRAS, implem) + @test_opt derivative(backend, f, x, NO_EXTRAS, implem) end end end @@ -302,7 +174,8 @@ function test_derivative( end function test_multiderivative( - backend::AbstractBackend, + implem::AbstractImplem, + backend::AbstractADType, scenarios::Vector{<:Scenario}; allocs::Bool=false, type_stability::Bool=true, @@ -310,7 +183,7 @@ function test_multiderivative( scenarios = filter(scenarios) do s (get_input_type(s) <: Number) && (get_output_type(s) <: AbstractArray) end - @testset "Multiderivative ($(is_custom(backend) ? "custom" : "fallback"))" begin + @testset "Multiderivative ⁻ $(pretty(implem))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -318,15 +191,19 @@ function test_multiderivative( @testset "$X -> $Y" begin (; f, x, y, multider_true) = scenario - y_out1, multider_out1 = value_and_multiderivative(backend, f, x) + y_out1, multider_out1 = value_and_multiderivative( + backend, f, x, NO_EXTRAS, implem + ) multider_in2 = zero(multider_out1) y_out2, multider_out2 = value_and_multiderivative!( - multider_in2, backend, f, x + multider_in2, backend, f, x, NO_EXTRAS, implem ) - multider_out3 = multiderivative(backend, f, x) + multider_out3 = multiderivative(backend, f, x, NO_EXTRAS, implem) multider_in4 = zero(multider_out3) - multider_out4 = multiderivative!(multider_in4, backend, f, x) + multider_out4 = multiderivative!( + multider_in4, backend, f, x, NO_EXTRAS, implem + ) @testset "Primal value" begin @test y_out1 ≈ y @@ -344,13 +221,23 @@ function test_multiderivative( end allocs && @testset "Allocations" begin @test iszero( - @allocated value_and_multiderivative!(multider_in2, backend, f, x) + @allocated value_and_multiderivative!( + multider_in2, backend, f, x, NO_EXTRAS, implem + ) + ) + @test iszero( + @allocated multiderivative!( + multider_in4, backend, f, x, NO_EXTRAS, implem + ) ) - @test iszero(@allocated multiderivative!(multider_in4, backend, f, x)) end type_stability && @testset "Type stability" begin - @test_opt value_and_multiderivative!(multider_in2, backend, f, x) - @test_opt multiderivative!(multider_in4, backend, f, x) + @test_opt value_and_multiderivative!( + multider_in2, backend, f, x, NO_EXTRAS, implem + ) + @test_opt multiderivative!( + multider_in4, backend, f, x, NO_EXTRAS, implem + ) end end end @@ -358,7 +245,8 @@ function test_multiderivative( end function test_gradient( - backend::AbstractBackend, + implem::AbstractImplem, + backend::AbstractADType, scenarios::Vector{<:Scenario}; allocs::Bool=false, type_stability::Bool=true, @@ -366,7 +254,7 @@ function test_gradient( scenarios = filter(scenarios) do s (get_input_type(s) <: AbstractArray) && (get_output_type(s) <: Number) end - @testset "Gradient ($(is_custom(backend) ? "custom" : "fallback"))" begin + @testset "Gradient ⁻ $(pretty(implem))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -374,13 +262,15 @@ function test_gradient( @testset "$X -> $Y" begin (; f, x, y, grad_true) = scenario - y_out1, grad_out1 = value_and_gradient(backend, f, x) + y_out1, grad_out1 = value_and_gradient(backend, f, x, NO_EXTRAS, implem) grad_in2 = zero(grad_out1) - y_out2, grad_out2 = value_and_gradient!(grad_in2, backend, f, x) + y_out2, grad_out2 = value_and_gradient!( + grad_in2, backend, f, x, NO_EXTRAS, implem + ) - grad_out3 = gradient(backend, f, x) + grad_out3 = gradient(backend, f, x, NO_EXTRAS, implem) grad_in4 = zero(grad_out3) - grad_out4 = gradient!(grad_in4, backend, f, x) + grad_out4 = gradient!(grad_in4, backend, f, x, NO_EXTRAS, implem) @testset "Primal value" begin @test y_out1 ≈ y @@ -397,12 +287,18 @@ function test_gradient( end end allocs && @testset "Allocations" begin - @test iszero(@allocated value_and_gradient!(grad_in2, backend, f, x)) - @test iszero(@allocated gradient!(grad_in4, backend, f, x)) + @test iszero( + @allocated value_and_gradient!( + grad_in2, backend, f, x, NO_EXTRAS, implem + ) + ) + @test iszero( + @allocated gradient!(grad_in4, backend, f, x, NO_EXTRAS, implem) + ) end type_stability && @testset "Type stability" begin - @test_opt value_and_gradient!(grad_in2, backend, f, x) - @test_opt gradient!(grad_in4, backend, f, x) + @test_opt value_and_gradient!(grad_in2, backend, f, x, NO_EXTRAS, implem) + @test_opt gradient!(grad_in4, backend, f, x, NO_EXTRAS, implem) end end end @@ -410,7 +306,8 @@ function test_gradient( end function test_jacobian( - backend::AbstractBackend, + implem::AbstractImplem, + backend::AbstractADType, scenarios::Vector{<:Scenario}; allocs::Bool=false, type_stability::Bool=true, @@ -418,7 +315,7 @@ function test_jacobian( scenarios = filter(scenarios) do s (get_input_type(s) <: AbstractArray) && (get_output_type(s) <: AbstractArray) end - @testset "Jacobian ($(is_custom(backend) ? "custom" : "fallback"))" begin + @testset "Jacobian ⁻ $(pretty(implem))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) handles_types(backend, X, Y) || continue @@ -426,13 +323,15 @@ function test_jacobian( @testset "$X -> $Y" begin (; f, x, y, jac_true) = scenario - y_out1, jac_out1 = value_and_jacobian(backend, f, x) + y_out1, jac_out1 = value_and_jacobian(backend, f, x, NO_EXTRAS, implem) jac_in2 = zero(jac_out1) - y_out2, jac_out2 = value_and_jacobian!(jac_in2, backend, f, x) + y_out2, jac_out2 = value_and_jacobian!( + jac_in2, backend, f, x, NO_EXTRAS, implem + ) - jac_out3 = jacobian(backend, f, x) + jac_out3 = jacobian(backend, f, x, NO_EXTRAS, implem) jac_in4 = zero(jac_out3) - jac_out4 = jacobian!(jac_in4, backend, f, x) + jac_out4 = jacobian!(jac_in4, backend, f, x, NO_EXTRAS, implem) @testset "Primal value" begin @test y_out1 ≈ y @@ -449,12 +348,18 @@ function test_jacobian( end end allocs && @testset "Allocations" begin - @test iszero(@allocated value_and_jacobian!(jac_in2, backend, f, x)) - @test iszero(@allocated jacobian!(jac_in4, backend, f, x)) + @test iszero( + @allocated value_and_jacobian!( + jac_in2, backend, f, x, NO_EXTRAS, implem + ) + ) + @test iszero( + @allocated jacobian!(jac_in4, backend, f, x, NO_EXTRAS, implem) + ) end type_stability && @testset "Type stability" begin - @test_opt value_and_jacobian!(jac_in2, backend, f, x) - @test_opt jacobian!(jac_in4, backend, f, x) + @test_opt value_and_jacobian!(jac_in2, backend, f, x, NO_EXTRAS, implem) + @test_opt jacobian!(jac_in4, backend, f, x, NO_EXTRAS, implem) end end end @@ -462,7 +367,8 @@ function test_jacobian( end function test_jacobian_and_friends( - backend::AbstractBackend, + implem::AbstractImplem, + backend::AbstractADType, scenarios::Vector{<:Scenario}; input_type::Type=Any, output_type::Type=Any, @@ -472,9 +378,9 @@ function test_jacobian_and_friends( scenarios = filter(scenarios) do s (get_input_type(s) <: input_type) && (get_output_type(s) <: output_type) end - test_derivative(backend, scenarios; allocs, type_stability) - test_multiderivative(backend, scenarios; allocs, type_stability) - test_gradient(backend, scenarios; allocs, type_stability) - test_jacobian(backend, scenarios; allocs, type_stability) + test_derivative(implem, backend, scenarios; allocs, type_stability) + test_multiderivative(implem, backend, scenarios; allocs, type_stability) + test_gradient(implem, backend, scenarios; allocs, type_stability) + test_jacobian(implem, backend, scenarios; allocs, type_stability) return nothing end diff --git a/test/zygote.jl b/test/zygote.jl index 503dfc4c0..12c4c18bf 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -1,6 +1,7 @@ -using DifferentiationInterface +using ADTypes: AutoZygote +using DifferentiationInterface: CustomImplem, FallbackImplem using Zygote: Zygote -test_pullback(ZygoteBackend(), scenarios; type_stability=false); -test_jacobian_and_friends(ZygoteBackend(; custom=true), scenarios; type_stability=false); -test_jacobian_and_friends(ZygoteBackend(; custom=false), scenarios; type_stability=false); +test_pullback(AutoZygote(), scenarios; type_stability=false); +test_jacobian_and_friends(CustomImplem(), AutoZygote(), scenarios; type_stability=false); +test_jacobian_and_friends(FallbackImplem(), AutoZygote(), scenarios; type_stability=false);