Skip to content

Commit

Permalink
Add support for custom tags to ForwardDiff types
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jul 19, 2023
1 parent 76daa27 commit 93f3824
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ADTypes"
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com> and contributors"]
version = "0.1.6"
version = "0.1.7"

[compat]
julia = "1"
Expand Down
16 changes: 10 additions & 6 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ end

AutoFiniteDifferences(; fdm = nothing) = AutoFiniteDifferences(fdm)

struct AutoForwardDiff{chunksize} <: AbstractADType end
struct AutoForwardDiff{chunksize,T} <: AbstractADType
tag::T
end

function AutoForwardDiff(chunksize = nothing)
AutoForwardDiff{chunksize}()
function AutoForwardDiff(; chunksize = nothing, tag = nothing)
AutoForwardDiff{chunksize,typeof(tag)}(tag)
end

struct AutoReverseDiff <: AbstractADType
Expand All @@ -51,10 +53,12 @@ end

struct AutoSparseFiniteDiff <: AbstractADType end

struct AutoSparseForwardDiff{chunksize} <: AbstractADType end
struct AutoSparseForwardDiff{chunksize,T} <: AbstractADType
tag::T
end

function AutoSparseForwardDiff(chunksize = nothing)
AutoSparseForwardDiff{chunksize}()
function AutoSparseForwardDiff(; chunksize = nothing, tag = nothing)
AutoSparseForwardDiff{chunksize,typeof(tag)}(tag)
end

export AutoFiniteDiff, AutoFiniteDifferences, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoEnzyme, AutoTracker, AutoModelingToolkit, AutoSparseFiniteDiff, AutoSparseForwardDiff
Expand Down
16 changes: 15 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using ADTypes
using Test

struct CustomTag end

@testset "ADTypes.jl" begin
adtype = AutoFiniteDiff()
@test adtype isa ADTypes.AbstractADType
Expand All @@ -19,7 +21,11 @@ using Test

adtype = AutoForwardDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoForwardDiff
@test adtype isa AutoForwardDiff{nothing,Nothing}

adtype = AutoForwardDiff(; chunksize = 10, tag = CustomTag())
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoForwardDiff{10,CustomTag}

adtype = AutoReverseDiff()
@test adtype isa ADTypes.AbstractADType
Expand All @@ -32,4 +38,12 @@ using Test
adtype = AutoTracker()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoTracker

adtype = AutoSparseForwardDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseForwardDiff{nothing,Nothing}

adtype = AutoSparseForwardDiff(; chunksize = 10, tag = CustomTag())
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseForwardDiff{10,CustomTag}
end

0 comments on commit 93f3824

Please sign in to comment.