Skip to content

Commit

Permalink
Switch backends to ADTypes.jl (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Mar 11, 2024
1 parent 83d3701 commit 04aec00
Show file tree
Hide file tree
Showing 47 changed files with 968 additions and 939 deletions.
16 changes: 15 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.1.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Expand All @@ -20,15 +23,26 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore"
DifferentiationInterfaceDiffractorExt = [
"Diffractor",
"AbstractDifferentiation",
]
DifferentiationInterfaceEnzymeExt = "Enzyme"
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
DifferentiationInterfacePolyesterForwardDiffExt = [
"PolyesterForwardDiff",
"ForwardDiff",
"DiffResults",
]
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
DifferentiationInterfaceZygoteExt = ["Zygote"]

[compat]
AbstractDifferentiation = "0.6"
ADTypes = "0.2.6"
ChainRulesCore = "1.19"
Diffractor = "0.2"
DiffResults = "1.1"
DocStringExtensions = "0.9"
Enzyme = "0.11"
Expand Down
42 changes: 19 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@ This package provides a backend-agnostic syntax to differentiate functions `f(x)

It started out as an experimental redesign for [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl).

## Example
## Compatibility

```jldoctest
julia> using DifferentiationInterface, Enzyme
We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):

julia> backend = EnzymeReverseBackend();
- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) with `AutoEnzyme(Val(:forward))` or `AutoEnzyme(Val(:reverse))`
- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) with `AutoFiniteDiff()`
- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) with `AutoForwardDiff()`
- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) with `AutoPolyesterForwardDiff(; chunksize=C)`
- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) with `AutoReverseDiff()`
- [Zygote.jl](https://github.com/FluxML/Zygote.jl) with `AutoZygote()`

julia> f(x) = sum(abs2, x);
We also support two more backends which are not yet part of ADTypes.jl:

julia> value_and_gradient(backend, f, [1., 2., 3.])
(14.0, [2.0, 4.0, 6.0])
```
- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) with `AutoChainRules(ruleconfig)`
- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) with `AutoDiffractor()`

## Design

Expand All @@ -40,22 +43,15 @@ From these primitives, several utilities are defined, depending on the type of t
| scalar input | derivative | multiderivative |
| array input | gradient | jacobian |

## Supported backends

Forward mode:

- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl)
- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl)
- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl)
- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl)
## Example

Reverse mode:
```jldoctest
julia> import DifferentiationInterface, ADTypes, ForwardDiff
- [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl)
- [Zygote.jl](https://github.com/FluxML/Zygote.jl)
- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl)
julia> backend = ADTypes.AutoForwardDiff();
Experimental:
julia> f(x) = sum(abs2, x);
- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl)
- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) (currently broken due to [#277](https://github.com/JuliaDiff/Diffractor.jl/issues/277))
julia> DifferentiationInterface.value_and_gradient(backend, f, [1., 2., 3.])
(14.0, [2.0, 4.0, 6.0])
```
1 change: 1 addition & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand Down
62 changes: 14 additions & 48 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
using ADTypes
using BenchmarkTools
using DifferentiationInterface
using LinearAlgebra

using Diffractor: Diffractor
using Enzyme: Enzyme
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using PolyesterForwardDiff: PolyesterForwardDiff
using ReverseDiff: ReverseDiff
using Zygote: Zygote
using Zygote: Zygote, ZygoteRuleConfig

## Settings

Expand Down Expand Up @@ -41,38 +43,18 @@ end

## Backends

forward_custom_backends = [
EnzymeForwardBackend(; custom=true),
FiniteDiffBackend(; custom=true),
ForwardDiffBackend(; custom=true),
PolyesterForwardDiffBackend(4; custom=true),
all_backends = [
AutoChainRules(ZygoteRuleConfig()),
AutoDiffractor(),
AutoEnzyme(Val(:forward)),
AutoEnzyme(Val(:reverse)),
AutoFiniteDiff(),
AutoForwardDiff(),
AutoPolyesterForwardDiff(; chunksize=4),
AutoReverseDiff(),
AutoZygote(),
]

forward_fallback_backends = [
EnzymeForwardBackend(; custom=false),
FiniteDiffBackend(; custom=false),
ForwardDiffBackend(; custom=false),
]

reverse_custom_backends = [
ZygoteBackend(; custom=true),
EnzymeReverseBackend(; custom=true),
ReverseDiffBackend(; custom=true),
]

reverse_fallback_backends = [
ZygoteBackend(; custom=false),
EnzymeReverseBackend(; custom=false),
ReverseDiffBackend(; custom=false),
]

all_backends = vcat(
forward_custom_backends,
forward_fallback_backends,
reverse_custom_backends,
reverse_fallback_backends,
)

## Suite

function make_suite()
Expand All @@ -83,11 +65,7 @@ function make_suite()

for backend in all_backends
add_derivative_benchmarks!(SUITE, backend, scalar_to_scalar, 1, 1)
end
for backend in forward_fallback_backends
add_pushforward_benchmarks!(SUITE, backend, scalar_to_scalar, 1, 1)
end
for backend in reverse_fallback_backends
add_pullback_benchmarks!(SUITE, backend, scalar_to_scalar, 1, 1)
end

Expand All @@ -97,11 +75,7 @@ function make_suite()

for backend in all_backends
add_multiderivative_benchmarks!(SUITE, backend, scalar_to_vector, 1, m)
end
for backend in forward_fallback_backends
add_pushforward_benchmarks!(SUITE, backend, scalar_to_vector, 1, m)
end
for backend in reverse_fallback_backends
add_pullback_benchmarks!(SUITE, backend, scalar_to_vector, 1, m)
end
end
Expand All @@ -112,11 +86,7 @@ function make_suite()

for backend in all_backends
add_gradient_benchmarks!(SUITE, backend, vector_to_scalar, n, 1)
end
for backend in forward_fallback_backends
add_pushforward_benchmarks!(SUITE, backend, vector_to_scalar, n, 1)
end
for backend in reverse_fallback_backends
add_pullback_benchmarks!(SUITE, backend, vector_to_scalar, n, 1)
end
end
Expand All @@ -127,11 +97,7 @@ function make_suite()

for backend in all_backends
add_jacobian_benchmarks!(SUITE, backend, vector_to_vector, n, m)
end
for backend in forward_fallback_backends
add_pushforward_benchmarks!(SUITE, backend, vector_to_vector, n, m)
end
for backend in reverse_fallback_backends
add_pullback_benchmarks!(SUITE, backend, vector_to_vector, n, m)
end
end
Expand All @@ -144,7 +110,7 @@ include("utils.jl")
SUITE = make_suite()

# Run benchmarks locally
# results = BenchmarkTools.run(SUITE; verbose=true)
results = BenchmarkTools.run(SUITE; verbose=true)

# Compare commits locally
# using BenchmarkCI; BenchmarkCI.judge(baseline="origin/main"); BenchmarkCI.displayjudgement()
Expand Down
Loading

0 comments on commit 04aec00

Please sign in to comment.