diff --git a/Project.toml b/Project.toml index 236771b..f64103b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ADTypes" uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = ["Vaibhav Dixit and contributors"] -version = "0.1.6" +version = "0.1.7" [compat] julia = "1" diff --git a/src/ADTypes.jl b/src/ADTypes.jl index 06c7abe..88e63ad 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -22,10 +22,12 @@ end AutoFiniteDifferences(; fdm = nothing) = AutoFiniteDifferences(fdm) -struct AutoForwardDiff{chunksize} <: AbstractADType end +struct AutoForwardDiff{chunksize,T} <: AbstractADType + tag::T +end -function AutoForwardDiff(chunksize = nothing) - AutoForwardDiff{chunksize}() +function AutoForwardDiff(; chunksize = nothing, tag = nothing) + AutoForwardDiff{chunksize,typeof(tag)}(tag) end struct AutoReverseDiff <: AbstractADType @@ -51,10 +53,12 @@ end struct AutoSparseFiniteDiff <: AbstractADType end -struct AutoSparseForwardDiff{chunksize} <: AbstractADType end +struct AutoSparseForwardDiff{chunksize,T} <: AbstractADType + tag::T +end -function AutoSparseForwardDiff(chunksize = nothing) - AutoSparseForwardDiff{chunksize}() +function AutoSparseForwardDiff(; chunksize = nothing, tag = nothing) + AutoSparseForwardDiff{chunksize,typeof(tag)}(tag) end export AutoFiniteDiff, AutoFiniteDifferences, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoEnzyme, AutoTracker, AutoModelingToolkit, AutoSparseFiniteDiff, AutoSparseForwardDiff diff --git a/test/runtests.jl b/test/runtests.jl index c0bc197..d3c42cd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,8 @@ using ADTypes using Test +struct CustomTag end + @testset "ADTypes.jl" begin adtype = AutoFiniteDiff() @test adtype isa ADTypes.AbstractADType @@ -19,7 +21,11 @@ using Test adtype = AutoForwardDiff() @test adtype isa ADTypes.AbstractADType - @test adtype isa AutoForwardDiff + @test adtype isa AutoForwardDiff{nothing,Nothing} + + adtype = AutoForwardDiff(; chunksize = 10, tag = CustomTag()) + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoForwardDiff{10,CustomTag} adtype = AutoReverseDiff() @test adtype isa ADTypes.AbstractADType @@ -32,4 +38,12 @@ using Test adtype = AutoTracker() @test adtype isa ADTypes.AbstractADType @test adtype isa AutoTracker + + adtype = AutoSparseForwardDiff() + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoSparseForwardDiff{nothing,Nothing} + + adtype = AutoSparseForwardDiff(; chunksize = 10, tag = CustomTag()) + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoSparseForwardDiff{10,CustomTag} end