diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index e4bdc00f6..cc0a6318b 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -16,8 +16,7 @@ end # Case I: We have CachedEnzymeExtras and objective_function is unchanged. function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras, F}) where {F} - dps = ts.cache.dparameters - Lux.__recursive_make_zero!(dps) + dps = Lux.__recursive_make_zero!!(ts.cache.dparameters) _, loss = Enzyme.autodiff( Enzyme.ReverseWithPrimal, ts.cache.objective_function, Active, Const(ts.model), @@ -33,8 +32,7 @@ end # Case II: We have CachedEnzymeExtras and objective_function is changed. function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras}) where {F} - dps = ts.cache.dparameters - Lux.__recursive_make_zero!(dps) + dps = Lux.__recursive_make_zero!!(ts.cache.dparameters) obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( objective_function, ts.states) @@ -49,20 +47,13 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, end # Case III: Nothing is cached. First call to `compute_gradients` -function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, +function Lux.Experimental.compute_gradients(ad::AutoEnzyme, objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} dps = Lux.__recursive_make_zero(ts.parameters) - - obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( - objective_function, ts.states) - - _, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(ts.model), - Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) - - ts_new = __construct_new_trainstate( - st_wrap, ts.states, ts, objective_function, dps, obj_fn, st_wrap, stats_wrap) - - return dps, loss, stats_wrap, ts_new + cache = CachedEnzymeExtras(dps, nothing, nothing, nothing) + ts_new = Lux.Experimental.TrainState( + cache, nothing, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + return Lux.Experimental.compute_gradients(ad, objective_function, data, ts_new) end # If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not diff --git a/src/utils.jl b/src/utils.jl index c7bb0815c..8dbd774ab 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -288,6 +288,7 @@ end @inline __size(x::AbstractArray) = size(x) @inline __size(x::T) where {T} = hasmethod(size, Tuple{T}) ? size(x) : nothing +@inline __recursive_make_zero(x::Number) = zero(x) @inline __recursive_make_zero(x::AbstractArray{<:Number}) = zero(x) @inline __recursive_make_zero(x::AbstractArray) = map(__recursive_make_zero, x) @inline __recursive_make_zero(x::Tuple) = map(__recursive_make_zero, x) @@ -297,10 +298,11 @@ end @inline __recursive_make_zero(v::Val) = v @inline __recursive_make_zero(x) = fmap(__recursive_make_zero, x) -@inline __recursive_make_zero!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x))) -@inline __recursive_make_zero!(x::AbstractArray) = map(__recursive_make_zero!, x) -@inline __recursive_make_zero!(x::Tuple) = map(__recursive_make_zero!, x) -@inline __recursive_make_zero!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( - __recursive_make_zero!, values(x))) -@inline __recursive_make_zero!(::Nothing) = nothing -@inline __recursive_make_zero!(x) = fmap(__recursive_make_zero!, x) +@inline __recursive_make_zero!!(x::Number) = zero(x) +@inline __recursive_make_zero!!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x))) +@inline __recursive_make_zero!!(x::AbstractArray) = map(__recursive_make_zero!!, x) +@inline __recursive_make_zero!!(x::Tuple) = map(__recursive_make_zero!!, x) +@inline __recursive_make_zero!!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( + __recursive_make_zero!!, values(x))) +@inline __recursive_make_zero!!(::Nothing) = nothing +@inline __recursive_make_zero!!(x) = fmap(__recursive_make_zero!!, x) diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 9e19bde1b..30d608c14 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -8,7 +8,8 @@ end @testitem "Explicit Imports: Quality Assurance" tags=[:others] begin # Load all trigger packages - import Lux, ComponentArrays, ReverseDiff, Flux, LuxAMDGPU, SimpleChains, Tracker, Zygote + import Lux, ComponentArrays, ReverseDiff, Flux, LuxAMDGPU, SimpleChains, Tracker, + Zygote, Enzyme using ExplicitImports