From 885a5699da11d4294f4a704a57494fca536e9988 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 18 Jul 2023 23:25:12 +0200 Subject: [PATCH 1/3] Support Enzyme forward + reverse mode --- Project.toml | 6 +++++- src/ADTypes.jl | 11 +++++++++-- test/runtests.jl | 49 +++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 62 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 157c9b9..426e10b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,13 @@ name = "ADTypes" uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = ["Vaibhav Dixit and contributors"] -version = "0.1.5" +version = "0.1.6" + +[deps] +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [compat] +EnzymeCore = "0.1, 0.2, 0.3, 0.4, 0.5" julia = "1" [extras] diff --git a/src/ADTypes.jl b/src/ADTypes.jl index 8b2cf72..366eae4 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -1,5 +1,7 @@ module ADTypes +import EnzymeCore + """ Base type for AD choices. """ @@ -30,7 +32,11 @@ AutoReverseDiff(; compile = false) = AutoReverseDiff(compile) struct AutoZygote <: AbstractADType end -struct AutoEnzyme <: AbstractADType end +struct AutoEnzyme{M <: EnzymeCore.Mode} <: AbstractADType + mode::M +end + +AutoEnzyme(; mode::EnzymeCore.Mode = EnzymeCore.Reverse) = AutoEnzyme(mode) struct AutoTracker <: AbstractADType end @@ -51,5 +57,6 @@ function AutoSparseForwardDiff(chunksize = nothing) AutoSparseForwardDiff{chunksize}() end -export AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoEnzyme, AutoTracker, AutoModelingToolkit, AutoSparseFiniteDiff, AutoSparseForwardDiff +export AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoEnzyme, + AutoTracker, AutoModelingToolkit, AutoSparseFiniteDiff, AutoSparseForwardDiff end diff --git a/test/runtests.jl b/test/runtests.jl index 7e1a1bd..d75f71e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,18 +1,33 @@ using ADTypes using Test +import EnzymeCore + @testset "ADTypes.jl" begin adtype = AutoFiniteDiff() @test adtype isa ADTypes.AbstractADType @test adtype isa AutoFiniteDiff + @test adtype.fdtype === Val(:forward) + @test adtype.fdjtype === Val(:forward) + @test adtype.fdhtype === Val(:hcentral) adtype = AutoForwardDiff() @test adtype isa ADTypes.AbstractADType - @test adtype isa AutoForwardDiff + @test adtype isa AutoForwardDiff{nothing} + + adtype = AutoForwardDiff(10) + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoForwardDiff{10} adtype = AutoReverseDiff() @test adtype isa ADTypes.AbstractADType @test adtype isa AutoReverseDiff + @test !adtype.compile + + adtype = AutoReverseDiff(; compile = true) + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoReverseDiff + @test adtype.compile adtype = AutoZygote() @test adtype isa ADTypes.AbstractADType @@ -21,4 +36,36 @@ using Test adtype = AutoTracker() @test adtype isa ADTypes.AbstractADType @test adtype isa AutoTracker + + adtype = AutoEnzyme() + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoEnzyme{<:EnzymeCore.ReverseMode} + + adtype = AutoEnzyme(; mode = EnzymeCore.Forward) + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoEnzyme{<:EnzymeCore.ForwardMode} + + adtype = AutoModelingToolkit() + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoModelingToolkit + @test !adtype.obj_sparse + @test !adtype.cons_sparse + + adtype = AutoModelingToolkit(; obj_sparse = true, cons_sparse = true) + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoModelingToolkit + @test adtype.obj_sparse + @test adtype.cons_sparse + + adtype = AutoSparseFiniteDiff() + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoSparseFiniteDiff + + adtype = AutoSparseForwardDiff() + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoSparseForwardDiff{nothing} + + adtype = AutoSparseForwardDiff(10) + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoSparseForwardDiff{10} end From 1229375527c7b062e8aa941b19c5c3fc960de02a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 19 Jul 2023 21:32:14 +0200 Subject: [PATCH 2/3] Remove type constraints and use default of `nothing` --- Project.toml | 7 ++----- src/ADTypes.jl | 6 ++---- test/runtests.jl | 6 +++--- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 426e10b..9da6bd5 100644 --- a/Project.toml +++ b/Project.toml @@ -3,15 +3,12 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = ["Vaibhav Dixit and contributors"] version = "0.1.6" -[deps] -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - [compat] -EnzymeCore = "0.1, 0.2, 0.3, 0.4, 0.5" julia = "1" [extras] +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["EnzymeCore", "Test"] diff --git a/src/ADTypes.jl b/src/ADTypes.jl index 366eae4..14bb43e 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -1,7 +1,5 @@ module ADTypes -import EnzymeCore - """ Base type for AD choices. """ @@ -32,11 +30,11 @@ AutoReverseDiff(; compile = false) = AutoReverseDiff(compile) struct AutoZygote <: AbstractADType end -struct AutoEnzyme{M <: EnzymeCore.Mode} <: AbstractADType +struct AutoEnzyme{M} <: AbstractADType mode::M end -AutoEnzyme(; mode::EnzymeCore.Mode = EnzymeCore.Reverse) = AutoEnzyme(mode) +AutoEnzyme(; mode = nothing) = AutoEnzyme(mode) struct AutoTracker <: AbstractADType end diff --git a/test/runtests.jl b/test/runtests.jl index d75f71e..c8988c5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,11 +39,11 @@ import EnzymeCore adtype = AutoEnzyme() @test adtype isa ADTypes.AbstractADType - @test adtype isa AutoEnzyme{<:EnzymeCore.ReverseMode} + @test adtype isa AutoEnzyme{Nothing} - adtype = AutoEnzyme(; mode = EnzymeCore.Forward) + adtype = AutoEnzyme(; mode = EnzymeCore.Reverse) @test adtype isa ADTypes.AbstractADType - @test adtype isa AutoEnzyme{<:EnzymeCore.ForwardMode} + @test adtype isa AutoEnzyme{<:EnzymeCore.ReverseMode} adtype = AutoModelingToolkit() @test adtype isa ADTypes.AbstractADType From acdf8958709f85339ea7bb92cf77cb88319de68e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 19 Jul 2023 23:24:15 +0200 Subject: [PATCH 3/3] Remove EnzymeCore test dependency --- Project.toml | 3 +-- test/runtests.jl | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 9da6bd5..236771b 100644 --- a/Project.toml +++ b/Project.toml @@ -7,8 +7,7 @@ version = "0.1.6" julia = "1" [extras] -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["EnzymeCore", "Test"] +test = ["Test"] diff --git a/test/runtests.jl b/test/runtests.jl index 65e840a..3944cf3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,6 @@ using ADTypes using Test -import EnzymeCore - @testset "ADTypes.jl" begin adtype = AutoFiniteDiff() @test adtype isa ADTypes.AbstractADType @@ -52,9 +50,11 @@ import EnzymeCore @test adtype isa ADTypes.AbstractADType @test adtype isa AutoEnzyme{Nothing} - adtype = AutoEnzyme(; mode = EnzymeCore.Reverse) + # In practice, you would rather specify a + # `mode::Enzyme.Mode`, e.g. `Enzyme.Reverse` or `Enzyme.Forward` + adtype = AutoEnzyme(; mode = Val(:Reverse)) @test adtype isa ADTypes.AbstractADType - @test adtype isa AutoEnzyme{<:EnzymeCore.ReverseMode} + @test adtype isa AutoEnzyme{Val{:Reverse}} adtype = AutoModelingToolkit() @test adtype isa ADTypes.AbstractADType