From 885a5699da11d4294f4a704a57494fca536e9988 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 18 Jul 2023 23:25:12 +0200 Subject: [PATCH] Support Enzyme forward + reverse mode --- Project.toml | 6 +++++- src/ADTypes.jl | 11 +++++++++-- test/runtests.jl | 49 +++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 62 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 157c9b9..426e10b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,13 @@ name = "ADTypes" uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = ["Vaibhav Dixit and contributors"] -version = "0.1.5" +version = "0.1.6" + +[deps] +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [compat] +EnzymeCore = "0.1, 0.2, 0.3, 0.4, 0.5" julia = "1" [extras] diff --git a/src/ADTypes.jl b/src/ADTypes.jl index 8b2cf72..366eae4 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -1,5 +1,7 @@ module ADTypes +import EnzymeCore + """ Base type for AD choices. """ @@ -30,7 +32,11 @@ AutoReverseDiff(; compile = false) = AutoReverseDiff(compile) struct AutoZygote <: AbstractADType end -struct AutoEnzyme <: AbstractADType end +struct AutoEnzyme{M <: EnzymeCore.Mode} <: AbstractADType + mode::M +end + +AutoEnzyme(; mode::EnzymeCore.Mode = EnzymeCore.Reverse) = AutoEnzyme(mode) struct AutoTracker <: AbstractADType end @@ -51,5 +57,6 @@ function AutoSparseForwardDiff(chunksize = nothing) AutoSparseForwardDiff{chunksize}() end -export AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoEnzyme, AutoTracker, AutoModelingToolkit, AutoSparseFiniteDiff, AutoSparseForwardDiff +export AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoEnzyme, + AutoTracker, AutoModelingToolkit, AutoSparseFiniteDiff, AutoSparseForwardDiff end diff --git a/test/runtests.jl b/test/runtests.jl index 7e1a1bd..d75f71e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,18 +1,33 @@ using ADTypes using Test +import EnzymeCore + @testset "ADTypes.jl" begin adtype = AutoFiniteDiff() @test adtype isa ADTypes.AbstractADType @test adtype isa AutoFiniteDiff + @test adtype.fdtype === Val(:forward) + @test adtype.fdjtype === Val(:forward) + @test adtype.fdhtype === Val(:hcentral) adtype = AutoForwardDiff() @test adtype isa ADTypes.AbstractADType - @test adtype isa AutoForwardDiff + @test adtype isa AutoForwardDiff{nothing} + + adtype = AutoForwardDiff(10) + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoForwardDiff{10} adtype = AutoReverseDiff() @test adtype isa ADTypes.AbstractADType @test adtype isa AutoReverseDiff + @test !adtype.compile + + adtype = AutoReverseDiff(; compile = true) + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoReverseDiff + @test adtype.compile adtype = AutoZygote() @test adtype isa ADTypes.AbstractADType @@ -21,4 +36,36 @@ using Test adtype = AutoTracker() @test adtype isa ADTypes.AbstractADType @test adtype isa AutoTracker + + adtype = AutoEnzyme() + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoEnzyme{<:EnzymeCore.ReverseMode} + + adtype = AutoEnzyme(; mode = EnzymeCore.Forward) + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoEnzyme{<:EnzymeCore.ForwardMode} + + adtype = AutoModelingToolkit() + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoModelingToolkit + @test !adtype.obj_sparse + @test !adtype.cons_sparse + + adtype = AutoModelingToolkit(; obj_sparse = true, cons_sparse = true) + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoModelingToolkit + @test adtype.obj_sparse + @test adtype.cons_sparse + + adtype = AutoSparseFiniteDiff() + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoSparseFiniteDiff + + adtype = AutoSparseForwardDiff() + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoSparseForwardDiff{nothing} + + adtype = AutoSparseForwardDiff(10) + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoSparseForwardDiff{10} end