Skip to content

Commit

Permalink
Remove subtyping, replace by traits
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Apr 13, 2024
1 parent b1ae98a commit fe11e33
Show file tree
Hide file tree
Showing 15 changed files with 459 additions and 280 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,9 @@ jobs:
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true
22 changes: 20 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
name = "ADTypes"
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com> and contributors"]
authors = [
"Vaibhav Dixit <vaibhavyashdixit@gmail.com>, Guillaume Dalle and contributors",
]
version = "0.3.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
ADTypesChainRulesCoreExt = "ChainRulesCore"
ADTypesEnzymeCoreExt = "EnzymeCore"

[compat]
ChainRulesCore = "1.23.0"
EnzymeCore = "0.7.2"
julia = "1.6"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["ChainRulesCore", "EnzymeCore", "Test"]
31 changes: 13 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,24 @@
[![Docs dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://SciML.github.io/ADTypes.jl/dev/)
[![Build Status](https://github.com/SciML/ADTypes.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/SciML/ADTypes.jl/actions/workflows/CI.yml?query=branch%3Amain)

ADTypes.jl is a multi-valued logic system specifying the choice of an automatic differentiation (AD) library and its parameters.
ADTypes.jl is a multi-valued logic system to choose an automatic differentiation (AD) package and specify its parameters.

## Which AD libraries are supported?

See the API reference in the documentation.
If a given package is missing, feel free to open an issue or pull request.

## Why should packages adopt this standard?
## Why should AD users adopt this standard?

A common practice is the use of a boolean keyword argument like `autodiff = true/false`.
However, boolean logic is not precise enough for all the choices required.
For instance, forward mode AD is implemented by both ForwardDiff and Enzyme, which makes `autodiff = true` ambiguous.
Something like `ChooseForwardDiff()` is thus required, possibly with additional parameters depending on the library.
A natural approach is to use a keyword argument with e.g. `Bool` or `Symbol` values.
Let's see a few examples to understand why this is not enough:

The risk is that every package developer might develop their own version of `ChooseForwardDiff()`, which would ruin interoperability.
This is why ADTypes.jl provides a single set of shared types for this task, as an extremely lightweight dependency.
Wonder no more: `ADTypes.AutoForwardDiff()` is the way to go.

## Why define types instead of enums?

If we used enums, they would not contain type-level information useful for dispatch.
This is needed by many AD libraries to ensure type stability.
Notably, the choice of config or cache type is different with each AD, so we must know statically which AD library is chosen.
- `autodiff = true`: ambiguous, we don't know which AD package should be used
- `autodiff = :forward`: ambiguous, there are several AD packages implementing both forward and reverse mode (and there are other modes beyond that)
- `autodiff = :Enzyme`: ambiguous, some AD packages can work both in forward and reverse mode
- `autodiff = (:Enzyme, :forward)`: not too bad, but many AD packages require additional configuration (number of chunks, tape compilation, etc.)

## Why is this AD package missing?

Feel free to open a pull request adding it.
A more involved struct is thus required, with package-specific parameters.
If every AD user develops their own version of said struct, it will ruin interoperability.
This is why ADTypes.jl provides a single set of shared types for this task, as an extremely lightweight dependency.
They are types and not enums because we need AD choice information statically to use it for dispatch.
14 changes: 9 additions & 5 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ AutoPolyesterForwardDiff

```@docs
AutoReverseDiff
AutoTapir
AutoTracker
AutoZygote
```
Expand Down Expand Up @@ -78,11 +79,14 @@ ADTypes.row_coloring
ADTypes.NoColoringAlgorithm
```

## Internals
## Modes

```@docs
AbstractFiniteDifferencesMode
AbstractForwardMode
AbstractReverseMode
AbstractSymbolicDifferentiationMode
ADTypes.mode
ADTypes.AbstractMode
ADTypes.FiniteDifferencesMode
ADTypes.ForwardMode
ADTypes.ForwardOrReverseMode
ADTypes.ReverseMode
ADTypes.SymbolicMode
```
28 changes: 28 additions & 0 deletions ext/ADTypesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module ADTypesChainRulesCoreExt

using ADTypes: ADTypes, AutoChainRules
using ChainRulesCore: HasForwardsMode, HasReverseMode,
NoForwardsMode, NoReverseMode,
RuleConfig

# see https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/superpowers/ruleconfig.html

function ADTypes.mode(::AutoChainRules{RC}) where {
RC <: RuleConfig{>:Union{HasForwardsMode, NoReverseMode}}
}
return ADTypes.ForwardMode()
end

function ADTypes.mode(::AutoChainRules{RC}) where {
RC <: RuleConfig{>:Union{NoForwardsMode, HasReverseMode}}
}
return ADTypes.ReverseMode()
end

function ADTypes.mode(::AutoChainRules{RC}) where {
RC <: RuleConfig{>:Union{HasForwardsMode, HasReverseMode}}
}
return ADTypes.ForwardOrReverseMode()
end

end
9 changes: 9 additions & 0 deletions ext/ADTypesEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module ADTypesEnzymeCoreExt

using ADTypes: ADTypes, AutoEnzyme
using EnzymeCore: EnzymeCore

ADTypes.mode(::AutoEnzyme{M}) where {M <: EnzymeCore.ForwardMode} = ADTypes.ForwardMode()
ADTypes.mode(::AutoEnzyme{M}) where {M <: EnzymeCore.ReverseMode} = ADTypes.ReverseMode()

end
17 changes: 16 additions & 1 deletion src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,25 @@ module ADTypes

using Base: @deprecate

include("abstract.jl")
"""
AbstractADType
Abstract supertype for all AD choices.
"""
abstract type AbstractADType end

Base.broadcastable(ad::AbstractADType) = Ref(ad)

include("mode.jl")
include("dense.jl")
include("sparse.jl")
include("legacy.jl")

if !isdefined(Base, :get_extension)
include("../ext/ADTypesChainRulesCoreExt.jl")
include("../ext/ADTypesEnzymeCoreExt.jl")
end

export AbstractADType

export AutoChainRules,
Expand All @@ -24,6 +38,7 @@ export AutoChainRules,
AutoModelingToolkit,
AutoPolyesterForwardDiff,
AutoReverseDiff,
AutoTapir,
AutoTracker,
AutoZygote

Expand Down
36 changes: 0 additions & 36 deletions src/abstract.jl

This file was deleted.

Loading

0 comments on commit fe11e33

Please sign in to comment.