From ebe293bbc5fc3563711000612721ae51ebb5f86b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Jun 2024 14:58:33 -0700 Subject: [PATCH] Consolidate the Backend parameter caching --- ext/LuxEnzymeExt.jl | 29 ++++----- ext/LuxReverseDiffExt/LuxReverseDiffExt.jl | 3 +- ext/LuxReverseDiffExt/training.jl | 44 ++++++++----- ext/LuxTrackerExt.jl | 72 +++++++++++++--------- src/contrib/training.jl | 16 ++++- src/utils.jl | 4 +- 6 files changed, 104 insertions(+), 64 deletions(-) diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index a4c4994eb..9823565d9 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -1,20 +1,16 @@ module LuxEnzymeExt using ADTypes: AutoEnzyme -using ConcreteStructs: @concrete using Enzyme: Enzyme, Active, Const, Duplicated using Lux: Lux - -@concrete struct CachedEnzymeExtras{FT} - dparameters - objective_function - st_wrap - stats_wrap -end +using Lux.Experimental: TrainingBackendCache # Case I: We have CachedEnzymeExtras and objective_function is unchanged. -function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, - ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras{FT}, F}) where {F, FT} +function Lux.Experimental.compute_gradients(::AutoEnzyme, + objective_function::F, + data, + ts::Lux.Experimental.TrainState{<:TrainingBackendCache{:Enzyme, FT}, F}) where { + F, FT} dps = Lux.recursive_make_zero!!(ts.cache.dparameters) _, loss = Enzyme.autodiff( @@ -22,10 +18,10 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) ts_new = __construct_new_trainstate( - ts.cache.st_wrap[], ts.states, ts, objective_function, dps, - ts.cache.objective_function, ts.cache.st_wrap, ts.cache.stats_wrap) + ts.cache.extras.st_wrap[], ts.states, ts, objective_function, dps, + ts.cache.objective_function, ts.cache.extras.st_wrap, ts.cache.extras.stats_wrap) - return dps, loss, ts.cache.stats_wrap[], ts_new + return dps, loss, ts.cache.extras.stats_wrap[], ts_new end # Case II: We have CachedEnzymeExtras and objective_function is changed. @@ -49,7 +45,8 @@ end function Lux.Experimental.compute_gradients(ad::AutoEnzyme, objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} dps = Lux.recursive_make_zero(ts.parameters) - cache = CachedEnzymeExtras{true}(dps, nothing, nothing, nothing) + cache = TrainingBackendCache{:Enzyme, true}( + dps, nothing, (; st_wrap=nothing, stats_wrap=nothing)) ts_new = Lux.Experimental.TrainState( cache, nothing, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) return Lux.Experimental.compute_gradients(ad, objective_function, data, ts_new) @@ -60,7 +57,7 @@ end function __construct_new_trainstate( st_new::S, ::S, ts::Lux.Experimental.TrainState, objective_fn::O, dps, obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2} - cache = CachedEnzymeExtras{false}(dps, obj_fn, st_wrap, stats_wrap) + cache = TrainingBackendCache{:Enzyme, false}(dps, obj_fn, (; st_wrap, stats_wrap)) return Lux.Experimental.TrainState( cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) end @@ -68,7 +65,7 @@ end function __construct_new_trainstate( st_new, _, ts::Lux.Experimental.TrainState, objective_fn::O, dps, obj_fn::O2, st_wrap, stats_wrap) where {O, O2} - cache = CachedEnzymeExtras{false}(dps, nothing, nothing, nothing) + cache = TrainingBackendCache{:Enzyme, false}(dps, nothing, (; st_wrap, stats_wrap)) return Lux.Experimental.TrainState( cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) end diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl index d22c39e5c..20dbb4c4b 100644 --- a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -2,10 +2,9 @@ module LuxReverseDiffExt using ADTypes: ADTypes, AutoReverseDiff using ArrayInterface: ArrayInterface -using Functors: fmap using Lux: Lux, LuxCPUDevice +using Lux.Experimental: TrainingBackendCache using ReverseDiff: ReverseDiff, TrackedArray, @grad_from_chainrules -using Setfield: @set! # AoS to SoA conversion function Lux.apply( diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl index e41b15e85..6fb7d394a 100644 --- a/ext/LuxReverseDiffExt/training.jl +++ b/ext/LuxReverseDiffExt/training.jl @@ -16,29 +16,43 @@ else end end -@inline function __uncompiled_reverse_diff( - objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} +# Uncompiled ReverseDiff +@inline function __uncompiled_reverse_diff(objective_function::F, data, + ts::Lux.Experimental.TrainState{<:TrainingBackendCache{:ReverseDiff}}) where {F} tape = ReverseDiff.InstructionTape() - grads = Lux.recursive_make_zero(ts.parameters) ps_tracked = Lux.recursive_map( - Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, grads) + Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, ts.cache.dparameters) + loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) loss.deriv = true ReverseDiff.reverse_pass!(tape) - @set! ts.states = st - return grads, ReverseDiff.value(loss), stats, ts + + ts_new = Lux.Experimental.TrainState( + TrainingBackendCache{:ReverseDiff, false}( + ts.cache.dparameters, objective_function, nothing), + objective_function, + ts.model, + ts.parameters, + st, + ts.optimizer_state, + ts.step) + + return ts.cache.dparameters, ReverseDiff.value(loss), stats, ts_new +end + +# First call, nothing is cached +@inline function __uncompiled_reverse_diff( + objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} + grads = Lux.recursive_make_zero(ts.parameters) + ts_new = Lux.Experimental.TrainState( + TrainingBackendCache{:ReverseDiff, true}(grads, objective_function, nothing), + objective_function, ts.model, ts.parameters, + ts.states, ts.optimizer_state, ts.step) + return __uncompiled_reverse_diff(objective_function, data, ts_new) end +# Compiled ReverseDiff @inline function __compiled_reverse_diff( objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} - # tape = ReverseDiff.InstructionTape() - # grads = Lux.recursive_make_zero(ts.parameters) - # ps_tracked = Lux.recursive_map( - # Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, grads) - # loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) - # loss.deriv = true - # ReverseDiff.reverse_pass!(tape) - # @set! ts.states = st - # return grads, ReverseDiff.value(loss), stats, ts error("Not implemented yet") end diff --git a/ext/LuxTrackerExt.jl b/ext/LuxTrackerExt.jl index fe5785669..2e608215f 100644 --- a/ext/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt.jl @@ -3,33 +3,14 @@ module LuxTrackerExt using ADTypes: AutoTracker using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore -using FastClosures: @closure using Functors: fmap using Lux: Lux, LuxCPUDevice +using Lux.Experimental: TrainingBackendCache using LuxCore: LuxCore -using Setfield: @set! using Tracker: Tracker, TrackedArray, @grad_from_chainrules const CRC = ChainRulesCore -# Type Piracy: Need to upstream -Tracker.param(nt::NamedTuple{F}) where {F} = NamedTuple{F}(Tracker.param.(values(nt))) -Tracker.param(t::Tuple) = map(Tracker.param, t) -Tracker.param(l::LuxCore.AbstractExplicitLayer) = l - -Tracker.zero_grad!(nt::NamedTuple) = Tracker.zero_grad!.(values(nt)) -Tracker.zero_grad!(::LuxCore.AbstractExplicitLayer) = nothing - -function Tracker.extract_grad!(nt::NamedTuple{F}) where {F} - return NamedTuple{F}(Tracker.extract_grad!.(values(nt))) -end -Tracker.extract_grad!(t::Tuple) = map(Tracker.extract_grad!, t) -Tracker.extract_grad!(::LuxCore.AbstractExplicitLayer) = nothing - -Tracker.data(nt::NamedTuple) = fmap(Tracker.data, nt) -Tracker.data(t::Tuple) = map(Tracker.data, t) -Tracker.data(l::LuxCore.AbstractExplicitLayer) = l - # Weight Norm Patch @inline Lux._norm(x::TrackedArray; dims=Colon()) = sqrt.(sum(abs2.(x); dims)) @@ -37,15 +18,46 @@ Tracker.data(l::LuxCore.AbstractExplicitLayer) = l @inline Lux._gate(x::Tracker.TrackedVector, h::Int, n::Int) = x[Lux._gate(h, n)] @inline Lux._gate(x::Tracker.TrackedMatrix, h::Int, n::Int) = x[Lux._gate(h, n), :] +function __construct_tracked_params(ps, dps) + map_fn = (p, dp) -> Tracker.TrackedArray(Tracker.Call(), p, dp) + return Lux.recursive_map(map_fn, ps, dps) +end + # Lux.Training -function Lux.Experimental.compute_gradients(::AutoTracker, objective_function::F, data, - ts::Lux.Experimental.TrainState) where {F} - ps_tracked = fmap(Tracker.param, ts.parameters) +## Use the cached gradient parameters +function Lux.Experimental.compute_gradients(::AutoTracker, + objective_function::F, + data, + ts::Lux.Experimental.TrainState{<:TrainingBackendCache{:Tracker, FT}}) where {F, FT} + dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + ps_tracked = __construct_tracked_params(ts.parameters, dparams) + loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) Tracker.back!(loss) - @set! ts.states = st - grads = fmap(Tracker.grad, ps_tracked) - return grads, Tracker.value(loss), stats, ts + + ts_new = Lux.Experimental.TrainState( + TrainingBackendCache{:Tracker, false}( + ts.cache.dparameters, objective_function, nothing), + objective_function, + ts.model, + ts.parameters, + st, + ts.optimizer_state, + ts.step) + + return dparams, Tracker.value(loss), stats, ts_new +end + +## First call, nothing is cached +function Lux.Experimental.compute_gradients(::AutoTracker, objective_function::F, data, + ts::Lux.Experimental.TrainState) where {F} + grads = Lux.recursive_make_zero(ts.parameters) + ts_new = Lux.Experimental.TrainState( + TrainingBackendCache{:Tracker, true}(grads, objective_function, nothing), + objective_function, ts.model, ts.parameters, + ts.states, ts.optimizer_state, ts.step) + return Lux.Experimental.compute_gradients( + AutoTracker(), objective_function, data, ts_new) end # AoS to SoA conversion @@ -77,9 +89,11 @@ Tracker.@grad function Lux.__apply_simple_chain(layer, x, ps, ::LuxCPUDevice) As such please test your model with FiniteDifferences or Zygote before using \ `Tracker.jl` for your model." maxlog=1 y, pb_f = CRC.rrule(layer, Tracker.data(x), Tracker.data(ps)) - __∇apply_simple_chain = @closure Δ -> begin - _, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ))) - return Tracker.nobacksies(:__apply_simple_chain, (nothing, ∂x, ∂ps, nothing)) + __∇apply_simple_chain = let pb_f = pb_f + Δ -> begin + _, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ))) + return Tracker.nobacksies(:__apply_simple_chain, (nothing, ∂x, ∂ps, nothing)) + end end # Tracker is not great at handling arbitrary types, so we convert to Array return Array(y), __∇apply_simple_chain diff --git a/src/contrib/training.jl b/src/contrib/training.jl index a6df98eb2..972606424 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -29,6 +29,14 @@ Internal fields: step::Int end +@concrete struct TrainingBackendCache{backend, first_try} + dparameters + objective_function + extras +end + +@inline __backend(::TrainingBackendCache{backend}) where {backend} = backend + function Base.show(io::IO, ts::TrainState) println(io, "TrainState") println(io, " model: ", ts.model) @@ -36,7 +44,13 @@ function Base.show(io::IO, ts::TrainState) println(io, " # of states: ", Lux.statelength(ts.states)) println(io, " optimizer_state: ", ts.optimizer_state) print(io, " step: ", ts.step) - ts.cache !== nothing && print(io, "\n cache: ", nameof(typeof(ts.cache))) + if ts.cache !== nothing + if ts.cache isa TrainingBackendCache + print(io, "\n cache: $(nameof(typeof(ts.cache))){$(__backend(ts.cache))}") + else + print(io, "\n cache: $(nameof(typeof(ts.cache)))") + end + end ts.objective_function !== nothing && print(io, "\n objective_function: ", nameof(typeof(ts.objective_function))) end diff --git a/src/utils.jl b/src/utils.jl index e2b0f0df4..3089a103a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -331,7 +331,9 @@ end end @inline function __fused_agg(::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y) fast_scalar_indexing(x) && fast_scalar_indexing(y) && return sum(lfn, x, y) - return mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y) + # mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y) leads to slowdowns, better to + # allocate a new array + return sum(lfn.(x, y)) end @inline __fused_agg(::Nothing, op::OP, args...) where {OP} = op.(args...)