Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AutoMooncake type #89

Merged
merged 13 commits into from
Sep 25, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ADTypes"
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com>, Guillaume Dalle and contributors"]
version = "1.8.1"
version = "1.9.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
3 changes: 2 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ AutoGTPSA
### Reverse mode

```@docs
AutoMooncake
AutoReverseDiff
AutoTapir
AutoTracker
AutoTapir
gdalle marked this conversation as resolved.
Show resolved Hide resolved
AutoZygote
```

Expand Down
1 change: 1 addition & 0 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export AutoChainRules,
AutoForwardDiff,
AutoGTPSA,
AutoModelingToolkit,
AutoMooncake,
AutoPolyesterForwardDiff,
AutoReverseDiff,
AutoSymbolics,
Expand Down
46 changes: 27 additions & 19 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,30 @@ function Base.show(io::IO, backend::AutoGTPSA{D}) where {D}
print(io, ")")
end

"""
AutoMooncake

Struct used to select the [Mooncake.jl](https://github.com/compintell/Mooncake.jl) backend
for automatic differentiation.
gdalle marked this conversation as resolved.
Show resolved Hide resolved

Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).

# Constructors

AutoMooncake(; config)

# Fields

- `config`: either `nothing` or an instance of `Mooncake.Config` -- see the docstring for
`Mooncake.Config` for more information. `AutoMooncake(; config=nothing)` is equivalent
to `AutoMooncake(; config=Mooncake.Config())`, i.e. the default configuration.
gdalle marked this conversation as resolved.
Show resolved Hide resolved
"""
Base.@kwdef struct AutoMooncake{Tconfig} <: AbstractADType
config::Tconfig
end

mode(::AutoMooncake) = ReverseMode()

"""
AutoPolyesterForwardDiff{chunksize,T}

Expand Down Expand Up @@ -323,26 +347,10 @@ mode(::AutoSymbolics) = SymbolicMode()
"""
AutoTapir

Struct used to select the [Tapir.jl](https://github.com/withbayes/Tapir.jl) backend for automatic differentiation.

Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).

# Constructors

AutoTapir(; safe_mode=true)

# Fields

- `safe_mode::Bool`: whether to run additional checks to catch errors early. While this is
gdalle marked this conversation as resolved.
Show resolved Hide resolved
on by default to ensure that users are aware of this option, you should generally turn
it off for actual use, as it has substantial performance implications.
If you encounter a problem with using Tapir (it fails to differentiate a function, or
something truly nasty like a segfault occurs), then you should try switching `safe_mode`
on and look at what happens. Often errors are caught earlier and the error messages are
more useful.
This ADType is deprecated. `AutoMooncake` should be used instead.
gdalle marked this conversation as resolved.
Show resolved Hide resolved
"""
Base.@kwdef struct AutoTapir <: AbstractADType
safe_mode::Bool = true
struct AutoTapir <: AbstractADType
safe_mode::Bool
end

mode(::AutoTapir) = ReverseMode()
Expand Down
7 changes: 7 additions & 0 deletions src/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,10 @@ function AutoModelingToolkit(; obj_sparse::Bool = false, cons_sparse::Bool = fal
:AutoModelingToolkit; force = false)
return mtk_to_symbolics(obj_sparse, cons_sparse)
end

function AutoTapir(; safe_mode=true)
Base.depwarn(
"AutoTapir is deprecated in favour of AutoMooncake.", :AutoTapir; force=false
gdalle marked this conversation as resolved.
Show resolved Hide resolved
)
return AutoTapir(safe_mode)
end
8 changes: 8 additions & 0 deletions test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ end
@test ad.descriptor == Val(:descriptor)
end

@testset "AutoMooncake" begin
ad = AutoMooncake(; config=nothing)
@test ad isa AbstractADType
@test ad isa AutoMooncake
@test mode(ad) isa ReverseMode
@test ad.config === nothing
end

@testset "AutoPolyesterForwardDiff" begin
ad = AutoPolyesterForwardDiff()
@test ad isa AbstractADType
Expand Down
4 changes: 4 additions & 0 deletions test/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,7 @@ end
ad = @test_deprecated AutoReverseDiff(true)
@test ad.compile
end

@testset "AutoTapir" begin
@test_deprecated AutoTapir()
gdalle marked this conversation as resolved.
Show resolved Hide resolved
end
Loading