Skip to content

Commit

Permalink
Merge pull request #12 from SciML/dw/forwarddiff
Browse files Browse the repository at this point in the history
Add support for custom tags to ForwardDiff types
  • Loading branch information
Vaibhavdixit02 authored Jul 25, 2023
2 parents 9e91d8a + b3ae9ac commit d29dd9d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
16 changes: 10 additions & 6 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ end

AutoFiniteDifferences(; fdm = nothing) = AutoFiniteDifferences(fdm)

struct AutoForwardDiff{chunksize} <: AbstractADType end
struct AutoForwardDiff{chunksize,T} <: AbstractADType
tag::T
end

function AutoForwardDiff(chunksize = nothing)
AutoForwardDiff{chunksize}()
function AutoForwardDiff(; chunksize = nothing, tag = nothing)
AutoForwardDiff{chunksize,typeof(tag)}(tag)
end

struct AutoReverseDiff <: AbstractADType
Expand Down Expand Up @@ -55,10 +57,12 @@ end

struct AutoSparseFiniteDiff <: AbstractADType end

struct AutoSparseForwardDiff{chunksize} <: AbstractADType end
struct AutoSparseForwardDiff{chunksize,T} <: AbstractADType
tag::T
end

function AutoSparseForwardDiff(chunksize = nothing)
AutoSparseForwardDiff{chunksize}()
function AutoSparseForwardDiff(; chunksize = nothing, tag = nothing)
AutoSparseForwardDiff{chunksize,typeof(tag)}(tag)
end

export AutoFiniteDiff, AutoFiniteDifferences, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoEnzyme, AutoTracker, AutoModelingToolkit, AutoSparseFiniteDiff, AutoSparseForwardDiff
Expand Down
29 changes: 18 additions & 11 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using ADTypes
using Test

struct CustomTag end

@testset "ADTypes.jl" begin
adtype = AutoFiniteDiff()
@test adtype isa ADTypes.AbstractADType
Expand All @@ -22,11 +24,12 @@ using Test

adtype = AutoForwardDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoForwardDiff{nothing}

adtype = AutoForwardDiff(10)
@test adtype isa AutoForwardDiff{nothing,Nothing}

adtype = AutoForwardDiff(; chunksize = 10, tag = CustomTag())
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoForwardDiff{10}
@test adtype isa AutoForwardDiff{10,CustomTag}

adtype = AutoReverseDiff()
@test adtype isa ADTypes.AbstractADType
Expand All @@ -46,15 +49,13 @@ using Test
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoTracker

adtype = AutoEnzyme()
adtype = AutoSparseForwardDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoEnzyme{Nothing}
@test adtype isa AutoSparseForwardDiff{nothing,Nothing}

# In practice, you would rather specify a
# `mode::Enzyme.Mode`, e.g. `Enzyme.Reverse` or `Enzyme.Forward`
adtype = AutoEnzyme(; mode = Val(:Reverse))
adtype = AutoSparseForwardDiff(; chunksize = 10, tag = CustomTag())
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoEnzyme{Val{:Reverse}}
@test adtype isa AutoSparseForwardDiff{10,CustomTag}

adtype = AutoModelingToolkit()
@test adtype isa ADTypes.AbstractADType
Expand All @@ -76,7 +77,13 @@ using Test
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseForwardDiff{nothing}

adtype = AutoSparseForwardDiff(10)
adtype = AutoEnzyme()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoEnzyme{Nothing}

# In practice, you would rather specify a
# `mode::Enzyme.Mode`, e.g. `Enzyme.Reverse` or `Enzyme.Forward`
adtype = AutoEnzyme(; mode = Val(:Reverse))
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseForwardDiff{10}
@test adtype isa AutoEnzyme{Val{:Reverse}}
end

0 comments on commit d29dd9d

Please sign in to comment.