From 79845648dfcfd2391e158cfed7571eb662c133b4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 17 Jul 2024 08:49:26 +0200 Subject: [PATCH 1/3] Add constant_function kwarg to AutoEnzyme --- Project.toml | 2 +- src/dense.jl | 19 +++++++++++++++---- test/dense.jl | 14 ++++++++++---- test/misc.jl | 5 ----- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 2ada300..7a7fad8 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.5.4" +version = "1.6.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/dense.jl b/src/dense.jl index 4ca7045..57b064b 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -39,7 +39,7 @@ struct AutoDiffractor <: AbstractADType end mode(::AutoDiffractor) = ForwardOrReverseMode() """ - AutoEnzyme{M} + AutoEnzyme{M,constant_function} Struct used to select the [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) backend for automatic differentiation. @@ -47,7 +47,10 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). # Constructors - AutoEnzyme(; mode=nothing) + AutoEnzyme(; mode=nothing, constant_function::Bool=false) + +The `constant_function` keyword argument (and type parameter) determines whether the function object itself should be considered constant or not during differentiation with Enzyme.jl. +For simple functions, this should usually be set to `false`, but in the case of closures or callable structs which contain differentiated data, it should be set to `true`. # Fields @@ -56,8 +59,16 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required + `nothing` to choose the best mode automatically """ -Base.@kwdef struct AutoEnzyme{M} <: AbstractADType - mode::M = nothing +struct AutoEnzyme{M, constant_function} <: AbstractADType + mode::M +end + +function AutoEnzyme(mode::M; constant_function::Bool = false) where {M} + return AutoEnzyme{M, constant_function}(mode) +end + +function AutoEnzyme(; mode::M = nothing, constant_function::Bool = false) where {M} + return AutoEnzyme{M, constant_function}(mode) end mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension diff --git a/test/dense.jl b/test/dense.jl index 15f784d..739cf59 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -28,19 +28,25 @@ end @testset "AutoEnzyme" begin ad = AutoEnzyme() @test ad isa AbstractADType - @test ad isa AutoEnzyme{Nothing} + @test ad isa AutoEnzyme{Nothing, false} @test mode(ad) isa ForwardOrReverseMode @test ad.mode === nothing + ad = AutoEnzyme(EnzymeCore.Forward; constant_function = true) + @test ad isa AbstractADType + @test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), true} + @test mode(ad) isa ForwardMode + @test ad.mode == EnzymeCore.Forward + ad = AutoEnzyme(; mode = EnzymeCore.Forward) @test ad isa AbstractADType - @test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)} + @test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), false} @test mode(ad) isa ForwardMode @test ad.mode == EnzymeCore.Forward - ad = AutoEnzyme(; mode = EnzymeCore.Reverse) + ad = AutoEnzyme(; mode = EnzymeCore.Reverse, constant_function = true) @test ad isa AbstractADType - @test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse)} + @test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse), true} @test mode(ad) isa ReverseMode @test ad.mode == EnzymeCore.Reverse end diff --git a/test/misc.jl b/test/misc.jl index 0ca1ddd..b3e501f 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -21,11 +21,6 @@ end @test length(string(sparse_backend1)) < length(string(sparse_backend2)) end -import ADTypes - -struct FakeSparsityDetector <: ADTypes.AbstractSparsityDetector end -struct FakeColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm end - for backend in [ # dense ADTypes.AutoChainRules(; ruleconfig = :rc), From 995986ad5e3c57d14ac78e258e21c7feb2fa027b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 17 Jul 2024 09:44:29 +0200 Subject: [PATCH 2/3] More explicit docstring --- src/dense.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dense.jl b/src/dense.jl index 57b064b..ad2fcbf 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -50,7 +50,7 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). AutoEnzyme(; mode=nothing, constant_function::Bool=false) The `constant_function` keyword argument (and type parameter) determines whether the function object itself should be considered constant or not during differentiation with Enzyme.jl. -For simple functions, this should usually be set to `false`, but in the case of closures or callable structs which contain differentiated data, it should be set to `true`. +For simple functions, `constant_function` should usually be set to `false`, but in the case of closures or callable structs which contain differentiated data that can be treated as constant, `constant_function` should be set to `true` for increased performance. # Fields From 091d3b67094bd1320075522ab4aeb6953bdf4e4e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 17 Jul 2024 10:37:05 +0200 Subject: [PATCH 3/3] More details --- src/dense.jl | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/dense.jl b/src/dense.jl index ad2fcbf..8f52f62 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -50,7 +50,7 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). AutoEnzyme(; mode=nothing, constant_function::Bool=false) The `constant_function` keyword argument (and type parameter) determines whether the function object itself should be considered constant or not during differentiation with Enzyme.jl. -For simple functions, `constant_function` should usually be set to `false`, but in the case of closures or callable structs which contain differentiated data that can be treated as constant, `constant_function` should be set to `true` for increased performance. +For simple functions, `constant_function` should usually be set to `false`, but in the case of closures or callable structs which contain differentiated data that can be treated as constant, `constant_function` should be set to `true` for increased performance (more details below). # Fields @@ -58,6 +58,33 @@ For simple functions, `constant_function` should usually be set to `false`, but + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required + `nothing` to choose the best mode automatically + +# Notes + +If `constant_function = true` but the enclosed data is not truly constant, then Enzyme.jl will not compute the correct derivative values. +An example of such a function is: + +```julia +cache = [0.0] +function f(x) + cache[1] = x[1]^2 + cache[1] + x[1] +end +``` + +In this case, the enclosed cache is a function of the differentiated input, and thus its values are non-constant with respect to the input. +Thus, in order to compute the correct derivative of the output, the derivative must propagate through the `cache` value, and said `cache` must not be treated as constant. + +Conversely, the following function can treat `parameter` as a constant, because `parameter` is never modified based on the input `x`: + +```julia +parameter = [0.0] +function f(x) + parameter[1] + x[1] +end +``` + +In this case, `constant_function = true` would allow the chosen differentiation system to perform extra memory and compute optimizations, under the assumption that `parameter` is kept constant. """ struct AutoEnzyme{M, constant_function} <: AbstractADType mode::M