diff --git a/Project.toml b/Project.toml index c54c8ca..c894b31 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.4.0" +version = "1.5.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/ADTypes.jl b/src/ADTypes.jl index 56ebaae..0000a40 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -16,6 +16,9 @@ abstract type AbstractADType end Base.broadcastable(ad::AbstractADType) = Ref(ad) +@inline _unwrap_val(::Val{T}) where {T} = T +@inline _unwrap_val(x) = x + include("mode.jl") include("dense.jl") include("sparse.jl") diff --git a/src/dense.jl b/src/dense.jl index ed3aea7..c6e812a 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -186,14 +186,28 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). # Constructors - AutoReverseDiff(; compile=false) + AutoReverseDiff(; compile::Union{Val, Bool} = Val(false)) # Fields - - `compile::Bool`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation + - `compile::Union{Val, Bool}`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation """ -Base.@kwdef struct AutoReverseDiff <: AbstractADType - compile::Bool = false +struct AutoReverseDiff{C} <: AbstractADType + compile::Bool # this field if left for legacy reasons + + function AutoReverseDiff(; compile::Union{Val, Bool} = Val(false)) + _compile = _unwrap_val(compile) + return new{_compile}(_compile) + end +end + +function Base.getproperty(ad::AutoReverseDiff, s::Symbol) + if s === :compile + Base.depwarn( + "`ad.compile` where `ad` is `AutoReverseDiff` has been deprecated and will be removed in v2. Instead it is available as a compile-time constant as `AutoReverseDiff{true}` or `AutoReverseDiff{false}`.", + :getproperty) + end + return getfield(ad, s) end mode(::AutoReverseDiff) = ReverseMode() diff --git a/src/symbols.jl b/src/symbols.jl index f349e84..b0f019c 100644 --- a/src/symbols.jl +++ b/src/symbols.jl @@ -22,8 +22,8 @@ ADTypes.AutoZygote() Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...) for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation, - :FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff, - :ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote) - @eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(args...; kws...) + :FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff, + :ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote) + @eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))( + args...; kws...) end - diff --git a/test/dense.jl b/test/dense.jl index abe3f9b..8ff70b7 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -113,17 +113,26 @@ end end @testset "AutoReverseDiff" begin - ad = AutoReverseDiff() + ad = @inferred AutoReverseDiff() @test ad isa AbstractADType @test ad isa AutoReverseDiff @test mode(ad) isa ReverseMode @test !ad.compile + @test_deprecated ad.compile ad = AutoReverseDiff(; compile = true) @test ad isa AbstractADType @test ad isa AutoReverseDiff @test mode(ad) isa ReverseMode @test ad.compile + @test_deprecated ad.compile + + ad = @inferred AutoReverseDiff(; compile = Val(true)) + @test ad isa AbstractADType + @test ad isa AutoReverseDiff + @test mode(ad) isa ReverseMode + @test ad.compile + @test_deprecated ad.compile end @testset "AutoSymbolics" begin