diff --git a/src/ADTypes.jl b/src/ADTypes.jl index a30e6f2..2b2fca8 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -22,10 +22,12 @@ end AutoFiniteDifferences(; fdm = nothing) = AutoFiniteDifferences(fdm) -struct AutoForwardDiff{chunksize} <: AbstractADType end +struct AutoForwardDiff{chunksize,T} <: AbstractADType + tag::T +end -function AutoForwardDiff(chunksize = nothing) - AutoForwardDiff{chunksize}() +function AutoForwardDiff(; chunksize = nothing, tag = nothing) + AutoForwardDiff{chunksize,typeof(tag)}(tag) end struct AutoReverseDiff <: AbstractADType @@ -55,10 +57,12 @@ end struct AutoSparseFiniteDiff <: AbstractADType end -struct AutoSparseForwardDiff{chunksize} <: AbstractADType end +struct AutoSparseForwardDiff{chunksize,T} <: AbstractADType + tag::T +end -function AutoSparseForwardDiff(chunksize = nothing) - AutoSparseForwardDiff{chunksize}() +function AutoSparseForwardDiff(; chunksize = nothing, tag = nothing) + AutoSparseForwardDiff{chunksize,typeof(tag)}(tag) end export AutoFiniteDiff, AutoFiniteDifferences, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoEnzyme, AutoTracker, AutoModelingToolkit, AutoSparseFiniteDiff, AutoSparseForwardDiff diff --git a/test/runtests.jl b/test/runtests.jl index 3944cf3..d309cf6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,8 @@ using ADTypes using Test +struct CustomTag end + @testset "ADTypes.jl" begin adtype = AutoFiniteDiff() @test adtype isa ADTypes.AbstractADType @@ -22,11 +24,12 @@ using Test adtype = AutoForwardDiff() @test adtype isa ADTypes.AbstractADType - @test adtype isa AutoForwardDiff{nothing} - adtype = AutoForwardDiff(10) + @test adtype isa AutoForwardDiff{nothing,Nothing} + + adtype = AutoForwardDiff(; chunksize = 10, tag = CustomTag()) @test adtype isa ADTypes.AbstractADType - @test adtype isa AutoForwardDiff{10} + @test adtype isa AutoForwardDiff{10,CustomTag} adtype = AutoReverseDiff() @test adtype isa ADTypes.AbstractADType @@ -46,15 +49,13 @@ using Test @test adtype isa ADTypes.AbstractADType @test adtype isa AutoTracker - adtype = AutoEnzyme() + adtype = AutoSparseForwardDiff() @test adtype isa ADTypes.AbstractADType - @test adtype isa AutoEnzyme{Nothing} + @test adtype isa AutoSparseForwardDiff{nothing,Nothing} - # In practice, you would rather specify a - # `mode::Enzyme.Mode`, e.g. `Enzyme.Reverse` or `Enzyme.Forward` - adtype = AutoEnzyme(; mode = Val(:Reverse)) + adtype = AutoSparseForwardDiff(; chunksize = 10, tag = CustomTag()) @test adtype isa ADTypes.AbstractADType - @test adtype isa AutoEnzyme{Val{:Reverse}} + @test adtype isa AutoSparseForwardDiff{10,CustomTag} adtype = AutoModelingToolkit() @test adtype isa ADTypes.AbstractADType @@ -76,7 +77,13 @@ using Test @test adtype isa ADTypes.AbstractADType @test adtype isa AutoSparseForwardDiff{nothing} - adtype = AutoSparseForwardDiff(10) + adtype = AutoEnzyme() + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoEnzyme{Nothing} + + # In practice, you would rather specify a + # `mode::Enzyme.Mode`, e.g. `Enzyme.Reverse` or `Enzyme.Forward` + adtype = AutoEnzyme(; mode = Val(:Reverse)) @test adtype isa ADTypes.AbstractADType - @test adtype isa AutoSparseForwardDiff{10} + @test adtype isa AutoEnzyme{Val{:Reverse}} end