Skip to content

Commit

Permalink
Merge pull request #10 from SciML/dw/enzyme
Browse files Browse the repository at this point in the history
Support Enzyme forward + reverse mode
  • Loading branch information
Vaibhavdixit02 authored Jul 25, 2023
2 parents 76daa27 + acdf895 commit 9e91d8a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ AutoReverseDiff(; compile = false) = AutoReverseDiff(compile)

struct AutoZygote <: AbstractADType end

struct AutoEnzyme <: AbstractADType end
struct AutoEnzyme{M} <: AbstractADType
mode::M
end

AutoEnzyme(; mode = nothing) = AutoEnzyme(mode)

struct AutoTracker <: AbstractADType end

Expand Down
49 changes: 48 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ using Test
adtype = AutoFiniteDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoFiniteDiff
@test adtype.fdtype === Val(:forward)
@test adtype.fdjtype === Val(:forward)
@test adtype.fdhtype === Val(:hcentral)

adtype = AutoFiniteDifferences()
@test adtype isa ADTypes.AbstractADType
Expand All @@ -19,11 +22,21 @@ using Test

adtype = AutoForwardDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoForwardDiff
@test adtype isa AutoForwardDiff{nothing}

adtype = AutoForwardDiff(10)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoForwardDiff{10}

adtype = AutoReverseDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoReverseDiff
@test !adtype.compile

adtype = AutoReverseDiff(; compile = true)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoReverseDiff
@test adtype.compile

adtype = AutoZygote()
@test adtype isa ADTypes.AbstractADType
Expand All @@ -32,4 +45,38 @@ using Test
adtype = AutoTracker()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoTracker

adtype = AutoEnzyme()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoEnzyme{Nothing}

# In practice, you would rather specify a
# `mode::Enzyme.Mode`, e.g. `Enzyme.Reverse` or `Enzyme.Forward`
adtype = AutoEnzyme(; mode = Val(:Reverse))
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoEnzyme{Val{:Reverse}}

adtype = AutoModelingToolkit()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoModelingToolkit
@test !adtype.obj_sparse
@test !adtype.cons_sparse

adtype = AutoModelingToolkit(; obj_sparse = true, cons_sparse = true)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoModelingToolkit
@test adtype.obj_sparse
@test adtype.cons_sparse

adtype = AutoSparseFiniteDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseFiniteDiff

adtype = AutoSparseForwardDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseForwardDiff{nothing}

adtype = AutoSparseForwardDiff(10)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseForwardDiff{10}
end

0 comments on commit 9e91d8a

Please sign in to comment.