From 47e26ece0d9ff5ea9624f10b72c4682dae2d73cf Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 8 Mar 2024 12:54:32 +0100 Subject: [PATCH] Backend-specific utilities (derivative, multiderivative, gradient, jacobian) (#24) * Add special cases * More complex backends * Fix benchmarks * More benchmarks * Reactivate JET test * Remove weird build folder * Don't import dangerous stuff before JET test * Re-skip JET * Wrong JET test * Better docstrings, more exhaustive tests, benchmark with NN layer * Minor doc fixes --- .gitignore | 1 + Project.toml | 3 + README.md | 55 ++++- benchmark/Project.toml | 1 + benchmark/benchmarks.jl | 225 +++++++++++++++--- docs/Project.toml | 3 +- docs/make.jl | 48 ++-- docs/src/backends.md | 73 ++++++ docs/src/extensions.md | 41 ---- docs/src/index.md | 62 ++++- docs/src/{api.md => interface.md} | 40 ++-- ...fferentiationInterfaceChainRulesCoreExt.jl | 62 ++++- ext/DifferentiationInterfaceEnzymeExt.jl | 62 ----- .../DifferentiationInterfaceEnzymeExt.jl | 29 +++ .../forward.jl | 35 +++ .../reverse.jl | 44 ++++ ext/DifferentiationInterfaceFiniteDiffExt.jl | 138 ++++++----- ext/DifferentiationInterfaceForwardDiffExt.jl | 93 ++++++-- ...tiationInterfacePolyesterForwardDiffExt.jl | 48 ++++ ext/DifferentiationInterfaceReverseDiffExt.jl | 52 +++- ext/DifferentiationInterfaceZygoteExt.jl | 49 +++- src/DifferentiationInterface.jl | 18 +- src/array_array.jl | 16 +- src/array_scalar.jl | 5 +- src/backends.jl | 120 ++++++---- src/backends_abstract.jl | 68 ++++++ src/pullback.jl | 29 +-- src/pushforward.jl | 29 +-- src/scalar_array.jl | 5 +- src/scalar_scalar.jl | 3 +- src/utils.jl | 4 +- test/Project.toml | 1 + test/backends.jl | 11 + test/diffractor.jl | 8 +- test/enzyme.jl | 24 -- test/enzyme_forward.jl | 10 + test/enzyme_reverse.jl | 10 + test/finitediff.jl | 5 +- test/forwarddiff.jl | 9 +- test/polyesterforwarddiff.jl | 16 ++ test/reversediff.jl | 7 +- test/runtests.jl | 23 +- test/utils.jl | 223 +++++++++-------- test/zygote.jl | 7 +- 44 files changed, 1271 insertions(+), 544 deletions(-) create mode 100644 docs/src/backends.md delete mode 100644 docs/src/extensions.md rename docs/src/{api.md => interface.md} (89%) delete mode 100644 ext/DifferentiationInterfaceEnzymeExt.jl create mode 100644 ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl create mode 100644 ext/DifferentiationInterfaceEnzymeExt/forward.jl create mode 100644 ext/DifferentiationInterfaceEnzymeExt/reverse.jl create mode 100644 ext/DifferentiationInterfacePolyesterForwardDiffExt.jl create mode 100644 src/backends_abstract.jl create mode 100644 test/backends.jl delete mode 100644 test/enzyme.jl create mode 100644 test/enzyme_forward.jl create mode 100644 test/enzyme_reverse.jl create mode 100644 test/polyesterforwarddiff.jl diff --git a/.gitignore b/.gitignore index 5a1bc4c48..8b14d40cd 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ /benchmark/*.json /docs/build/ +/docs/src/index.md /Manifest.toml /docs/Manifest.toml diff --git a/Project.toml b/Project.toml index 1c7fd16d4..06c8b9268 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -22,6 +23,7 @@ DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore" DifferentiationInterfaceEnzymeExt = "Enzyme" DifferentiationInterfaceFiniteDiffExt = "FiniteDiff" DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] +DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceZygoteExt = ["Zygote"] @@ -34,6 +36,7 @@ FillArrays = "1" FiniteDiff = "2.22" ForwardDiff = "0.10" LinearAlgebra = "1" +PolyesterForwardDiff = "0.1" ReverseDiff = "1.15" Zygote = "0.6" julia = "1.10" diff --git a/README.md b/README.md index 57da14887..f76db1417 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,57 @@ [![Coverage](https://codecov.io/gh/gdalle/DifferentiationInterface.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/gdalle/DifferentiationInterface.jl) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) -An experimental redesign for [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl). +An interface to various automatic differentiation backends in Julia. -See the documentation for details. +## Goal + +This package provides a backend-agnostic syntax to differentiate functions `f(x) = y`, where `x` and `y` are either numbers or abstract arrays. + +It started out as an experimental redesign for [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl). + +## Example + +```jldoctest +julia> using DifferentiationInterface, Enzyme + +julia> backend = EnzymeReverseBackend(); + +julia> f(x) = sum(abs2, x); + +julia> value_and_gradient(backend, f, [1., 2., 3.]) +(14.0, [2.0, 4.0, 6.0]) +``` + +## Design + +Each backend must implement only one primitive: + +- forward mode: the pushforward, computing a Jacobian-vector product +- reverse mode: the pullback, computing a vector-Jacobian product + +From these primitives, several utilities are defined, depending on the type of the input and output: + +| | scalar output | array output | +| ------------ | ------------- | --------------- | +| scalar input | derivative | multiderivative | +| array input | gradient | jacobian | + +## Supported backends + +Forward mode: + +- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) +- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) +- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) +- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) + +Reverse mode: + +- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) +- [Zygote.jl](https://github.com/FluxML/Zygote.jl) +- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) + +Experimental: + +- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) +- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) (currently broken due to [#277](https://github.com/JuliaDiff/Diffractor.jl/issues/277)) diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 8beb074b7..209efe139 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -6,5 +6,6 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 0b89c57e3..7dab20d5e 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -1,63 +1,210 @@ -# Run benchmarks locally by calling: -# julia -e 'using BenchmarkCI; BenchmarkCI.judge(baseline="origin/main"); BenchmarkCI.displayjudgement()' - -using Base: Fix2 using BenchmarkTools using DifferentiationInterface using LinearAlgebra -using Diffractor: Diffractor using Enzyme: Enzyme using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff +using PolyesterForwardDiff: PolyesterForwardDiff using ReverseDiff: ReverseDiff using Zygote: Zygote -scalar_to_scalar(x::Real) = x -scalar_to_vector(x::Real, n) = collect((1:n) .* x) -vector_to_scalar(x::AbstractVector{<:Real}) = dot(1:length(x), x) -vector_to_vector(x::AbstractVector{<:Real}) = (1:length(x)) .* x +## Settings + +BenchmarkTools.DEFAULT_PARAMETERS.evals = 1 +BenchmarkTools.DEFAULT_PARAMETERS.samples = 100 +BenchmarkTools.DEFAULT_PARAMETERS.seconds = 1 + +## Functions + +struct Layer{W<:Union{Number,AbstractArray},B<:Union{Number,AbstractArray},S<:Function} + w::W + b::B + σ::S +end + +function (l::Layer{<:Number,<:Number})(x::Number)::Number + return l.σ(l.w * x + l.b) +end + +function (l::Layer{<:AbstractVector,<:AbstractVector})(x::Number)::AbstractVector + return l.σ.(l.w .* x .+ l.b) +end -forward_backends = [EnzymeForwardBackend(), FiniteDiffBackend(), ForwardDiffBackend()] +function (l::Layer{<:AbstractVector,<:Number})(x::AbstractVector)::Number + return l.σ(dot(l.w, x) + l.b) +end + +function (l::Layer{<:AbstractMatrix,<:AbstractVector})(x::AbstractVector)::AbstractVector + return l.σ.(l.w * x .+ l.b) +end -reverse_backends = [ - ChainRulesReverseBackend(Zygote.ZygoteRuleConfig()), - EnzymeReverseBackend(), - ReverseDiffBackend(), +## Backends + +forward_custom_backends = [ + EnzymeForwardBackend(; custom=true), + FiniteDiffBackend(; custom=true), + ForwardDiffBackend(; custom=true), + PolyesterForwardDiffBackend(4; custom=true), +] + +forward_fallback_backends = [ + EnzymeForwardBackend(; custom=false), + FiniteDiffBackend(; custom=false), + ForwardDiffBackend(; custom=false), ] -n_values = [10] +reverse_custom_backends = [ + ZygoteBackend(; custom=true), + EnzymeReverseBackend(; custom=true), + ReverseDiffBackend(; custom=true), +] + +reverse_fallback_backends = [ + ZygoteBackend(; custom=false), + EnzymeReverseBackend(; custom=false), + ReverseDiffBackend(; custom=false), +] + +all_custom_backends = vcat(forward_custom_backends, reverse_custom_backends) +all_fallback_backends = vcat(forward_fallback_backends, reverse_fallback_backends) +all_backends = vcat(all_custom_backends, all_fallback_backends) + +## Suite SUITE = BenchmarkGroup() -for n in n_values - for backend in forward_backends - SUITE["forward"]["scalar_to_scalar"][n][string(backend)] = @benchmarkable begin - value_and_pushforward!(dy, $backend, scalar_to_scalar, x, dx) - end setup = (x = 1.0; dx = 1.0; dy = 0.0) evals = 1 - if backend != EnzymeForwardBackend() # type instability? - SUITE["forward"]["scalar_to_vector"][n][string(backend)] = @benchmarkable begin - value_and_pushforward!(dy, $backend, Fix2(scalar_to_vector, $n), x, dx) - end setup = (x = 1.0; dx = 1.0; dy = zeros($n)) evals = 1 +### Scalar to scalar + +scalar_to_scalar = Layer(randn(), randn(), tanh) + +for backend in all_backends + handles_types(backend, Number, Number) || continue + SUITE["value_and_derivative"][(1, 1)][string(backend)] = @benchmarkable begin + value_and_derivative($backend, $scalar_to_scalar, x) + end setup = (x = randn()) +end + +for backend in all_fallback_backends + handles_types(backend, Number, Number) || continue + if autodiff_mode(backend) == :forward + SUITE["value_and_pushforward"][(1, 1)][string(backend)] = @benchmarkable begin + value_and_pushforward($backend, $scalar_to_scalar, x, dx) + end setup = (x = randn(); dx = randn()) + else + SUITE["value_and_pullback"][(1, 1)][string(backend)] = @benchmarkable begin + value_and_pullback($backend, $scalar_to_scalar, x, dy) + end setup = (x = randn(); dy = randn()) + end +end + +### Scalar to vector + +for m in [10] + scalar_to_vector = Layer(randn(m), randn(m), tanh) + + for backend in all_backends + handles_types(backend, Number, Vector) || continue + SUITE["value_and_multiderivative"][(1, m)][string(backend)] = @benchmarkable begin + value_and_multiderivative($backend, $scalar_to_vector, x) + end setup = (x = randn()) + SUITE["value_and_multiderivative!"][(1, m)][string(backend)] = @benchmarkable begin + value_and_multiderivative!(multider, $backend, $scalar_to_vector, x) + end setup = (x = randn(); multider = zeros($m)) + end + + for backend in all_fallback_backends + handles_types(backend, Number, Vector) || continue + if autodiff_mode(backend) == :forward + SUITE["value_and_pushforward"][(1, m)][string(backend)] = @benchmarkable begin + value_and_pushforward($backend, $scalar_to_vector, x, dx) + end setup = (x = randn(); dx = randn()) + SUITE["value_and_pushforward!"][(1, m)][string(backend)] = @benchmarkable begin + value_and_pushforward!(dy, $backend, $scalar_to_vector, x, dx) + end setup = (x = randn(); dx = randn(); dy = zeros($m)) + else + SUITE["value_and_pullback"][(1, m)][string(backend)] = @benchmarkable begin + value_and_pullback($backend, $scalar_to_vector, x, dy) + end setup = (x = randn(); dy = ones($m)) + SUITE["value_and_pullback!"][(1, m)][string(backend)] = @benchmarkable begin + value_and_pullback!(dx, $backend, $scalar_to_vector, x, dy) + end setup = (x = randn(); dy = ones($m); dx = 0.0) end - SUITE["forward"]["vector_to_vector"][n][string(backend)] = @benchmarkable begin - value_and_pushforward!(dy, $backend, vector_to_vector, x, dx) - end setup = (x = randn($n); dx = randn($n); dy = zeros($n)) evals = 1 + end +end + +### Vector to scalar + +for n in [10] + vector_to_scalar = Layer(randn(n), randn(), tanh) + + for backend in all_backends + handles_types(backend, Vector, Number) || continue + SUITE["value_and_gradient"][(n, 1)][string(backend)] = @benchmarkable begin + value_and_gradient($backend, $vector_to_scalar, x) + end setup = (x = randn($n)) + SUITE["value_and_gradient!"][(n, 1)][string(backend)] = @benchmarkable begin + value_and_gradient!(grad, $backend, $vector_to_scalar, x) + end setup = (x = randn($n); grad = zeros($n)) end - for backend in reverse_backends - if backend != ReverseDiffBackend() - SUITE["reverse"]["scalar_to_scalar"][n][string(backend)] = @benchmarkable begin - value_and_pullback!(dx, $backend, scalar_to_scalar, x, dy) - end setup = (x = 1.0; dy = 1.0; dx = 0.0) evals = 1 + for backend in all_fallback_backends + handles_types(backend, Vector, Number) || continue + if autodiff_mode(backend) == :forward + SUITE["value_and_pushforward"][(n, 1)][string(backend)] = @benchmarkable begin + value_and_pushforward($backend, $vector_to_scalar, x, dx) + end setup = (x = randn($n); dx = randn($n)) + SUITE["value_and_pushforward!"][(n, 1)][string(backend)] = @benchmarkable begin + value_and_pushforward!(dy, $backend, $vector_to_scalar, x, dx) + end setup = (x = randn($n); dx = randn($n); dy = 0.0) + else + SUITE["value_and_pullback"][(n, 1)][string(backend)] = @benchmarkable begin + value_and_pullback($backend, $vector_to_scalar, x, dy) + end setup = (x = randn($n); dy = randn()) + SUITE["value_and_pullback!"][(n, 1)][string(backend)] = @benchmarkable begin + value_and_pullback!(dx, $backend, $vector_to_scalar, x, dy) + end setup = (x = randn($n); dy = randn(); dx = zeros($n)) end - SUITE["reverse"]["vector_to_scalar"][n][string(backend)] = @benchmarkable begin - value_and_pullback!(dx, $backend, vector_to_scalar, x, dy) - end setup = (x = randn($n); dy = 1.0; dx = zeros($n)) evals = 1 - if backend != EnzymeReverseBackend() - SUITE["reverse"]["vector_to_vector"][n][string(backend)] = @benchmarkable begin - value_and_pullback!(dx, $backend, vector_to_vector, x, dy) - end setup = (x = randn($n); dy = randn($n); dx = zeros($n)) evals = 1 + end +end + +### Vector to vector + +for (n, m) in [(10, 10)] + vector_to_vector = Layer(randn(m, n), randn(m), tanh) + + for backend in all_backends + handles_types(backend, Vector, Vector) || continue + SUITE["value_and_jacobian"][(n, m)][string(backend)] = @benchmarkable begin + value_and_jacobian($backend, $vector_to_vector, x) + end setup = (x = randn($n)) + SUITE["value_and_jacobian!"][(n, m)][string(backend)] = @benchmarkable begin + value_and_jacobian!(jac, $backend, $vector_to_vector, x) + end setup = (x = randn($n); jac = zeros($m, $n)) + end + + for backend in all_fallback_backends + handles_types(backend, Vector, Vector) || continue + if autodiff_mode(backend) == :forward + SUITE["value_and_pushforward"][(n, m)][string(backend)] = @benchmarkable begin + value_and_pushforward($backend, $vector_to_vector, x, dx) + end setup = (x = randn($n); dx = randn($n)) + SUITE["value_and_pushforward!"][(n, m)][string(backend)] = @benchmarkable begin + value_and_pushforward!(dy, $backend, $vector_to_vector, x, dx) + end setup = (x = randn($n); dx = randn($n); dy = zeros($m)) + else + SUITE["value_and_pullback"][(n, m)][string(backend)] = @benchmarkable begin + value_and_pullback($backend, $vector_to_vector, x, dy) + end setup = (x = randn($n); dy = randn($m)) + SUITE["value_and_pullback!"][(n, m)][string(backend)] = @benchmarkable begin + value_and_pullback!(dx, $backend, $vector_to_vector, x, dy) + end setup = (x = randn($n); dy = randn($m); dx = zeros($n)) end end end + +# Run benchmarks locally +# results = BenchmarkTools.run(SUITE; verbose=true) + +# Compare commits locally +# using BenchmarkCI; BenchmarkCI.judge(baseline="origin/main"); BenchmarkCI.displayjudgement() diff --git a/docs/Project.toml b/docs/Project.toml index f2efbc6b8..7850516f4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,11 +1,12 @@ [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/docs/make.jl b/docs/make.jl index fe5164808..16c3a5566 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,20 +1,25 @@ +using Base: get_extension using DifferentiationInterface import DifferentiationInterface as DI using Documenter -using Diffractor: Diffractor +using DiffResults: DiffResults using Enzyme: Enzyme using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff +using PolyesterForwardDiff: PolyesterForwardDiff using ReverseDiff: ReverseDiff using Zygote: Zygote -DIChainRulesCoreExt = Base.get_extension(DI, :DifferentiationInterfaceChainRulesCoreExt) -DIEnzymeExt = Base.get_extension(DI, :DifferentiationInterfaceEnzymeExt) -DIFiniteDiffExt = Base.get_extension(DI, :DifferentiationInterfaceFiniteDiffExt) -DIForwardDiffExt = Base.get_extension(DI, :DifferentiationInterfaceForwardDiffExt) -DIReverseDiffExt = Base.get_extension(DI, :DifferentiationInterfaceReverseDiffExt) -DIZygoteExt = Base.get_extension(DI, :DifferentiationInterfaceZygoteExt) +ChainRulesCoreExt = get_extension(DI, :DifferentiationInterfaceChainRulesCoreExt) +EnzymeExt = get_extension(DI, :DifferentiationInterfaceEnzymeExt) +FiniteDiffExt = get_extension(DI, :DifferentiationInterfaceFiniteDiffExt) +ForwardDiffExt = get_extension(DI, :DifferentiationInterfaceForwardDiffExt) +PolyesterForwardDiffExt = get_extension( + DI, :DifferentiationInterfacePolyesterForwardDiffExt +) +ReverseDiffExt = get_extension(DI, :DifferentiationInterfaceReverseDiffExt) +ZygoteExt = get_extension(DI, :DifferentiationInterfaceZygoteExt) DocMeta.setdocmeta!( DifferentiationInterface, @@ -23,15 +28,30 @@ DocMeta.setdocmeta!( recursive=true, ) +open(joinpath(@__DIR__, "src", "index.md"), "w") do io + println( + io, + """ + ```@meta + EditURL = "https://github.com/gdalle/DifferentiationInterface.jl/blob/main/README.md" + ``` + """, + ) + for line in eachline(joinpath(dirname(@__DIR__), "README.md")) + println(io, line) + end +end + makedocs(; modules=[ DifferentiationInterface, - DIChainRulesCoreExt, - DIEnzymeExt, - DIFiniteDiffExt, - DIForwardDiffExt, - DIReverseDiffExt, - DIZygoteExt, + ChainRulesCoreExt, + EnzymeExt, + FiniteDiffExt, + ForwardDiffExt, + PolyesterForwardDiffExt, + ReverseDiffExt, + ZygoteExt, ], authors="Guillaume Dalle, Adrian Hill", sitename="DifferentiationInterface.jl", @@ -40,7 +60,7 @@ makedocs(; canonical="https://gdalle.github.io/DifferentiationInterface.jl", ), pages=[ - "Home" => "index.md", "API reference" => "api.md", "Extensions" => "extensions.md" + "Home" => "index.md", "Interface" => "interface.md", "Backends" => "backends.md" ], ) diff --git a/docs/src/backends.md b/docs/src/backends.md new file mode 100644 index 000000000..df9b2b271 --- /dev/null +++ b/docs/src/backends.md @@ -0,0 +1,73 @@ +```@meta +CollapsedDocStrings = true +``` + +# Backends + +## ChainRulesCore + +```@docs +ChainRulesForwardBackend +ChainRulesReverseBackend +``` + +```@autodocs +Modules = [ChainRulesCoreExt] +``` + +## Enzyme + +```@docs +EnzymeForwardBackend +EnzymeReverseBackend +``` + +```@autodocs +Modules = [EnzymeExt] +``` + +## FiniteDiff + +```@docs +FiniteDiffBackend +``` + +```@autodocs +Modules = [FiniteDiffExt] +``` + +## ForwardDiff + +```@docs +ForwardDiffBackend +``` + +```@autodocs +Modules = [ForwardDiffExt] +``` + +## PolyesterForwardDiff + +```@docs +PolyesterForwardDiffBackend +``` + +```@autodocs +Modules = [PolyesterForwardDiffExt] +``` + +## ReverseDiff + +```@docs +ReverseDiffBackend +``` + +```@autodocs +Modules = [ReverseDiffExt] +``` + +## Zygote + +```@autodocs +Modules = [ZygoteExt] +``` diff --git a/docs/src/extensions.md b/docs/src/extensions.md deleted file mode 100644 index b50ef240b..000000000 --- a/docs/src/extensions.md +++ /dev/null @@ -1,41 +0,0 @@ -```@meta -CollapsedDocStrings = true -``` - -# Extensions - -## ChainRulesCore - -```@autodocs -Modules = [DIChainRulesCoreExt] -``` - -## Enzyme - -```@autodocs -Modules = [DIEnzymeExt] -``` - -## FiniteDiff - -```@autodocs -Modules = [DIFiniteDiffExt] -``` - -## ForwardDiff - -```@autodocs -Modules = [DIForwardDiffExt] -``` - -## ReverseDiff - -```@autodocs -Modules = [DIReverseDiffExt] -``` - -## Zygote - -```@autodocs -Modules = [DIZygoteExt] -``` diff --git a/docs/src/index.md b/docs/src/index.md index e08407892..9cdd27b68 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,17 +1,65 @@ ```@meta -CurrentModule = DifferentiationInterface -CollapsedDocStrings = true +EditURL = "https://github.com/gdalle/DifferentiationInterface.jl/blob/main/README.md" ``` # DifferentiationInterface -Documentation for [DifferentiationInterface](https://github.com/gdalle/DifferentiationInterface.jl). +[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://gdalle.github.io/DifferentiationInterface.jl/dev/) +[![Build Status](https://github.com/gdalle/DifferentiationInterface.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/gdalle/DifferentiationInterface.jl/actions/workflows/CI.yml?query=branch%3Amain) +[![Coverage](https://codecov.io/gh/gdalle/DifferentiationInterface.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/gdalle/DifferentiationInterface.jl) +[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) -This is an interface to various autodiff backends for differentiating functions of the form `f(x) = y`, where `x` and `y` are either numbers or abstract arrays. +An interface to various automatic differentiation backends in Julia. -## Terminology +## Goal -| | scalar output | vector output | +This package provides a backend-agnostic syntax to differentiate functions `f(x) = y`, where `x` and `y` are either numbers or abstract arrays. + +It started out as an experimental redesign for [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl). + +## Example + +```jldoctest +julia> using DifferentiationInterface, Enzyme + +julia> backend = EnzymeReverseBackend(); + +julia> f(x) = sum(abs2, x); + +julia> value_and_gradient(backend, f, [1., 2., 3.]) +(14.0, [2.0, 4.0, 6.0]) +``` + +## Design + +Each backend must implement only one primitive: + +- forward mode: the pushforward, computing a Jacobian-vector product +- reverse mode: the pullback, computing a vector-Jacobian product + +From these primitives, several utilities are defined, depending on the type of the input and output: + +| | scalar output | array output | | ------------ | ------------- | --------------- | | scalar input | derivative | multiderivative | -| vector input | gradient | jacobian | +| array input | gradient | jacobian | + +## Supported backends + +Forward mode: + +- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) +- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) +- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) +- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) + +Reverse mode: + +- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) +- [Zygote.jl](https://github.com/FluxML/Zygote.jl) +- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) + +Experimental: + +- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) +- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) (currently broken due to [#277](https://github.com/JuliaDiff/Diffractor.jl/issues/277)) diff --git a/docs/src/api.md b/docs/src/interface.md similarity index 89% rename from docs/src/api.md rename to docs/src/interface.md index 491c2d513..6f25db112 100644 --- a/docs/src/api.md +++ b/docs/src/interface.md @@ -3,68 +3,68 @@ CurrentModule = DifferentiationInterface CollapsedDocStrings = true ``` -# API reference +# Interface ```@docs DifferentiationInterface ``` -## Backends +## Utilities + +### Scalar to scalar ```@autodocs Modules = [DifferentiationInterface] -Pages = ["backends.jl"] +Pages = ["scalar_scalar.jl"] ``` -## Primitives - -### Pushforward +### Scalar to array ```@autodocs Modules = [DifferentiationInterface] -Pages = ["pushforward.jl"] +Pages = ["scalar_array.jl"] ``` -### Pullback +### Array to scalar ```@autodocs Modules = [DifferentiationInterface] -Pages = ["pullback.jl"] +Pages = ["array_scalar.jl"] ``` -## Special cases - -### Scalar to scalar +### Array to array ```@autodocs Modules = [DifferentiationInterface] -Pages = ["scalar_scalar.jl"] +Pages = ["array_array.jl"] ``` -### Scalar to array +## Primitives + +### Pushforward ```@autodocs Modules = [DifferentiationInterface] -Pages = ["scalar_array.jl"] +Pages = ["pushforward.jl"] ``` -### Array to scalar +### Pullback ```@autodocs Modules = [DifferentiationInterface] -Pages = ["array_scalar.jl"] +Pages = ["pullback.jl"] ``` -### Array to array +## Abstract backends ```@autodocs Modules = [DifferentiationInterface] -Pages = ["array_array.jl"] +Pages = ["backends_abstract.jl"] ``` ## Internals ```@autodocs Modules = [DifferentiationInterface] -Public = false +Pages = ["utils.jl"] ``` diff --git a/ext/DifferentiationInterfaceChainRulesCoreExt.jl b/ext/DifferentiationInterfaceChainRulesCoreExt.jl index f7d6fcfd7..5e116c557 100644 --- a/ext/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/ext/DifferentiationInterfaceChainRulesCoreExt.jl @@ -1,7 +1,9 @@ module DifferentiationInterfaceChainRulesCoreExt -using ChainRulesCore: NoTangent, frule_via_ad, rrule_via_ad -using DifferentiationInterface +using ChainRulesCore: + HasForwardsMode, HasReverseMode, NoTangent, RuleConfig, frule_via_ad, rrule_via_ad +using DifferentiationInterface: ChainRulesForwardBackend, ChainRulesReverseBackend +import DifferentiationInterface as DI using DocStringExtensions ruleconfig(backend::ChainRulesForwardBackend) = backend.ruleconfig @@ -10,26 +12,62 @@ ruleconfig(backend::ChainRulesReverseBackend) = backend.ruleconfig update!(_old::Number, new::Number) = new update!(old, new) = old .= new +## Backend construction + +""" + ChainRulesForwardBackend(rc::RuleConfig; custom=true) + +Construct a [`ChainRulesForwardBackend`](@ref) from a [`RuleConfig`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/superpowers/ruleconfig.html) object that `HasForwardsMode`. + +## Example + +```julia +using Diffractor, DifferentiationInterface +backend = ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()) +``` """ -$(TYPEDSIGNATURES) +function DI.ChainRulesForwardBackend(rc::RuleConfig{>:HasForwardsMode}; custom::Bool=true) + return ChainRulesForwardBackend{custom,typeof(rc)}(rc) +end + +""" + ChainRulesReverseBackend(rc::RuleConfig; custom=true) + +Construct a [`ChainRulesReverseBackend`](@ref) from a [`RuleConfig`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/superpowers/ruleconfig.html) object that `HasReverseMode`. + +## Example + +```julia +using Zygote, DifferentiationInterface +backend = ChainRulesReverseBackend(Zygote.ZygoteRuleConfig()) +``` """ -function DifferentiationInterface.value_and_pushforward!( - dy::Y, backend::ChainRulesForwardBackend, f, x::X, dx -) where {X,Y} +function DI.ChainRulesReverseBackend(rc::RuleConfig{>:HasReverseMode}; custom::Bool=true) + return ChainRulesReverseBackend{custom,typeof(rc)}(rc) +end + +## Primitives + +function DI.value_and_pushforward(backend::ChainRulesForwardBackend, f, x, dx) rc = ruleconfig(backend) y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x) + return y, new_dy +end + +function DI.value_and_pushforward!(dy, backend::ChainRulesForwardBackend, f, x, dx) + y, new_dy = DI.value_and_pushforward(backend, f, x, dx) return y, update!(dy, new_dy) end -""" -$(TYPEDSIGNATURES) -""" -function DifferentiationInterface.value_and_pullback!( - dx, backend::ChainRulesReverseBackend, f, x::X, dy::Y -) where {X,Y} +function DI.value_and_pullback(backend::ChainRulesReverseBackend, f, x, dy) rc = ruleconfig(backend) y, pullback = rrule_via_ad(rc, f, x) _, new_dx = pullback(dy) + return y, new_dx +end + +function DI.value_and_pullback!(dx, backend::ChainRulesReverseBackend, f, x, dy) + y, new_dx = DI.value_and_pullback(backend, f, x, dy) return y, update!(dx, new_dx) end diff --git a/ext/DifferentiationInterfaceEnzymeExt.jl b/ext/DifferentiationInterfaceEnzymeExt.jl deleted file mode 100644 index 47bacf870..000000000 --- a/ext/DifferentiationInterfaceEnzymeExt.jl +++ /dev/null @@ -1,62 +0,0 @@ -module DifferentiationInterfaceEnzymeExt - -using DifferentiationInterface -using DocStringExtensions -using Enzyme: Forward, ReverseWithPrimal, Active, Duplicated, autodiff - -const EnzymeBackends = Union{EnzymeForwardBackend,EnzymeReverseBackend} - -## Unit vector - -# Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type -function DifferentiationInterface.basisarray( - ::EnzymeBackends, a::AbstractArray{T}, i -) where {T} - b = zero(a) - b[i] = one(T) - return b -end - -## Forward mode - -""" -$(TYPEDSIGNATURES) -""" -function DifferentiationInterface.value_and_pushforward!( - _dy::Y, ::EnzymeForwardBackend, f, x::X, dx -) where {X,Y<:Real} - y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx)) - return y, new_dy -end - -""" -$(TYPEDSIGNATURES) -""" -function DifferentiationInterface.value_and_pushforward!( - dy::Y, ::EnzymeForwardBackend, f, x::X, dx -) where {X,Y<:AbstractArray} - y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx)) - dy .= new_dy - return y, dy -end - -## Reverse mode - -function DifferentiationInterface.value_and_pullback!( - _dx, ::EnzymeReverseBackend, f, x::X, dy::Y -) where {X<:Number,Y<:Union{Real,Nothing}} - der, y = autodiff(ReverseWithPrimal, f, Active, Active(x)) - new_dx = dy * only(der) - return y, new_dx -end - -function DifferentiationInterface.value_and_pullback!( - dx::X, ::EnzymeReverseBackend, f, x::X, dy::Y -) where {X<:AbstractArray,Y<:Union{Real,Nothing}} - dx .= zero(eltype(dx)) - _, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx)) - dx .*= dy - return y, dx -end - -end # module diff --git a/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl new file mode 100644 index 000000000..fcbeda077 --- /dev/null +++ b/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -0,0 +1,29 @@ +module DifferentiationInterfaceEnzymeExt + +using DifferentiationInterface: EnzymeForwardBackend, EnzymeReverseBackend +import DifferentiationInterface as DI +using DocStringExtensions +using Enzyme: + Active, + Duplicated, + Forward, + Reverse, + ReverseWithPrimal, + autodiff, + gradient, + gradient!, + jacobian + +# Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type +function DI.basisarray( + ::Union{EnzymeForwardBackend,EnzymeReverseBackend}, a::AbstractArray{T}, i +) where {T} + b = zero(a) + b[i] = one(T) + return b +end + +include("forward.jl") +include("reverse.jl") + +end # module diff --git a/ext/DifferentiationInterfaceEnzymeExt/forward.jl b/ext/DifferentiationInterfaceEnzymeExt/forward.jl new file mode 100644 index 000000000..95aadfa51 --- /dev/null +++ b/ext/DifferentiationInterfaceEnzymeExt/forward.jl @@ -0,0 +1,35 @@ + +## Backend construction + +""" + EnzymeForwardBackend(; custom=true) + +Construct a [`EnzymeForwardBackend`](@ref). +""" +DI.EnzymeForwardBackend(; custom::Bool=true) = EnzymeForwardBackend{custom}() + +## Primitives + +function DI.value_and_pushforward!( + _dy::Y, ::EnzymeForwardBackend, f, x::X, dx +) where {X,Y<:Real} + y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx)) + return y, new_dy +end + +function DI.value_and_pushforward!( + dy::Y, ::EnzymeForwardBackend, f, x::X, dx +) where {X,Y<:AbstractArray} + y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx)) + dy .= new_dy + return y, dy +end + +## Utilities + +function DI.value_and_jacobian(::EnzymeForwardBackend{true}, f, x::AbstractArray) + y = f(x) + jac = jacobian(Forward, f, x) + # see https://github.com/EnzymeAD/Enzyme.jl/issues/1332 + return y, reshape(jac, length(y), length(x)) +end diff --git a/ext/DifferentiationInterfaceEnzymeExt/reverse.jl b/ext/DifferentiationInterfaceEnzymeExt/reverse.jl new file mode 100644 index 000000000..053bd3e37 --- /dev/null +++ b/ext/DifferentiationInterfaceEnzymeExt/reverse.jl @@ -0,0 +1,44 @@ + +## Backend construction + +""" + EnzymeReverseBackend(; custom=true) + +Construct a [`EnzymeReverseBackend`](@ref). +""" +DI.EnzymeReverseBackend(; custom::Bool=true) = EnzymeReverseBackend{custom}() + +## Primitives + +function DI.value_and_pullback!( + _dx, ::EnzymeReverseBackend, f, x::X, dy::Y +) where {X<:Number,Y<:Union{Real,Nothing}} + der, y = autodiff(ReverseWithPrimal, f, Active, Active(x)) + new_dx = dy * only(der) + return y, new_dx +end + +function DI.value_and_pullback!( + dx::X, ::EnzymeReverseBackend, f, x::X, dy::Y +) where {X<:AbstractArray,Y<:Union{Real,Nothing}} + dx .= zero(eltype(dx)) + _, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx)) + dx .*= dy + return y, dx +end + +## Utilities + +function DI.value_and_gradient(::EnzymeReverseBackend{true}, f, x::AbstractArray) + y = f(x) + grad = gradient(Reverse, f, x) + return y, grad +end + +function DI.value_and_gradient!( + grad::AbstractArray, ::EnzymeReverseBackend{true}, f, x::AbstractArray +) + y = f(x) + gradient!(Reverse, grad, f, x) + return y, grad +end diff --git a/ext/DifferentiationInterfaceFiniteDiffExt.jl b/ext/DifferentiationInterfaceFiniteDiffExt.jl index 8cf29cf63..a1c18c3fe 100644 --- a/ext/DifferentiationInterfaceFiniteDiffExt.jl +++ b/ext/DifferentiationInterfaceFiniteDiffExt.jl @@ -1,6 +1,7 @@ module DifferentiationInterfaceFiniteDiffExt -using DifferentiationInterface +using DifferentiationInterface: FiniteDiffBackend +import DifferentiationInterface as DI using DocStringExtensions using FiniteDiff: finite_difference_derivative, @@ -9,80 +10,101 @@ using FiniteDiff: finite_difference_jacobian using LinearAlgebra: dot, mul! -const DEFAULT_FDTYPE = Val{:central} +# see https://docs.sciml.ai/FiniteDiff/stable/#f-Definitions +const FUNCTION_INPLACE = Val{true} +const FUNCTION_NOT_INPLACE = Val{false} + +## Backend construction """ -$(TYPEDSIGNATURES) + FiniteDiffBackend(::Type{fdtype}=Val{:central}; custom=true) + +Construct a [`FiniteDiffBackend`](@ref) with any finite difference type `fdtype` (`Val{:forward}` or `Val{:central}`). """ -function DifferentiationInterface.value_and_pushforward!( - dy::Y, ::FiniteDiffBackend, f, x::X, dx -) where {X<:Number,Y<:Number} +function DI.FiniteDiffBackend( + ::Type{fdtype}=Val{:central}; custom::Bool=true +) where {fdtype} + return FiniteDiffBackend{custom,fdtype}() +end + +## Primitives + +function DI.value_and_pushforward!( + dy::Y, ::FiniteDiffBackend{custom,fdtype}, f, x, dx +) where {Y<:Number,custom,fdtype} y = f(x) - der = finite_difference_derivative( - f, - x, - DEFAULT_FDTYPE, # fdtype - eltype(dy), # returntype - y, # fx - ) - new_dy = der * dx + step(t::Number)::Number = f(x .+ t .* dx) + new_dy = finite_difference_derivative(step, zero(eltype(dx)), fdtype, eltype(y), y) return y, new_dy end -""" -$(TYPEDSIGNATURES) -""" -function DifferentiationInterface.value_and_pushforward!( - dy::Y, ::FiniteDiffBackend, f, x::X, dx -) where {X<:Number,Y<:AbstractArray} +function DI.value_and_pushforward!( + dy::Y, ::FiniteDiffBackend{custom,fdtype}, f, x, dx +) where {Y<:AbstractArray,custom,fdtype} y = f(x) + step(t::Number)::AbstractArray = f(x .+ t .* dx) finite_difference_gradient!( - dy, - f, - x, - DEFAULT_FDTYPE, # fdtype - eltype(dy), # returntype - Val{false}, # inplace - y, # fx + dy, step, zero(eltype(dx)), fdtype, eltype(y), FUNCTION_NOT_INPLACE, y ) - dy .*= dx return y, dy end -""" -$(TYPEDSIGNATURES) -""" -function DifferentiationInterface.value_and_pushforward!( - dy::Y, ::FiniteDiffBackend, f, x::X, dx -) where {X<:AbstractArray,Y<:Number} +## Utilities + +function DI.value_and_derivative( + ::FiniteDiffBackend{true,fdtype}, f, x::Number +) where {fdtype} y = f(x) - g = finite_difference_gradient( - f, - x, - DEFAULT_FDTYPE, # fdtype - eltype(dy), # returntype - Val{false}, # inplace - y, # fx - ) - new_dy = dot(g, dx) - return y, new_dy + der = finite_difference_derivative(f, x, fdtype, eltype(y), y) + return y, der end -""" -$(TYPEDSIGNATURES) -""" -function DifferentiationInterface.value_and_pushforward!( - dy::Y, ::FiniteDiffBackend, f, x::X, dx -) where {X<:AbstractArray,Y<:AbstractArray} +function DI.value_and_multiderivative!( + multider::AbstractArray, ::FiniteDiffBackend{true,fdtype}, f, x::Number +) where {fdtype} y = f(x) - J = finite_difference_jacobian( - f, - x, - DEFAULT_FDTYPE, # fdtype - eltype(dy), # returntype - ) - mul!(vec(dy), J, vec(dx)) - return y, dy + finite_difference_gradient!(multider, f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) + return y, multider +end + +function DI.value_and_multiderivative( + ::FiniteDiffBackend{true,fdtype}, f, x::Number +) where {fdtype} + y = f(x) + multider = finite_difference_gradient(f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) + return y, multider +end + +function DI.value_and_gradient!( + grad::AbstractArray, ::FiniteDiffBackend{true,fdtype}, f, x::AbstractArray +) where {fdtype} + y = f(x) + finite_difference_gradient!(grad, f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) + return y, grad +end + +function DI.value_and_gradient( + ::FiniteDiffBackend{true,fdtype}, f, x::AbstractArray +) where {fdtype} + y = f(x) + grad = finite_difference_gradient(f, x, fdtype, eltype(y), FUNCTION_NOT_INPLACE, y) + return y, grad +end + +function DI.value_and_jacobian( + ::FiniteDiffBackend{true,fdtype}, f, x::AbstractArray +) where {fdtype} + y = f(x) + jac = finite_difference_jacobian(f, x, fdtype, eltype(y)) + return y, jac +end + +function DI.value_and_jacobian!( + jac::AbstractMatrix, backend::FiniteDiffBackend{true}, f, x::AbstractArray +) + y, new_jac = DI.value_and_jacobian(backend, f, x) + jac .= new_jac + return y, jac end end # module diff --git a/ext/DifferentiationInterfaceForwardDiffExt.jl b/ext/DifferentiationInterfaceForwardDiffExt.jl index a419d2815..367fca5d9 100644 --- a/ext/DifferentiationInterfaceForwardDiffExt.jl +++ b/ext/DifferentiationInterfaceForwardDiffExt.jl @@ -1,15 +1,35 @@ module DifferentiationInterfaceForwardDiffExt -using DifferentiationInterface +using DifferentiationInterface: ForwardDiffBackend +import DifferentiationInterface as DI using DiffResults: DiffResults using DocStringExtensions -using ForwardDiff: Dual, Tag, value, extract_derivative, extract_derivative! +using ForwardDiff: + Dual, + Tag, + derivative, + derivative!, + extract_derivative, + extract_derivative!, + gradient, + gradient!, + jacobian, + jacobian!, + value using LinearAlgebra: mul! +## Backend construction + """ -$(TYPEDSIGNATURES) + ForwardDiffBackend(; custom=true) + +Construct a [`ForwardDiffBackend`](@ref). """ -function DifferentiationInterface.value_and_pushforward!( +DI.ForwardDiffBackend(; custom::Bool=true) = ForwardDiffBackend{custom}() + +## Primitives + +function DI.value_and_pushforward!( _dy::Y, ::ForwardDiffBackend, f, x::X, dx ) where {X<:Real,Y<:Real} T = typeof(Tag(f, X)) @@ -20,10 +40,7 @@ function DifferentiationInterface.value_and_pushforward!( return y, new_dy end -""" -$(TYPEDSIGNATURES) -""" -function DifferentiationInterface.value_and_pushforward!( +function DI.value_and_pushforward!( dy::Y, ::ForwardDiffBackend, f, x::X, dx ) where {X<:Real,Y<:AbstractArray} T = typeof(Tag(f, X)) @@ -34,10 +51,7 @@ function DifferentiationInterface.value_and_pushforward!( return y, dy end -""" -$(TYPEDSIGNATURES) -""" -function DifferentiationInterface.value_and_pushforward!( +function DI.value_and_pushforward!( _dy::Y, ::ForwardDiffBackend, f, x::X, dx ) where {X<:AbstractArray,Y<:Real} T = typeof(Tag(f, X)) # TODO: unsure @@ -48,10 +62,7 @@ function DifferentiationInterface.value_and_pushforward!( return y, new_dy end -""" -$(TYPEDSIGNATURES) -""" -function DifferentiationInterface.value_and_pushforward!( +function DI.value_and_pushforward!( dy::Y, ::ForwardDiffBackend, f, x::X, dx ) where {X<:AbstractArray,Y<:AbstractArray} T = typeof(Tag(f, X)) # TODO: unsure @@ -62,4 +73,54 @@ function DifferentiationInterface.value_and_pushforward!( return y, dy end +## Utilities (TODO: use DiffResults) + +function DI.value_and_derivative(::ForwardDiffBackend{true}, f, x::Number) + y = f(x) + der = derivative(f, x) + return y, der +end + +function DI.value_and_multiderivative(::ForwardDiffBackend{true}, f, x::Number) + y = f(x) + multider = derivative(f, x) + return y, multider +end + +function DI.value_and_multiderivative!( + multider::AbstractArray, ::ForwardDiffBackend{true}, f, x::Number +) + y = f(x) + derivative!(multider, f, x) + return y, multider +end + +function DI.value_and_gradient(::ForwardDiffBackend{true}, f, x::AbstractArray) + y = f(x) + grad = gradient(f, x) + return y, grad +end + +function DI.value_and_gradient!( + grad::AbstractArray, ::ForwardDiffBackend{true}, f, x::AbstractArray +) + y = f(x) + gradient!(grad, f, x) + return y, grad +end + +function DI.value_and_jacobian(::ForwardDiffBackend{true}, f, x::AbstractArray) + y = f(x) + jac = jacobian(f, x) + return y, jac +end + +function DI.value_and_jacobian!( + jac::AbstractMatrix, ::ForwardDiffBackend{true}, f, x::AbstractArray +) + y = f(x) + jacobian!(jac, f, x) + return y, jac +end + end # module diff --git a/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl b/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl new file mode 100644 index 000000000..25d54d6f9 --- /dev/null +++ b/ext/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -0,0 +1,48 @@ +module DifferentiationInterfacePolyesterForwardDiffExt + +using DifferentiationInterface: ForwardDiffBackend, PolyesterForwardDiffBackend +import DifferentiationInterface as DI +using DiffResults: DiffResults +using DocStringExtensions +using ForwardDiff: Chunk +using LinearAlgebra: mul! +using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian! + +## Backend construction + +""" + PolyesterForwardDiffBackend(C; custom=true) + +Construct a [`PolyesterForwardDiffBackend`](@ref) with chunk size `C`. +""" +function DI.PolyesterForwardDiffBackend(C::Integer; custom::Bool=true) + return PolyesterForwardDiffBackend{custom,C}() +end + +## Primitives + +function DI.value_and_pushforward!( + dy, ::PolyesterForwardDiffBackend{custom}, f, x, dx +) where {custom} + return DI.value_and_pushforward!(dy, ForwardDiffBackend{custom}(), f, x, dx) +end + +## Utilities + +function DI.value_and_gradient!( + grad::AbstractArray, ::PolyesterForwardDiffBackend{true,C}, f, x::AbstractArray +) where {C} + y = f(x) + threaded_gradient!(f, grad, x, Chunk{C}()) + return y, grad +end + +function DI.value_and_jacobian!( + jac::AbstractMatrix, ::PolyesterForwardDiffBackend{true,C}, f, x::AbstractArray +) where {C} + y = f(x) + threaded_jacobian!(f, jac, x, Chunk{C}()) + return y, jac +end + +end # module diff --git a/ext/DifferentiationInterfaceReverseDiffExt.jl b/ext/DifferentiationInterfaceReverseDiffExt.jl index 2a8922da7..99c5b325c 100644 --- a/ext/DifferentiationInterfaceReverseDiffExt.jl +++ b/ext/DifferentiationInterfaceReverseDiffExt.jl @@ -1,15 +1,24 @@ module DifferentiationInterfaceReverseDiffExt -using DifferentiationInterface +using DifferentiationInterface: ReverseDiffBackend +import DifferentiationInterface as DI using DiffResults: DiffResults using DocStringExtensions using LinearAlgebra: mul! -using ReverseDiff: gradient!, jacobian! +using ReverseDiff: gradient, gradient!, jacobian, jacobian! + +## Backend construction """ -$(TYPEDSIGNATURES) + ReverseDiffBackend(; custom) + +Construct a [`ReverseDiffBackend`](@ref). """ -function DifferentiationInterface.value_and_pullback!( +DI.ReverseDiffBackend(; custom::Bool=true) = ReverseDiffBackend{custom}() + +## Primitives + +function DI.value_and_pullback!( dx, ::ReverseDiffBackend, f, x::X, dy::Y ) where {X<:AbstractArray,Y<:Real} res = DiffResults.DiffResult(zero(Y), dx) @@ -19,10 +28,7 @@ function DifferentiationInterface.value_and_pullback!( return y, dx end -""" -$(TYPEDSIGNATURES) -""" -function DifferentiationInterface.value_and_pullback!( +function DI.value_and_pullback!( dx, ::ReverseDiffBackend, f, x::X, dy::Y ) where {X<:AbstractArray,Y<:AbstractArray} res = DiffResults.DiffResult(similar(dy), similar(dy, length(dy), length(x))) @@ -33,4 +39,34 @@ function DifferentiationInterface.value_and_pullback!( return y, dx end +## Utilities (TODO: use DiffResults) + +function DI.value_and_gradient(::ReverseDiffBackend{true}, f, x::AbstractArray) + y = f(x) + grad = gradient(f, x) + return y, grad +end + +function DI.value_and_gradient!( + grad::AbstractArray, ::ReverseDiffBackend{true}, f, x::AbstractArray +) + y = f(x) + gradient!(grad, f, x) + return y, grad +end + +function DI.value_and_jacobian(::ReverseDiffBackend{true}, f, x::AbstractArray) + y = f(x) + jac = jacobian(f, x) + return y, jac +end + +function DI.value_and_jacobian!( + jac::AbstractMatrix, ::ReverseDiffBackend{true}, f, x::AbstractArray +) + y = f(x) + jacobian!(jac, f, x) + return y, jac +end + end # module diff --git a/ext/DifferentiationInterfaceZygoteExt.jl b/ext/DifferentiationInterfaceZygoteExt.jl index a762fb9c3..28c0a7998 100644 --- a/ext/DifferentiationInterfaceZygoteExt.jl +++ b/ext/DifferentiationInterfaceZygoteExt.jl @@ -1,7 +1,52 @@ module DifferentiationInterfaceZygoteExt -using DifferentiationInterface +using DifferentiationInterface: ChainRulesReverseBackend, ZygoteBackend +import DifferentiationInterface as DI using DocStringExtensions -using Zygote +using Zygote: ZygoteRuleConfig, gradient, jacobian, withgradient, withjacobian + +## Backend construction + +""" + ZygoteBackend(; custom=true) + +Enables the use of [Zygote.jl](https://github.com/FluxML/Zygote.jl) by constructing a [`ChainRulesReverseBackend`](@ref) from `ZygoteRuleConfig()`. +""" +DI.ZygoteBackend(; custom::Bool=true) = ChainRulesReverseBackend(ZygoteRuleConfig(); custom) + +const ZygoteBackendType{custom} = ChainRulesReverseBackend{custom,<:ZygoteRuleConfig} + +function Base.show(io::IO, backend::ZygoteBackendType{custom}) where {custom} + return print(io, "ZygoteBackend{$(custom ? "custom" : "fallback")}()") +end + +## Utilities + +function DI.value_and_gradient(::ZygoteBackendType{true}, f, x::AbstractArray) + res = withgradient(f, x) + return res.val, only(res.grad) +end + +function DI.value_and_gradient!( + grad::AbstractArray, backend::ZygoteBackendType{true}, f, x::AbstractArray +) + y, new_grad = DI.value_and_gradient(backend, f, x) + grad .= new_grad + return y, grad +end + +function DI.value_and_jacobian(::ZygoteBackendType{true}, f, x::AbstractArray) + y = f(x) + jac = jacobian(f, x) + return y, only(jac) +end + +function DI.value_and_jacobian!( + jac::AbstractMatrix, backend::ZygoteBackendType{true}, f, x::AbstractArray +) + y, new_jac = DI.value_and_jacobian(backend, f, x) + jac .= new_jac + return y, jac +end end diff --git a/src/DifferentiationInterface.jl b/src/DifferentiationInterface.jl index 80c46bedb..3e44f5f07 100644 --- a/src/DifferentiationInterface.jl +++ b/src/DifferentiationInterface.jl @@ -1,8 +1,7 @@ """ DifferentiationInterface -An experimental redesign for [AbstractDifferentiation.jl] -(https://github.com/JuliaDiff/AbstractDifferentiation.jl). +An interface to various automatic differentiation backends in Julia. # Exports @@ -13,6 +12,7 @@ module DifferentiationInterface using DocStringExtensions using FillArrays: OneElement +include("backends_abstract.jl") include("backends.jl") include("utils.jl") include("pushforward.jl") @@ -22,13 +22,19 @@ include("scalar_array.jl") include("array_scalar.jl") include("array_array.jl") -export ChainRulesReverseBackend, - ChainRulesForwardBackend, - EnzymeReverseBackend, +export AbstractBackend, AbstractForwardBackend, AbstractReverseBackend +export autodiff_mode, is_custom +export handles_input_type, handles_output_type, handles_types + +export ChainRulesForwardBackend, + ChainRulesReverseBackend, EnzymeForwardBackend, + EnzymeReverseBackend, FiniteDiffBackend, ForwardDiffBackend, - ReverseDiffBackend + PolyesterForwardDiffBackend, + ReverseDiffBackend, + ZygoteBackend export value_and_pushforward!, value_and_pushforward export value_and_pullback!, value_and_pullback diff --git a/src/array_array.jl b/src/array_array.jl index 265d7358b..2ba82d021 100644 --- a/src/array_array.jl +++ b/src/array_array.jl @@ -1,13 +1,12 @@ """ - value_and_jacobian!(jac, backend, f, x[, stuff]) -> (y, jac) + value_and_jacobian!(jac, backend, f, x) -> (y, jac) -Compute the Jacobian inside `jac`. -Returns the primal output `f(x)` and the Jacobian. +Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of an array-to-array function, overwriting `jac` if possible. ## Notes -For a function `f: ℝⁿ → ℝᵐ`, `jac` is expected to be a `m × n` matrix. -If the input or output is a higher-order array, it is flattened with `vec`. +Regardless of the shape of `x` and `y`, if `x` has length `n` and `y` has length `m`, then `jac` is expected to be a `m × n` matrix. +This function acts as if the input and output had been flattened with `vec`. """ function value_and_jacobian!( jac::AbstractMatrix, backend::AbstractBackend, f, x::AbstractArray @@ -53,7 +52,12 @@ end """ value_and_jacobian(backend, f, x) -> (y, jac) -Call [`value_and_jacobian!`](@ref) after allocating memory for the Jacobian matrix. +Compute the primal value `y = f(x)` and the Jacobian matrix `jac = ∂f(x)` of an array-to-array function. + +## Notes + +Regardless of the shape of `x` and `y`, if `x` has length `n` and `y` has length `m`, then `jac` is expected to be a `m × n` matrix. +This function acts as if the input and output had been flattened with `vec`. """ function value_and_jacobian(backend::AbstractBackend, f, x::AbstractArray) y = f(x) diff --git a/src/array_scalar.jl b/src/array_scalar.jl index 2a2e69bc5..2224dcd77 100644 --- a/src/array_scalar.jl +++ b/src/array_scalar.jl @@ -1,8 +1,7 @@ """ value_and_gradient!(grad, backend, f, x) -> (y, grad) -Compute the gradient of an array-to-scalar function inside `dx`. -Returns the primal output `f(x)` and the gradient. +Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function, overwriting `grad` if possible. """ function value_and_gradient! end @@ -27,7 +26,7 @@ end """ value_and_gradient(backend, f, x) -> (y, grad) -Call [`value_and_gradient!`](@ref) after allocating memory for the gradient. +Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function. """ function value_and_gradient(backend::AbstractBackend, f, x::AbstractArray) grad = similar(x) diff --git a/src/backends.jl b/src/backends.jl index a341b786b..81581e87d 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -1,79 +1,115 @@ -abstract type AbstractBackend end -abstract type AbstractForwardBackend <: AbstractBackend end -abstract type AbstractReverseBackend <: AbstractBackend end +""" + ChainRulesForwardBackend <: AbstractForwardBackend +Enables the use of forward mode AD packages based on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl). """ - ChainRulesReverseBackend{RC} +struct ChainRulesForwardBackend{custom,RC} <: AbstractForwardBackend{custom} + ruleconfig::RC +end -Performs autodiff with reverse-mode AD packages based on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl), like [Zygote.jl](https://github.com/FluxML/Zygote.jl). +function Base.show(io::IO, backend::ChainRulesForwardBackend{custom}) where {custom} + return print( + io, + "ChainRulesForwardBackend{$(custom ? "custom" : "fallback")}($(backend.ruleconfig))", + ) +end -This must be constructed with an appropriate [`RuleConfig`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/superpowers/ruleconfig.html) instance: +""" + ChainRulesReverseBackend <: AbstractReverseBackend -```julia -using Zygote, DifferentiationInterface -backend = ChainRulesReverseBackend(Zygote.ZygoteRuleConfig()) -``` +Enables the use of reverse mode AD packages based on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl). """ -struct ChainRulesReverseBackend{RC} <: AbstractReverseBackend - # TODO: check RC<:RuleConfig{>:HasReverseMode} +struct ChainRulesReverseBackend{custom,RC} <: AbstractReverseBackend{custom} ruleconfig::RC end -function Base.string(backend::ChainRulesReverseBackend) - return "ChainRulesReverseBackend($(backend.ruleconfig))" +function Base.show(io::IO, backend::ChainRulesReverseBackend{custom}) where {custom} + return print( + io, + "ChainRulesReverseBackend{$(custom ? "custom" : "fallback")}($(backend.ruleconfig))", + ) end """ - ChainRulesForwardBackend{RC} - -Performs autodiff with forward-mode AD packages based on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl), like [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl). + FiniteDiffBackend <: AbstractForwardBackend -This must be constructed with an appropriate [`RuleConfig`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/superpowers/ruleconfig.html) instance. -```julia -using Diffractor, DifferentiationInterface -backend = ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()) -``` +Enables the use of [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl). """ -struct ChainRulesForwardBackend{RC} <: AbstractForwardBackend - # TODO: check RC<:RuleConfig{>:HasForwardsMode} - ruleconfig::RC -end +struct FiniteDiffBackend{custom,fdtype} <: AbstractForwardBackend{custom} end -function Base.string(backend::ChainRulesForwardBackend) - return "ChainRulesForwardBackend($(backend.ruleconfig))" +function Base.show(io::IO, ::FiniteDiffBackend{custom,fdtype}) where {custom,fdtype} + return print(io, "FiniteDiffBackend{$(custom ? "custom" : "fallback"),$fdtype}()") end """ - FiniteDiffBackend + EnzymeForwardBackend <: AbstractForwardBackend -Performs autodiff with [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl). +Enables the use of [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) in forward mode. """ -struct FiniteDiffBackend <: AbstractForwardBackend end +struct EnzymeForwardBackend{custom} <: AbstractForwardBackend{custom} end + +function Base.show(io::IO, ::EnzymeForwardBackend{custom}) where {custom} + return print(io, "EnzymeForwardBackend{$(custom ? "custom" : "fallback")}()") +end """ - EnzymeReverseBackend + EnzymeReverseBackend <: AbstractReverseBackend -Performs reverse-mode autodiff with [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl). +Enables the use of [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) in reverse mode. + +!!! warning + This backend only works for scalar output. """ -struct EnzymeReverseBackend <: AbstractReverseBackend end +struct EnzymeReverseBackend{custom} <: AbstractReverseBackend{custom} end + +function Base.show(io::IO, ::EnzymeReverseBackend{custom}) where {custom} + return print(io, "EnzymeReverseBackend{$(custom ? "custom" : "fallback")}()") +end """ - EnzymeForwardBackend + ForwardDiffBackend <: AbstractForwardBackend -Performs forward-mode autodiff with [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl). +Enables the use of [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). """ -struct EnzymeForwardBackend <: AbstractForwardBackend end +struct ForwardDiffBackend{custom} <: AbstractForwardBackend{custom} end + +function Base.show(io::IO, ::ForwardDiffBackend{custom}) where {custom} + return print(io, "ForwardDiffBackend{$(custom ? "custom" : "fallback")}()") +end """ - ForwardDiffBackend + PolyesterForwardDiffBackend <: AbstractForwardBackend + +Enables the use of [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl), falling back on [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) if needed. -Performs autodiff with [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). +!!! warning + This backend only works when the arrays are vectors. """ -struct ForwardDiffBackend <: AbstractForwardBackend end +struct PolyesterForwardDiffBackend{custom,C} <: AbstractForwardBackend{custom} end + +function Base.show(io::IO, ::PolyesterForwardDiffBackend{custom,C}) where {custom,C} + return print(io, "PolyesterForwardDiffBackend{$(custom ? "custom" : "fallback"),$C}()") +end """ - ReverseDiffBackend + ReverseDiffBackend <: AbstractReverseBackend Performs autodiff with [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl). + +!!! warning + This backend only works for array input. """ -struct ReverseDiffBackend <: AbstractReverseBackend end +struct ReverseDiffBackend{custom} <: AbstractReverseBackend{custom} end + +function Base.show(io::IO, ::ReverseDiffBackend{custom}) where {custom} + return print(io, "ReverseDiffBackend{$(custom ? "custom" : "fallback")}()") +end + +## Pseudo backends + +function ZygoteBackend end + +## Limitations + +handles_input_type(::ReverseDiffBackend, ::Type{<:Number}) = false +handles_output_type(::EnzymeReverseBackend, ::Type{<:AbstractArray}) = false diff --git a/src/backends_abstract.jl b/src/backends_abstract.jl new file mode 100644 index 000000000..dbb9e540f --- /dev/null +++ b/src/backends_abstract.jl @@ -0,0 +1,68 @@ + +""" + AbstractBackend + +Abstract type pointing to the AD package chosen by the user, which is called a "backend". + +## Custom + +When we say that a backend is "custom", it describes how the utilities (derivative, multiderivative, gradient and jacobian) are implemented: + +- Custom backends use specific routines defined in their package whenever they exist +- Non-custom backends use fallbacks defined in DifferentiationInterface.jl, which end up calling the pushforward or pullback +""" +abstract type AbstractBackend{custom} end + +""" + is_custom(backend) + +Return a boolean `custom` that describes how utilities (derivative, multiderivative, gradient and jacobian) are implemented. +""" +is_custom(::AbstractBackend{custom}) where {custom} = custom + +""" + handles_input_type(backend, ::Type{X}) + +Check if `backend` can differentiate functions with input type `X`. +""" +handles_input_type(::AbstractBackend, ::Type{<:Number}) = true +handles_input_type(::AbstractBackend, ::Type{<:AbstractArray}) = true + +""" + handles_output_type(backend, ::Type{Y}) + +Check if `backend` can differentiate functions with output type `Y`. +""" +handles_output_type(::AbstractBackend, ::Type{<:Number}) = true +handles_output_type(::AbstractBackend, ::Type{<:AbstractArray}) = true + +""" + handles_types(backend, ::Type{X}, ::Type{Y}) + +Check if `backend` can differentiate functions with input type `X` and output type `Y`. +""" +function handles_types(backend::AbstractBackend, ::Type{X}, ::Type{Y}) where {X,Y} + return handles_input_type(backend, X) && handles_output_type(backend, Y) +end + +""" + AbstractForwardBackend <: AbstractBackend + +Abstract subtype of [`AbstractBackend`](@ref) for forward mode AD packages. +""" +abstract type AbstractForwardBackend{custom} <: AbstractBackend{custom} end + +""" + AbstractReverseBackend <: AbstractBackend + +Abstract subtype of [`AbstractBackend`](@ref) for reverse mode AD packages. +""" +abstract type AbstractReverseBackend{custom} <: AbstractBackend{custom} end + +""" + autodiff_mode(backend) + +Return either `:forward` or `:reverse` depending on the mode of `backend`. +""" +autodiff_mode(::AbstractForwardBackend) = :forward +autodiff_mode(::AbstractReverseBackend) = :reverse diff --git a/src/pullback.jl b/src/pullback.jl index 37d69d10f..d413641ff 100644 --- a/src/pullback.jl +++ b/src/pullback.jl @@ -1,32 +1,23 @@ """ - value_and_pullback!(dx, backend, f, x, dy[, stuff]) -> (y, dx) + value_and_pullback!(dx, backend::AbstractReverseBackend, f, x, dy) -> (y, dx) -Compute the vector-Jacobian product of `dy` with the Jacobian of `f` at `x` inside `dx`. -Returns the primal output `f(x)` and the VJP `dx`. +Compute the primal value `y = f(x)` and the vector-Jacobian product `dx = ∂f(x)' * dy`, overwriting `dx` if possible. !!! info "Interface requirement" - This is the only required implementation for a reverse mode backend. - -# Arguments - -- `y`: primal output -- `dx`: tangent, might be modified -- `backend`: reverse-mode autodiff backend -- `f`: function `x -> y` to differentiate -- `x`: argument -- `dy`: cotangent -- `stuff`: optional backend-specific storage (cache, config), might be modified + This is the only required implementation for an [`AbstractReverseBackend`](@ref). """ -function value_and_pullback!(dx, backend::AbstractBackend, f, x, dy) - return error("No package extension loaded for backend $backend.") +function value_and_pullback!(dx, backend::AbstractReverseBackend, f, x, dy) + return error( + "Backend $backend is not loaded or does not support this type combination." + ) end """ - value_and_pullback(backend, f, x, dy[, stuff]) -> (y, dx) + value_and_pullback(backend::AbstractReverseBackend, f, x, dy) -> (y, dx) -Call [`value_and_pullback!`](@ref) after allocating memory for the vector-Jacobian product. +Compute the primal value `y = f(x)` and the vector-Jacobian product `dx = ∂f(x)' * dy`. """ -function value_and_pullback(backend::AbstractBackend, f, x, dy) +function value_and_pullback(backend::AbstractReverseBackend, f, x, dy) dx = mysimilar(x) return value_and_pullback!(dx, backend, f, x, dy) end diff --git a/src/pushforward.jl b/src/pushforward.jl index e95082242..656c22f01 100644 --- a/src/pushforward.jl +++ b/src/pushforward.jl @@ -1,32 +1,23 @@ """ - value_and_pushforward!(dy, backend, f, x, dx[, stuff]) -> (y, dy) + value_and_pushforward!(dy, backend::AbstractForwardBackend, f, x, dx) -> (y, dy) -Compute the Jacobian-vector product of the Jacobian of `f` at `x` with `dx` inside `dy`. -Returns the primal output `f(x)` and the JVP `dy`. +Compute the primal value `y = f(x)` and the Jacobian-vector product `dy = ∂f(x) * dx`, overwriting `dy` if possible. !!! info "Interface requirement" - This is the only required implementation for a forward mode backend. - -# Arguments - -- `y`: primal output -- `dy`: cotangent, might be modified -- `backend`: forward-mode autodiff backend -- `f`: function `x -> y` to differentiate -- `x`: argument -- `dx`: tangent -- `stuff`: optional backend-specific storage (cache, config), might be modified + This is the only required implementation for an [`AbstractForwardBackend`](@ref). """ -function value_and_pushforward!(dy, backend::AbstractBackend, f, x, dx) - return error("No package extension loaded for backend $backend.") +function value_and_pushforward!(dy, backend::AbstractForwardBackend, f, x, dx) + return error( + "Backend $backend is not loaded or does not support this type combination." + ) end """ - value_and_pushforward(backend, f, x, dx[, stuff]) -> (y, dy) + value_and_pushforward(backend::AbstractForwardBackend, f, x, dx) -> (y, dy) -Call [`value_and_pushforward!`](@ref) after allocating memory for the Jacobian-vector product. +Compute the primal value `y = f(x)` and the Jacobian-vector product `dy = ∂f(x) * dx`. """ -function value_and_pushforward(backend::AbstractBackend, f, x, dx) +function value_and_pushforward(backend::AbstractForwardBackend, f, x, dx) dy = mysimilar(f(x)) return value_and_pushforward!(dy, backend, f, x, dx) end diff --git a/src/scalar_array.jl b/src/scalar_array.jl index 2bca050b9..aa55faff0 100644 --- a/src/scalar_array.jl +++ b/src/scalar_array.jl @@ -1,8 +1,7 @@ """ value_and_multiderivative!(multider, backend, f, x) -> (y, multider) -Compute the derivative of a scalar-to-array function inside `multider`. -Returns the primal output `f(x)` and the derivative. +Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function, overwriting `multider` if possible. """ function value_and_multiderivative! end @@ -26,7 +25,7 @@ end """ value_and_multiderivative(backend, f, x) -> (y, multider) -Call [`value_and_multiderivative!`](@ref) after allocating memory for the multiderivative. +Compute the primal value `y = f(x)` and the (array-valued) derivative `multider = f'(x)` of a scalar-to-array function. """ function value_and_multiderivative(backend::AbstractBackend, f, x::Number) multider = similar(f(x)) diff --git a/src/scalar_scalar.jl b/src/scalar_scalar.jl index fca8b6080..d65219df7 100644 --- a/src/scalar_scalar.jl +++ b/src/scalar_scalar.jl @@ -1,8 +1,7 @@ """ value_and_derivative(backend, f, x) -> (y, der) -Compute the derivative of a scalar-to-scalar function. -Returns the primal output `f(x)` and the derivative. +Compute the primal value `y = f(x)` and the derivative `der = f'(x)` of a scalar-to-scalar function. """ function value_and_derivative end diff --git a/src/utils.jl b/src/utils.jl index cdbc461af..8c4f9746a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,11 +1,11 @@ """ - basisarray(backend, a::AbstractArray, i) + basisarray(backend, a::AbstractArray, i::CartesianIndex) Construct the `i`-th stardard basis array in the vector space of `a` with element type `eltype(a)`. ## Note -If an AD backend benefits from a more specialized unit vector implementation, +If an AD backend benefits from a more specialized basis array implementation, this function can be extended on the backend type. """ basisarray(::AbstractBackend, a::AbstractArray, i) = basisarray(a, i) diff --git a/test/Project.toml b/test/Project.toml index 8a11015e6..e23957ab6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/test/backends.jl b/test/backends.jl new file mode 100644 index 000000000..b694a6a89 --- /dev/null +++ b/test/backends.jl @@ -0,0 +1,11 @@ +using DifferentiationInterface +using DifferentiationInterface: is_custom +using Test + +backend = ChainRulesForwardBackend{false,Nothing}(nothing) +@test !is_custom(backend) +@test autodiff_mode(backend) == :forward + +backend = ChainRulesReverseBackend{true,Nothing}(nothing) +@test is_custom(backend) +@test autodiff_mode(backend) == :reverse diff --git a/test/diffractor.jl b/test/diffractor.jl index 52b029cd5..5f70dc31f 100644 --- a/test/diffractor.jl +++ b/test/diffractor.jl @@ -4,8 +4,12 @@ using DifferentiationInterface # see https://github.com/JuliaDiff/Diffractor.jl/issues/277 @test_skip test_pushforward( - ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()); type_stability=false + ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()), + scenarios; + type_stability=false, ); @test_skip test_jacobian_and_friends( - ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()); type_stability=false + ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()), + scenarios; + type_stability=false, ); diff --git a/test/enzyme.jl b/test/enzyme.jl deleted file mode 100644 index 89ca3157c..000000000 --- a/test/enzyme.jl +++ /dev/null @@ -1,24 +0,0 @@ -using DifferentiationInterface -using Enzyme - -# see https://github.com/EnzymeAD/Enzyme.jl/issues/1330 - -@testset "EnzymeForwardBackend" begin - test_pushforward( - EnzymeForwardBackend(); - output_type=Union{Number,AbstractVector}, # TODO: remove - type_stability=true, - ) - test_jacobian_and_friends( - EnzymeForwardBackend(); - output_type=Union{Number,AbstractVector}, # TODO: remove - type_stability=true, - ) -end; - -@testset "EnzymeReverseBackend" begin - test_pullback(EnzymeReverseBackend(); output_type=Number, type_stability=true) - test_jacobian_and_friends( - EnzymeReverseBackend(); output_type=Number, type_stability=true - ) -end; diff --git a/test/enzyme_forward.jl b/test/enzyme_forward.jl new file mode 100644 index 000000000..4dc06e0bd --- /dev/null +++ b/test/enzyme_forward.jl @@ -0,0 +1,10 @@ +using DifferentiationInterface +using Enzyme + +test_pushforward(EnzymeForwardBackend(), scenarios; type_stability=true); +test_jacobian_and_friends( + EnzymeForwardBackend(; custom=true), scenarios; type_stability=true +); +test_jacobian_and_friends( + EnzymeForwardBackend(; custom=false), scenarios; type_stability=true +); diff --git a/test/enzyme_reverse.jl b/test/enzyme_reverse.jl new file mode 100644 index 000000000..68c6d5c4a --- /dev/null +++ b/test/enzyme_reverse.jl @@ -0,0 +1,10 @@ +using DifferentiationInterface +using Enzyme + +test_pullback(EnzymeReverseBackend(), scenarios; type_stability=true); +test_jacobian_and_friends( + EnzymeReverseBackend(; custom=true), scenarios; type_stability=true +) +test_jacobian_and_friends( + EnzymeReverseBackend(; custom=false), scenarios; type_stability=true +) diff --git a/test/finitediff.jl b/test/finitediff.jl index f830e71c2..cb9b657fc 100644 --- a/test/finitediff.jl +++ b/test/finitediff.jl @@ -1,5 +1,6 @@ using DifferentiationInterface using FiniteDiff -test_pushforward(FiniteDiffBackend(); type_stability=false); -test_jacobian_and_friends(FiniteDiffBackend(); type_stability=false); +test_pushforward(FiniteDiffBackend(), scenarios; type_stability=true); +test_jacobian_and_friends(FiniteDiffBackend(; custom=true), scenarios; type_stability=false); +test_jacobian_and_friends(FiniteDiffBackend(; custom=false), scenarios; type_stability=true); diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index 301941b23..62744d88a 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -1,5 +1,10 @@ using DifferentiationInterface using ForwardDiff -test_pushforward(ForwardDiffBackend(); type_stability=false); -test_jacobian_and_friends(ForwardDiffBackend(); type_stability=false); +test_pushforward(ForwardDiffBackend(), scenarios; type_stability=true); +test_jacobian_and_friends( + ForwardDiffBackend(; custom=true), scenarios; type_stability=false +); +test_jacobian_and_friends( + ForwardDiffBackend(; custom=false), scenarios; type_stability=true +); diff --git a/test/polyesterforwarddiff.jl b/test/polyesterforwarddiff.jl new file mode 100644 index 000000000..d1cb88d41 --- /dev/null +++ b/test/polyesterforwarddiff.jl @@ -0,0 +1,16 @@ +using DifferentiationInterface +using PolyesterForwardDiff + +# see https://github.com/JuliaDiff/PolyesterForwardDiff.jl/issues/17 + +test_pushforward( + PolyesterForwardDiffBackend(4; custom=true), scenarios; type_stability=false +); + +test_jacobian_and_friends( + PolyesterForwardDiffBackend(4; custom=true), + scenarios; + input_type=Union{Number,AbstractVector}, + output_type=Union{Number,AbstractVector}, + type_stability=false, +); diff --git a/test/reversediff.jl b/test/reversediff.jl index f5948cfdf..aebfdbaac 100644 --- a/test/reversediff.jl +++ b/test/reversediff.jl @@ -1,7 +1,10 @@ using DifferentiationInterface using ReverseDiff -test_pullback(ReverseDiffBackend(); input_type=AbstractArray, type_stability=false); +test_pullback(ReverseDiffBackend(), scenarios; type_stability=false); test_jacobian_and_friends( - ReverseDiffBackend(); input_type=AbstractArray, type_stability=false + ReverseDiffBackend(; custom=true), scenarios; type_stability=false +); +test_jacobian_and_friends( + ReverseDiffBackend(; custom=false), scenarios; type_stability=false ); diff --git a/test/runtests.jl b/test/runtests.jl index 59e55106b..50f0f19d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,13 +6,6 @@ using JET: JET using JuliaFormatter: JuliaFormatter using Test -using Diffractor: Diffractor -using Enzyme: Enzyme -using FiniteDiff: FiniteDiff -using ForwardDiff: ForwardDiff -using ReverseDiff: ReverseDiff -using Zygote: Zygote - ## Utils include("utils.jl") @@ -29,14 +22,21 @@ include("utils.jl") ) end @testset "JET" begin - @test_skip JET.test_package(DifferentiationInterface; target_defined_modules=true) + JET.test_package(DifferentiationInterface; target_defined_modules=true) + end + + @testset "Backend utilities" begin + include("backends.jl") end @testset "Diffractor" begin include("diffractor.jl") end - @testset "Enzyme" begin - include("enzyme.jl") + @testset "Enzyme (forward)" begin + include("enzyme_forward.jl") + end + @testset "Enzyme (reverse)" begin + include("enzyme_reverse.jl") end @testset "FiniteDiff" begin include("finitediff.jl") @@ -44,6 +44,9 @@ include("utils.jl") @testset "ForwardDiff" begin include("forwarddiff.jl") end + @testset "PolyesterForwardDiff" begin + include("polyesterforwarddiff.jl") + end @testset "ReverseDiff" begin include("reversediff.jl") end diff --git a/test/utils.jl b/test/utils.jl index bbf028203..aeb8b97a1 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -15,7 +15,7 @@ using Test f::F "argument" x::X - "primal output" + "primal value" y::Y "pushforward seed" dx::X @@ -101,7 +101,7 @@ rng = StableRNG(63) f_scalar_scalar(x::Number)::Number = sin(x) f_scalar_vector(x::Number)::AbstractVector = [sin(x), sin(2x)] -f_scalar_matrix(x::Number)::AbstractMatrix = [sin(x) cos(x); sin(2x) cos(2x)] +f_scalar_matrix(x::Number)::AbstractMatrix = hcat([sin(x) cos(x)], [sin(2x) cos(2x)]) function f_vector_scalar(x::AbstractVector)::Number a = eachindex(x) @@ -151,7 +151,7 @@ scenarios = [ function test_pushforward( backend::AbstractForwardBackend, - scenarios::Vector{<:Scenario}=scenarios; + scenarios::Vector{<:Scenario}; input_type::Type=Any, output_type::Type=Any, allocs::Bool=false, @@ -160,37 +160,35 @@ function test_pushforward( scenarios = filter(scenarios) do s get_input_type(s) <: input_type && get_output_type(s) <: output_type end - @testset "Pushforward" begin + @testset "Pushforward ($(is_custom(backend) ? "custom" : "fallback"))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) + handles_types(backend, X, Y) || continue + @testset "$X -> $Y" begin (; f, x, y, dx, dy_true) = scenario y_out, dy_out = value_and_pushforward(backend, f, x, dx) + dy_in = zero(dy_out) + y_out2, dy_out2 = value_and_pushforward!(dy_in, backend, f, x, dx) - @testset "Primal output" begin - @test y_out == y + @testset "Primal value" begin + @test y_out ≈ y + @test y_out2 ≈ y end - @testset "Tangent output" begin + @testset "Tangent value" begin @test dy_out ≈ dy_true rtol = 1e-3 - end - if ismutable(dy_out) - @testset "Mutation" begin - dy_in = similar(dy_out) - value_and_pushforward!(dy_in, backend, f, x, dx) - @test dy_in ≈ dy_true rtol = 1e-3 + @test dy_out2 ≈ dy_true rtol = 1e-3 + if ismutable(dy_in) + @testset "Mutation" begin + @test dy_in ≈ dy_true rtol = 1e-3 + end end end - if allocs - @testset "Allocations" begin - @test (@allocated value_and_pushforward!( - dy_out, backend, f, x, dx - )) == 0 - end + allocs && @testset "Allocations" begin + @test (@allocated value_and_pushforward!(dy_in, backend, f, x, dx)) == 0 end - if type_stability - @testset "Type stability" begin - @test_opt value_and_pushforward!(dy_out, backend, f, x, dx) - end + type_stability && @testset "Type stability" begin + @test_opt value_and_pushforward!(dy_in, backend, f, x, dx) end end end @@ -199,7 +197,7 @@ end function test_pullback( backend::AbstractReverseBackend, - scenarios=scenarios; + scenarios::Vector{<:Scenario}; input_type::Type=Any, output_type::Type=Any, allocs::Bool=false, @@ -208,36 +206,35 @@ function test_pullback( scenarios = filter(scenarios) do s (get_input_type(s) <: input_type) && (get_output_type(s) <: output_type) end - @testset "Pullback" begin + @testset "Pullback ($(is_custom(backend) ? "custom" : "fallback"))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) + handles_types(backend, X, Y) || continue + @testset "$X -> $Y" begin (; f, x, y, dy, dx_true) = scenario y_out, dx_out = value_and_pullback(backend, f, x, dy) + dx_in = zero(dx_out) + y_out2, dx_out2 = value_and_pullback!(dx_in, backend, f, x, dy) - @testset "Primal output" begin - @test y_out == y + @testset "Primal value" begin + @test y_out ≈ y + @test y_out2 ≈ y end - @testset "Co-tangent output" begin + @testset "Cotangent value" begin @test dx_out ≈ dx_true rtol = 1e-3 - end - if ismutable(dx_out) - @testset "Mutation" begin - dx_in = similar(dx_out) - value_and_pullback!(dx_in, backend, f, x, dy) - @test dx_in ≈ dx_true rtol = 1e-3 + @test dx_out2 ≈ dx_true rtol = 1e-3 + if ismutable(dx_out) + @testset "Mutation" begin + @test dx_in ≈ dx_true rtol = 1e-3 + end end end - if allocs - @testset "Allocations" begin - @test (@allocated value_and_pullback!(dx_out, backend, f, x, dy)) == - 0 - end + allocs && @testset "Allocations" begin + @test (@allocated value_and_pullback!(dx_in, backend, f, x, dy)) == 0 end - if type_stability - @testset "Type stability" begin - @test_opt value_and_pullback!(dx_out, backend, f, x, dy) - end + type_stability && @testset "Type stability" begin + @test_opt value_and_pullback!(dx_in, backend, f, x, dy) end end end @@ -246,35 +243,33 @@ end function test_derivative( backend::AbstractBackend, - scenarios::Vector{<:Scenario}=scenarios; + scenarios::Vector{<:Scenario}; allocs::Bool=false, type_stability::Bool=true, ) scenarios = filter(scenarios) do s (get_input_type(s) <: Number) && (get_output_type(s) <: Number) end - @testset "Derivative" begin + @testset "Derivative ($(is_custom(backend) ? "custom" : "fallback"))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) + handles_types(backend, X, Y) || continue + @testset "$X -> $Y" begin (; f, x, y, der_true) = scenario y_out, der_out = value_and_derivative(backend, f, x) - @testset "Primal output" begin - @test y_out == y + @testset "Primal value" begin + @test y_out ≈ y end - @testset "Derivative output" begin + @testset "Derivative value" begin @test der_out ≈ der_true rtol = 1e-3 end - if allocs - @testset "Allocations" begin - @test (@allocated value_and_derivative(backend, f, x)) == 0 - end + allocs && @testset "Allocations" begin + @test (@allocated value_and_derivative(backend, f, x)) == 0 end - if type_stability - @testset "Type stability" begin - @test_opt value_and_derivative(backend, f, x) - end + type_stability && @testset "Type stability" begin + @test_opt value_and_derivative(backend, f, x) end end end @@ -283,42 +278,42 @@ end function test_multiderivative( backend::AbstractBackend, - scenarios::Vector{<:Scenario}=scenarios; + scenarios::Vector{<:Scenario}; allocs::Bool=false, type_stability::Bool=true, ) scenarios = filter(scenarios) do s (get_input_type(s) <: Number) && (get_output_type(s) <: AbstractArray) end - @testset "Multiderivative" begin + @testset "Multiderivative ($(is_custom(backend) ? "custom" : "fallback"))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) + handles_types(backend, X, Y) || continue + @testset "$X -> $Y" begin (; f, x, y, multider_true) = scenario y_out, multider_out = value_and_multiderivative(backend, f, x) - - @testset "Primal output" begin - @test y_out == y + multider_in = zero(multider_out) + y_out2, multider_out2 = value_and_multiderivative!( + multider_in, backend, f, x + ) + + @testset "Primal value" begin + @test y_out ≈ y + @test y_out2 ≈ y end - @testset "Multiderivative output" begin + @testset "Multiderivative value" begin @test multider_out ≈ multider_true rtol = 1e-3 - end - @testset "Mutation" begin - multider_in = similar(multider_out) - value_and_multiderivative!(multider_in, backend, f, x) - @test multider_in ≈ multider_true rtol = 1e-3 - end - if allocs - @testset "Allocations" begin - @test (@allocated value_and_multiderivative!( - multider_out, backend, f, x - )) == 0 + @test multider_out2 ≈ multider_true rtol = 1e-3 + @testset "Mutation" begin + @test multider_in ≈ multider_true rtol = 1e-3 end end - if type_stability - @testset "Type stability" begin - @test_opt value_and_multiderivative!(multider_out, backend, f, x) - end + allocs && @testset "Allocations" begin + @test (@allocated value_and_multiderivative!(multider_in, backend, f, x)) == 0 + end + type_stability && @testset "Type stability" begin + @test_opt value_and_multiderivative!(multider_in, backend, f, x) end end end @@ -327,40 +322,40 @@ end function test_gradient( backend::AbstractBackend, - scenarios::Vector{<:Scenario}=scenarios; + scenarios::Vector{<:Scenario}; allocs::Bool=false, type_stability::Bool=true, ) scenarios = filter(scenarios) do s (get_input_type(s) <: AbstractArray) && (get_output_type(s) <: Number) end - @testset "Gradient" begin + @testset "Gradient ($(is_custom(backend) ? "custom" : "fallback"))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) + handles_types(backend, X, Y) || continue + @testset "$X -> $Y" begin (; f, x, y, grad_true) = scenario y_out, grad_out = value_and_gradient(backend, f, x) + grad_in = zero(grad_out) + y_out2, grad_out2 = value_and_gradient!(grad_in, backend, f, x) - @testset "Primal output" begin - @test y_out == y + @testset "Primal value" begin + @test y_out ≈ y + @test y_out2 ≈ y end - @testset "Gradient output" begin + @testset "Gradient value" begin @test grad_out ≈ grad_true rtol = 1e-3 - end - @testset "Mutation" begin - grad_in = similar(grad_out) - value_and_gradient!(grad_in, backend, f, x) - @test grad_in ≈ grad_true rtol = 1e-3 - end - if allocs - @testset "Allocations" begin - @test (@allocated value_and_gradient!(grad_out, backend, f, x)) == 0 + @test grad_out2 ≈ grad_true rtol = 1e-3 + @testset "Mutation" begin + @test grad_in ≈ grad_true rtol = 1e-3 end end - if type_stability - @testset "Type stability" begin - @test_opt value_and_gradient!(grad_out, backend, f, x) - end + allocs && @testset "Allocations" begin + @test (@allocated value_and_gradient!(grad_in, backend, f, x)) == 0 + end + type_stability && @testset "Type stability" begin + @test_opt value_and_gradient!(grad_in, backend, f, x) end end end @@ -369,40 +364,40 @@ end function test_jacobian( backend::AbstractBackend, - scenarios::Vector{<:Scenario}=scenarios; + scenarios::Vector{<:Scenario}; allocs::Bool=false, type_stability::Bool=true, ) scenarios = filter(scenarios) do s (get_input_type(s) <: AbstractArray) && (get_output_type(s) <: AbstractArray) end - @testset "Jacobian" begin + @testset "Jacobian ($(is_custom(backend) ? "custom" : "fallback"))" begin for scenario in scenarios X, Y = get_input_type(scenario), get_output_type(scenario) + handles_types(backend, X, Y) || continue + @testset "$X -> $Y" begin (; f, x, y, jac_true) = scenario y_out, jac_out = value_and_jacobian(backend, f, x) + jac_in = zero(jac_out) + y_out2, jac_out2 = value_and_jacobian!(jac_in, backend, f, x) - @testset "Primal output" begin - @test y_out == y + @testset "Primal value" begin + @test y_out ≈ y + @test y_out2 ≈ y end - @testset "Jacobian output" begin + @testset "Jacobian value" begin @test jac_out ≈ jac_true rtol = 1e-3 - end - @testset "Mutation" begin - jac_in = similar(jac_out) - value_and_jacobian!(jac_in, backend, f, x) - @test jac_in ≈ jac_true rtol = 1e-3 - end - if allocs - @testset "Allocations" begin - @test (@allocated value_and_jacobian!(jac_out, backend, f, x)) == 0 + @test jac_out2 ≈ jac_true rtol = 1e-3 + @testset "Mutation" begin + @test jac_in ≈ jac_true rtol = 1e-3 end end - if type_stability - @testset "Type stability" begin - @test_opt value_and_jacobian!(jac_out, backend, f, x) - end + allocs && @testset "Allocations" begin + @test (@allocated value_and_jacobian!(jac_in, backend, f, x)) == 0 + end + type_stability && @testset "Type stability" begin + @test_opt value_and_jacobian!(jac_in, backend, f, x) end end end @@ -411,7 +406,7 @@ end function test_jacobian_and_friends( backend::AbstractBackend, - scenarios::Vector{<:Scenario}=scenarios; + scenarios::Vector{<:Scenario}; input_type::Type=Any, output_type::Type=Any, allocs::Bool=false, diff --git a/test/zygote.jl b/test/zygote.jl index c91acb8e3..69dc583ac 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -1,7 +1,6 @@ using DifferentiationInterface using Zygote -test_pullback(ChainRulesReverseBackend(Zygote.ZygoteRuleConfig()); type_stability=false); -test_jacobian_and_friends( - ChainRulesReverseBackend(Zygote.ZygoteRuleConfig()); type_stability=false -); +test_pullback(ZygoteBackend(), scenarios; type_stability=false); +test_jacobian_and_friends(ZygoteBackend(; custom=true), scenarios; type_stability=false); +test_jacobian_and_friends(ZygoteBackend(; custom=false), scenarios; type_stability=false);