diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index e46cb01d0..e7bdb152a 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -1,26 +1,92 @@ module LuxEnzymeExt using ADTypes: AutoEnzyme -using Enzyme: Enzyme +using ConcreteStructs: @concrete +using Enzyme: Enzyme, Active, Const, Duplicated using Lux: Lux using Setfield: @set! +@concrete struct CachedEnzymeExtras + dparameters + forward + reverse +end + +# Case I: We have CachedEnzymeExtras and objective_function is unchanged. +function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, + ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras, F}) where {F} + Lux.__recursive_make_zero!(ts.cache.dparameters) + loss, st_new, stats = __compute_gradients!( + ts.cache.forward, ts.cache.reverse, objective_function, + ts.model, ts.parameters, ts.cache.dparameters, ts.states, data) + ts_new = __construct_new_trainstate( + st_new, ts.states, ts.cache.forward, ts.cache.reverse, + ts, objective_function, ts.cache.dparameters) + return ts.cache.dparameters, loss, stats, ts_new +end + +# Case II: We have CachedEnzymeExtras and objective_function is changed. +function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, + ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras}) where {F} + forward, reverse = Enzyme.autodiff_thunk( + Enzyme.ReverseSplitWithPrimal, Const{typeof(objective_function)}, + Active, Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)}, + Const{typeof(ts.states)}, Const{typeof(data)}) + + Lux.__recursive_make_zero!(ts.cache.dparameters) + loss, st_new, stats = __compute_gradients!( + forward, reverse, objective_function, ts.model, + ts.parameters, ts.cache.dparameters, ts.states, data) + + ts_new = __construct_new_trainstate( + st_new, ts.states, forward, reverse, ts, objective_function, ts.cache.dparameters) + return ts.cache.dparameters, loss, stats, ts_new +end + +# Case III: Nothing is cached function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} - dps = Enzyme.make_zero(ts.parameters) - fwd, rev = Enzyme.autodiff_thunk( - Enzyme.ReverseSplitWithPrimal, Enzyme.Const{typeof(objective_function)}, - Enzyme.Active, Enzyme.Const{typeof(ts.model)}, - Enzyme.Duplicated{typeof(ts.parameters)}, - Enzyme.Const{typeof(ts.states)}, Enzyme.Const{typeof(data)}) - tape, (loss, st_new, stats), shadow_result = fwd( - Enzyme.Const(objective_function), Enzyme.Const(ts.model), - Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data)) - rev(Enzyme.Const(objective_function), Enzyme.Const(ts.model), - Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data), - (one(loss), Enzyme.make_zero(st_new), Enzyme.make_zero(stats)), tape) - @set! ts.states = st_new - return dps, loss, stats, ts + dps = Lux.__recursive_make_zero(ts.parameters) + forward, reverse = Enzyme.autodiff_thunk( + Enzyme.ReverseSplitWithPrimal, Const{typeof(objective_function)}, + Active, Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)}, + Const{typeof(ts.states)}, Const{typeof(data)}) + + loss, st_new, stats = __compute_gradients!( + forward, reverse, objective_function, ts.model, ts.parameters, dps, ts.states, data) + ts_new = __construct_new_trainstate( + st_new, ts.states, forward, reverse, ts, objective_function, dps) + return dps, loss, stats, ts_new +end + +function __compute_gradients!( + forward::F, reverse::R, obj_fn::O, model, ps, dps, st, data) where {F, R, O} + pps = Duplicated(ps, dps) + args = (Const(obj_fn), Const(model), pps, Const(st), Const(data)) + tape, (loss, st_new, stats), shadow_result = forward(args...) + reverse(args..., + (one(loss), Lux.__recursive_make_zero(st_new), Lux.__recursive_make_zero(stats)), + tape) + return loss, st_new, stats +end + +# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it +# my not storing the objective function. +function __construct_new_trainstate( + st_new::S, ::S, forward::F, reverse::R, ts::Lux.Experimental.TrainState, + objective_fn::O, dps) where {S, F, R, O} + cache = CachedEnzymeExtras(dps, forward, reverse) + return Lux.Experimental.TrainState( + cache, ts.objective_function, ts.model, ts.parameters, + st_new, ts.optimizer_state, ts.step + 1) +end + +function __construct_new_trainstate( + st_new, _, forward::F, reverse::R, ts::Lux.Experimental.TrainState, + objective_fn::O, dps) where {F, R, O} + cache = CachedEnzymeExtras(dps, nothing, nothing) + return Lux.Experimental.TrainState( + cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step + 1) end end diff --git a/ext/LuxOptimisersExt.jl b/ext/LuxOptimisersExt.jl index c5f6950d7..5d549bcc3 100644 --- a/ext/LuxOptimisersExt.jl +++ b/ext/LuxOptimisersExt.jl @@ -36,10 +36,18 @@ function Lux.Experimental.TrainState( return Lux.Experimental.TrainState(nothing, nothing, model, ps, st, st_opt, 0) end -function Lux.Experimental.apply_gradients(ts::Lux.Experimental.TrainState, grads) - optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads) - return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model, - ps, ts.states, optimizer_state, ts.step + 1) +function Lux.Experimental.apply_gradients( + ts::Lux.Experimental.TrainState, grads, update_inplace=false) + if update_inplace + optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads) + return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model, + ps, ts.states, optimizer_state, ts.step + 1) + else + Optimisers.update!(ts.optimizer_state, ts.parameters, grads) + return Lux.Experimental.TrainState( + ts.cache, ts.objective_function, ts.model, ts.parameters, + ts.states, ts.optimizer_state, ts.step + 1) + end end # DistributedUtils diff --git a/src/contrib/training.jl b/src/contrib/training.jl index ca496fc38..5a8ed60d4 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -25,7 +25,7 @@ Internal fields: end """ - apply_gradients(ts::TrainState, grads) + apply_gradients(ts::TrainState, grads, update_inplace::Bool=false) Update the parameters stored in `ts` using the gradients `grads`. @@ -33,6 +33,7 @@ Update the parameters stored in `ts` using the gradients `grads`. - `ts`: [`TrainState`](@ref) object. - `grads`: Gradients of the loss function wrt `ts.params`. + - `update_inplace`: Whether to update the parameters inplace or not. ## Returns @@ -73,6 +74,17 @@ A 4-Tuple containing: - `loss`: Loss from the objective function. - `stats`: Any computed statistics from the objective function. - `ts`: Updated Training State. + +## Special Notes on Backends + + - `AutoEnzyme`: `mode` is always ignored. + +!!! danger + + `grads` returned by this function might be aliased by the implementation of the gradient + backend. For example, if you cache the `grads` from step `i`, the new gradients + returned in step `i + 1` might be aliased by the old gradients. If you want to prevent + this, simply use `copy(grads)` or `deepcopy(grads)` to make a copy of the gradients. """ function compute_gradients(ad::ADTypes.AbstractADType, ::F, _, ::TrainState) where {F} return __maybe_implemented_compute_gradients(ad) diff --git a/src/utils.jl b/src/utils.jl index fdbded5b4..c7bb0815c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -287,3 +287,20 @@ end @inline __size(x::AbstractArray) = size(x) @inline __size(x::T) where {T} = hasmethod(size, Tuple{T}) ? size(x) : nothing + +@inline __recursive_make_zero(x::AbstractArray{<:Number}) = zero(x) +@inline __recursive_make_zero(x::AbstractArray) = map(__recursive_make_zero, x) +@inline __recursive_make_zero(x::Tuple) = map(__recursive_make_zero, x) +@inline __recursive_make_zero(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( + __recursive_make_zero, values(x))) +@inline __recursive_make_zero(::Nothing) = nothing +@inline __recursive_make_zero(v::Val) = v +@inline __recursive_make_zero(x) = fmap(__recursive_make_zero, x) + +@inline __recursive_make_zero!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x))) +@inline __recursive_make_zero!(x::AbstractArray) = map(__recursive_make_zero!, x) +@inline __recursive_make_zero!(x::Tuple) = map(__recursive_make_zero!, x) +@inline __recursive_make_zero!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( + __recursive_make_zero!, values(x))) +@inline __recursive_make_zero!(::Nothing) = nothing +@inline __recursive_make_zero!(x) = fmap(__recursive_make_zero!, x)