Skip to content

Commit

Permalink
Merge branch 'main' into dw/forwarddiff
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 authored Jul 25, 2023
2 parents 93f3824 + 9e91d8a commit 13c2bc8
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ AutoReverseDiff(; compile = false) = AutoReverseDiff(compile)

struct AutoZygote <: AbstractADType end

struct AutoEnzyme <: AbstractADType end
struct AutoEnzyme{M} <: AbstractADType
mode::M
end

AutoEnzyme(; mode = nothing) = AutoEnzyme(mode)

struct AutoTracker <: AbstractADType end

Expand Down
50 changes: 50 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ struct CustomTag end
adtype = AutoFiniteDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoFiniteDiff
@test adtype.fdtype === Val(:forward)
@test adtype.fdjtype === Val(:forward)
@test adtype.fdhtype === Val(:hcentral)

adtype = AutoFiniteDifferences()
@test adtype isa ADTypes.AbstractADType
Expand All @@ -21,15 +24,27 @@ struct CustomTag end

adtype = AutoForwardDiff()
@test adtype isa ADTypes.AbstractADType

@test adtype isa AutoForwardDiff{nothing,Nothing}

adtype = AutoForwardDiff(; chunksize = 10, tag = CustomTag())
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoForwardDiff{10,CustomTag}

adtype = AutoForwardDiff(10)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoForwardDiff{10}


adtype = AutoReverseDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoReverseDiff
@test !adtype.compile

adtype = AutoReverseDiff(; compile = true)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoReverseDiff
@test adtype.compile

adtype = AutoZygote()
@test adtype isa ADTypes.AbstractADType
Expand All @@ -46,4 +61,39 @@ struct CustomTag end
adtype = AutoSparseForwardDiff(; chunksize = 10, tag = CustomTag())
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseForwardDiff{10,CustomTag}

adtype = AutoModelingToolkit()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoModelingToolkit
@test !adtype.obj_sparse
@test !adtype.cons_sparse

adtype = AutoModelingToolkit(; obj_sparse = true, cons_sparse = true)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoModelingToolkit
@test adtype.obj_sparse
@test adtype.cons_sparse

adtype = AutoSparseFiniteDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseFiniteDiff

adtype = AutoSparseForwardDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseForwardDiff{nothing}

adtype = AutoSparseForwardDiff(10)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseForwardDiff{10}

adtype = AutoEnzyme()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoEnzyme{Nothing}

# In practice, you would rather specify a
# `mode::Enzyme.Mode`, e.g. `Enzyme.Reverse` or `Enzyme.Forward`
adtype = AutoEnzyme(; mode = Val(:Reverse))
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoEnzyme{Val{:Reverse}}

end

0 comments on commit 13c2bc8

Please sign in to comment.