Skip to content

Commit

Permalink
Support Enzyme forward + reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Jul 18, 2023
1 parent d19986f commit 885a569
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 4 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
name = "ADTypes"
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com> and contributors"]
version = "0.1.5"
version = "0.1.6"

[deps]
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[compat]
EnzymeCore = "0.1, 0.2, 0.3, 0.4, 0.5"
julia = "1"

[extras]
Expand Down
11 changes: 9 additions & 2 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module ADTypes

import EnzymeCore

"""
Base type for AD choices.
"""
Expand Down Expand Up @@ -30,7 +32,11 @@ AutoReverseDiff(; compile = false) = AutoReverseDiff(compile)

struct AutoZygote <: AbstractADType end

struct AutoEnzyme <: AbstractADType end
struct AutoEnzyme{M <: EnzymeCore.Mode} <: AbstractADType
mode::M
end

AutoEnzyme(; mode::EnzymeCore.Mode = EnzymeCore.Reverse) = AutoEnzyme(mode)

struct AutoTracker <: AbstractADType end

Expand All @@ -51,5 +57,6 @@ function AutoSparseForwardDiff(chunksize = nothing)
AutoSparseForwardDiff{chunksize}()
end

export AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoEnzyme, AutoTracker, AutoModelingToolkit, AutoSparseFiniteDiff, AutoSparseForwardDiff
export AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoEnzyme,
AutoTracker, AutoModelingToolkit, AutoSparseFiniteDiff, AutoSparseForwardDiff
end
49 changes: 48 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
using ADTypes
using Test

import EnzymeCore

@testset "ADTypes.jl" begin
adtype = AutoFiniteDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoFiniteDiff
@test adtype.fdtype === Val(:forward)
@test adtype.fdjtype === Val(:forward)
@test adtype.fdhtype === Val(:hcentral)

adtype = AutoForwardDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoForwardDiff
@test adtype isa AutoForwardDiff{nothing}

adtype = AutoForwardDiff(10)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoForwardDiff{10}

adtype = AutoReverseDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoReverseDiff
@test !adtype.compile

adtype = AutoReverseDiff(; compile = true)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoReverseDiff
@test adtype.compile

adtype = AutoZygote()
@test adtype isa ADTypes.AbstractADType
Expand All @@ -21,4 +36,36 @@ using Test
adtype = AutoTracker()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoTracker

adtype = AutoEnzyme()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoEnzyme{<:EnzymeCore.ReverseMode}

adtype = AutoEnzyme(; mode = EnzymeCore.Forward)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoEnzyme{<:EnzymeCore.ForwardMode}

adtype = AutoModelingToolkit()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoModelingToolkit
@test !adtype.obj_sparse
@test !adtype.cons_sparse

adtype = AutoModelingToolkit(; obj_sparse = true, cons_sparse = true)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoModelingToolkit
@test adtype.obj_sparse
@test adtype.cons_sparse

adtype = AutoSparseFiniteDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseFiniteDiff

adtype = AutoSparseForwardDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseForwardDiff{nothing}

adtype = AutoSparseForwardDiff(10)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseForwardDiff{10}
end

0 comments on commit 885a569

Please sign in to comment.