From efdc12e72b44dba4f858bbef85c2835b2347f50e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Jun 2024 12:38:50 -0700 Subject: [PATCH 1/2] Promote ReverseDiff compile field to type --- Project.toml | 2 +- src/ADTypes.jl | 3 +++ src/dense.jl | 13 +++++++++---- src/symbols.jl | 8 ++++---- test/dense.jl | 8 +++++++- 5 files changed, 24 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index c54c8ca..c894b31 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.4.0" +version = "1.5.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/ADTypes.jl b/src/ADTypes.jl index 56ebaae..0000a40 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -16,6 +16,9 @@ abstract type AbstractADType end Base.broadcastable(ad::AbstractADType) = Ref(ad) +@inline _unwrap_val(::Val{T}) where {T} = T +@inline _unwrap_val(x) = x + include("mode.jl") include("dense.jl") include("sparse.jl") diff --git a/src/dense.jl b/src/dense.jl index ed3aea7..8fc7b8b 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -186,14 +186,19 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). # Constructors - AutoReverseDiff(; compile=false) + AutoReverseDiff(; compile::Union{Val, Bool} = Val(false)) # Fields - - `compile::Bool`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation + - `compile::Union{Val, Bool}`: whether to [compile the tape](https://juliadiff.org/ReverseDiff.jl/api/#ReverseDiff.compile) prior to differentiation """ -Base.@kwdef struct AutoReverseDiff <: AbstractADType - compile::Bool = false +struct AutoReverseDiff{C} <: AbstractADType + compile::Bool # this field if left for legacy reasons + + function AutoReverseDiff(; compile::Union{Val, Bool} = Val(false)) + _compile = _unwrap_val(compile) + return new{_compile}(_compile) + end end mode(::AutoReverseDiff) = ReverseMode() diff --git a/src/symbols.jl b/src/symbols.jl index f349e84..b0f019c 100644 --- a/src/symbols.jl +++ b/src/symbols.jl @@ -22,8 +22,8 @@ ADTypes.AutoZygote() Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...) for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation, - :FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff, - :ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote) - @eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(args...; kws...) + :FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff, + :ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote) + @eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))( + args...; kws...) end - diff --git a/test/dense.jl b/test/dense.jl index abe3f9b..b33cc59 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -113,7 +113,7 @@ end end @testset "AutoReverseDiff" begin - ad = AutoReverseDiff() + ad = @inferred AutoReverseDiff() @test ad isa AbstractADType @test ad isa AutoReverseDiff @test mode(ad) isa ReverseMode @@ -124,6 +124,12 @@ end @test ad isa AutoReverseDiff @test mode(ad) isa ReverseMode @test ad.compile + + ad = @inferred AutoReverseDiff(; compile = Val(true)) + @test ad isa AbstractADType + @test ad isa AutoReverseDiff + @test mode(ad) isa ReverseMode + @test ad.compile end @testset "AutoSymbolics" begin From 6c27b870db53a1bf9f4bfc7ce734b3bc43503cb3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Jun 2024 16:27:38 -0700 Subject: [PATCH 2/2] Add depwarn for .compile --- src/dense.jl | 9 +++++++++ test/dense.jl | 3 +++ 2 files changed, 12 insertions(+) diff --git a/src/dense.jl b/src/dense.jl index 8fc7b8b..c6e812a 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -201,6 +201,15 @@ struct AutoReverseDiff{C} <: AbstractADType end end +function Base.getproperty(ad::AutoReverseDiff, s::Symbol) + if s === :compile + Base.depwarn( + "`ad.compile` where `ad` is `AutoReverseDiff` has been deprecated and will be removed in v2. Instead it is available as a compile-time constant as `AutoReverseDiff{true}` or `AutoReverseDiff{false}`.", + :getproperty) + end + return getfield(ad, s) +end + mode(::AutoReverseDiff) = ReverseMode() """ diff --git a/test/dense.jl b/test/dense.jl index b33cc59..8ff70b7 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -118,18 +118,21 @@ end @test ad isa AutoReverseDiff @test mode(ad) isa ReverseMode @test !ad.compile + @test_deprecated ad.compile ad = AutoReverseDiff(; compile = true) @test ad isa AbstractADType @test ad isa AutoReverseDiff @test mode(ad) isa ReverseMode @test ad.compile + @test_deprecated ad.compile ad = @inferred AutoReverseDiff(; compile = Val(true)) @test ad isa AbstractADType @test ad isa AutoReverseDiff @test mode(ad) isa ReverseMode @test ad.compile + @test_deprecated ad.compile end @testset "AutoSymbolics" begin