From b15a924b6953eae55443e3e2e2b983471d7a7f8f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 13 Sep 2024 16:26:24 -0400 Subject: [PATCH] refactor: use setfield and make make_zero!! type-stable --- ext/LuxEnzymeExt/LuxEnzymeExt.jl | 2 + ext/LuxEnzymeExt/training.jl | 51 ++++++++------- ext/LuxReverseDiffExt/LuxReverseDiffExt.jl | 4 +- ext/LuxReverseDiffExt/training.jl | 73 ++++++++++------------ ext/LuxTrackerExt/LuxTrackerExt.jl | 4 +- ext/LuxTrackerExt/training.jl | 24 +++---- src/helpers/recursive_ops.jl | 8 +-- src/helpers/training.jl | 23 ++++--- test/runtests.jl | 2 +- 9 files changed, 92 insertions(+), 99 deletions(-) diff --git a/ext/LuxEnzymeExt/LuxEnzymeExt.jl b/ext/LuxEnzymeExt/LuxEnzymeExt.jl index 2b3cf877d..0174f3972 100644 --- a/ext/LuxEnzymeExt/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt/LuxEnzymeExt.jl @@ -3,6 +3,8 @@ module LuxEnzymeExt using ADTypes: AutoEnzyme using Enzyme: Enzyme, Active, Const, Duplicated using EnzymeCore: EnzymeCore +using Setfield: @set! +using Static: False, True using Lux: Lux using Lux.Training: TrainingBackendCache, TrainState diff --git a/ext/LuxEnzymeExt/training.jl b/ext/LuxEnzymeExt/training.jl index dda255e44..410b9f11e 100644 --- a/ext/LuxEnzymeExt/training.jl +++ b/ext/LuxEnzymeExt/training.jl @@ -1,42 +1,43 @@ function Lux.Training.compute_gradients( - ::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F} - dps = Enzyme.make_zero(ts.parameters) + ad::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F} + dps = Lux.recursive_make_zero(ts.parameters) obj_fn_wrap, st_wrap, stats_wrap = Lux.Training.wrap_objective_function( - obj_fn, ts.model, ts.parameters, ts.states, data, Val(true)) + obj_fn, ts.model, ts.parameters, ts.states, data, True()) _, loss = Enzyme.autodiff( EnzymeCore.ReverseWithPrimal, Const(obj_fn_wrap), Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) - cache = TrainingBackendCache{:Enzyme, false}( - dps, (; obj_fn=obj_fn_wrap, st_wrap, stats_wrap)) - ts_new = TrainState(cache, obj_fn, ts.model, ts.parameters, st_wrap[], - ts.optimizer, ts.optimizer_state, ts.step) - - return dps, loss, stats_wrap[], ts_new + cache = TrainingBackendCache( + ad, False(), dps, (; obj_fn=obj_fn_wrap, st_wrap, stats_wrap)) + @set! ts.cache = cache + @set! ts.objective_function = obj_fn + @set! ts.states = st_wrap[] + return dps, loss, stats_wrap[], ts end const AUTODIFF_CACHE_TYPE = TrainingBackendCache{ - :Enzyme, false, PS, <:NamedTuple{(:obj_fn, :st_wrap, :stats_wrap)}} where {PS} + <:AutoEnzyme, False, PS, <:NamedTuple{(:obj_fn, :st_wrap, :stats_wrap)}} where {PS} function Lux.Training.compute_gradients( ::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_CACHE_TYPE, F}) where {F} - dps = Lux.recursive_make_zero!!(ts.cache.dparameters) + # dps = Lux.recursive_make_zero!!(ts.cache.dparameters) + Enzyme.make_zero!(ts.cache.dparameters) + dps = ts.cache.dparameters _, loss = Enzyme.autodiff( EnzymeCore.ReverseWithPrimal, Const(ts.cache.extras.obj_fn), Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) - ts_new = TrainState( - ts.cache, obj_fn, ts.model, ts.parameters, ts.cache.extras.st_wrap[], - ts.optimizer, ts.optimizer_state, ts.step) + @set! ts.objective_function = obj_fn + @set! ts.states = ts.cache.extras.st_wrap[] - return dps, loss, ts.cache.extras.stats_wrap[], ts_new + return dps, loss, ts.cache.extras.stats_wrap[], ts end function Lux.Training.compute_gradients(ad::AutoEnzyme, obj_fn::F, data, - ts::TrainState{<:TrainingBackendCache{:Enzyme, false}}) where {F} + ts::TrainState{<:TrainingBackendCache{<:AutoEnzyme, False}}) where {F} @warn "Detected calls to `compute_gradients(::AutoEnzyme, ...)` with objective \ function that is changing across function calls. This can lead to the \ generation of slow code" maxlog=1 @@ -46,15 +47,14 @@ function Lux.Training.compute_gradients(ad::AutoEnzyme, obj_fn::F, data, Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)}, Const{typeof(ts.states)}, Const{typeof(data)}) - cache = TrainingBackendCache{:Enzyme, false}(ts.cache.dparameters, (; forward, reverse)) - ts_new = TrainState(cache, obj_fn, ts.model, ts.parameters, ts.states, - ts.optimizer, ts.optimizer_state, ts.step) - - return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new) + cache = TrainingBackendCache(ad, False(), ts.cache.dparameters, (; forward, reverse)) + @set! ts.cache = cache + @set! ts.objective_function = obj_fn + return Lux.Training.compute_gradients(ad, obj_fn, data, ts) end const AUTODIFF_THUNK_CACHE_TYPE = TrainingBackendCache{ - :Enzyme, false, PS, <:NamedTuple{(:forward, :reverse)}} where {PS} + <:AutoEnzyme, False, PS, <:NamedTuple{(:forward, :reverse)}} where {PS} function Lux.Training.compute_gradients(::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_THUNK_CACHE_TYPE, F}) where {F} @@ -67,8 +67,7 @@ function Lux.Training.compute_gradients(::AutoEnzyme, obj_fn::F, data, Const(obj_fn), Const(ts.model), params, Const(ts.states), Const(data), (one(loss), Lux.recursive_make_zero(st_), Lux.recursive_make_zero(stats)), tape) - ts_new = TrainState(ts.cache, obj_fn, ts.model, ts.parameters, st_, - ts.optimizer, ts.optimizer_state, ts.step) - - return dps, loss, stats, ts_new + @set! ts.objective_function = obj_fn + @set! ts.states = st_ + return dps, loss, stats, ts end diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl index 0295b4a28..bf6f9c667 100644 --- a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -5,9 +5,11 @@ using ArrayInterface: ArrayInterface using FunctionWrappers: FunctionWrapper using ReverseDiff: ReverseDiff, ForwardExecutor, ReverseExecutor, TrackedArray, TrackedReal, @grad_from_chainrules +using Setfield: @set! +using Static: False, True using Lux: Lux, Utils -using Lux.Training: TrainingBackendCache, TrainState +using Lux.Training: Training, TrainingBackendCache, TrainState using LuxCore: LuxCore using MLDataDevices: CPUDevice diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl index 47fa0dbd0..33bf01eb9 100644 --- a/ext/LuxReverseDiffExt/training.jl +++ b/ext/LuxReverseDiffExt/training.jl @@ -1,16 +1,15 @@ # Uncompiled ReverseDiff function Lux.Training.compute_gradients( ad::AutoReverseDiff{false}, obj_fn::F, data, ts::TrainState) where {F} - grads = Lux.recursive_make_zero(ts.parameters) - ts_new = TrainState( - TrainingBackendCache{:ReverseDiff, true}(grads, nothing), obj_fn, ts.model, - ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step) - return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new) + @set! ts.cache = TrainingBackendCache( + ad, True(), Lux.recursive_make_zero(ts.parameters), nothing) + @set! ts.objective_function = obj_fn + return Lux.Training.compute_gradients(ad, obj_fn, data, ts) end function Lux.Training.compute_gradients(::AutoReverseDiff{false}, obj_fn::F, data, - ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT} - dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{false}}}) where {F} + dparams = Training.dparameters(ts.cache) tape = ReverseDiff.InstructionTape() ps_tracked = Lux.recursive_map(Utils.Fix3(TrackedArray, tape), ts.parameters, dparams) @@ -18,36 +17,34 @@ function Lux.Training.compute_gradients(::AutoReverseDiff{false}, obj_fn::F, dat loss.deriv = true ReverseDiff.reverse_pass!(tape) - ts_new = TrainState( - TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, nothing), - obj_fn, ts.model, ts.parameters, st, ts.optimizer, ts.optimizer_state, ts.step) - - return ts.cache.dparameters, ReverseDiff.value(loss), stats, ts_new + @set! ts.cache.first_try = False() + @set! ts.objective_function = obj_fn + @set! ts.states = st + return dparams, ReverseDiff.value(loss), stats, ts end # Compiled ReverseDiff function Lux.Training.compute_gradients( ad::AutoReverseDiff{true}, obj_fn::F, data, ts::TrainState) where {F} - grads = Lux.recursive_make_zero(ts.parameters) - data_cache = deepcopy(data) - ps_cache = deepcopy(ts.parameters) - extras = (; data_cache, ps_cache) - - ts_new = TrainState( - TrainingBackendCache{:ReverseDiff, true}(grads, extras), nothing, ts.model, - ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step) - return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new) + @set! ts.cache = TrainingBackendCache( + ad, True(), Lux.recursive_make_zero(ts.parameters), + (; data_cache=deepcopy(data), ps_cache=deepcopy(ts.parameters))) + @set! ts.objective_function = nothing + + return Lux.Training.compute_gradients(ad, obj_fn, data, ts) end ## Tape hasn't been compiled yet / Function mismatch so recompile -function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data, - ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT} +function Lux.Training.compute_gradients(ad::AutoReverseDiff{true}, obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{true}}}) where {F} if LuxCore.statelength(ts.states) != 0 throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for Lux \ models with non-empty state `st`.")) end - if FT # do a dry run + first_try = ts.cache.first_try isa True + + if first_try # do a dry run _, st_, stats = obj_fn(ts.model, ts.parameters, ts.states, data) if stats != NamedTuple() throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for \ @@ -59,20 +56,18 @@ function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data end end - dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + dparams = Training.dparameters(ts.cache) (; ps_cache, data_cache) = ts.cache.extras - if !FT + if !first_try Lux.recursive_copyto!(ps_cache, ts.parameters) Lux.recursive_copyto!(data_cache, data) end - obj_fn_wrap = first ∘ obj_fn - tape = ReverseDiff.InstructionTape() ps_tracked = Lux.recursive_map(Utils.Fix3(TrackedArray, tape), ps_cache, dparams) - loss = obj_fn_wrap(ts.model, ps_tracked, ts.states, data_cache) + loss = first(obj_fn(ts.model, ps_tracked, ts.states, data_cache)) loss.deriv = true ReverseDiff.reverse_pass!(tape) @@ -81,18 +76,14 @@ function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data reverse_executor = [FunctionWrapper{Nothing, Tuple{}}(ReverseExecutor(tape[i])) for i in length(tape):-1:1] - compiled_extras = (; - ps_cache, data_cache, forward_executor, reverse_executor, output=loss) - ts_new = TrainState( - TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, compiled_extras), - obj_fn, ts.model, ts.parameters, ts.states, - ts.optimizer, ts.optimizer_state, ts.step) - - return dparams, ReverseDiff.value(loss), NamedTuple(), ts_new + @set! ts.cache = TrainingBackendCache(ad, False(), dparams, + (; ps_cache, data_cache, forward_executor, reverse_executor, output=loss)) + @set! ts.objective_function = obj_fn + return dparams, ReverseDiff.value(loss), NamedTuple(), ts end function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data, - ts::TrainState{<:TrainingBackendCache{:ReverseDiff, false}, F}) where {F} + ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{true}}, F}) where {F} (; ps_cache, data_cache, output) = ts.cache.extras dparams = Lux.recursive_make_zero!!(ts.cache.dparameters) @@ -107,7 +98,7 @@ function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data wrapper() end - ts_new = TrainState(ts.cache, obj_fn, ts.model, ts.parameters, ts.states, - ts.optimizer, ts.optimizer_state, ts.step) - return dparams, ReverseDiff.value(output), NamedTuple(), ts_new + @set! ts.cache.first_try = False() + @set! ts.objective_function = obj_fn + return dparams, ReverseDiff.value(output), NamedTuple(), ts end diff --git a/ext/LuxTrackerExt/LuxTrackerExt.jl b/ext/LuxTrackerExt/LuxTrackerExt.jl index 34dd0d527..e243611d2 100644 --- a/ext/LuxTrackerExt/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt/LuxTrackerExt.jl @@ -3,10 +3,12 @@ module LuxTrackerExt using ADTypes: AbstractADType, AutoTracker using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore +using Setfield: @set! +using Static: False, True using Tracker: Tracker, TrackedArray, TrackedReal, @grad_from_chainrules using Lux: Lux, Utils -using Lux.Training: TrainingBackendCache, TrainState +using Lux.Training: Training, TrainingBackendCache, TrainState const CRC = ChainRulesCore diff --git a/ext/LuxTrackerExt/training.jl b/ext/LuxTrackerExt/training.jl index 607f46994..0e0880b41 100644 --- a/ext/LuxTrackerExt/training.jl +++ b/ext/LuxTrackerExt/training.jl @@ -1,23 +1,23 @@ function Lux.Training.compute_gradients(::AutoTracker, obj_fn::F, data, - ts::TrainState{<:TrainingBackendCache{:Tracker, FT}}) where {F, FT} - dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) - ps_tracked = construct_tracked_params(ts.parameters, dparams) + ts::TrainState{<:TrainingBackendCache{AutoTracker}}) where {F} + dps = Training.dparameters(ts.cache) + ps_tracked = construct_tracked_params(ts.parameters, dps) loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data) Tracker.back!(loss) - ts_new = TrainState( - TrainingBackendCache{:Tracker, false}(ts.cache.dparameters, nothing), obj_fn, - ts.model, ts.parameters, st, ts.optimizer, ts.optimizer_state, ts.step) + @set! ts.cache.first_try = False() + @set! ts.objective_function = obj_fn + @set! ts.states = st - return dparams, loss.data, stats, ts_new + return dps, loss.data, stats, ts end function Lux.Training.compute_gradients( - ::AutoTracker, obj_fn::F, data, ts::TrainState) where {F} + ad::AutoTracker, obj_fn::F, data, ts::TrainState) where {F} grads = Lux.recursive_make_zero(ts.parameters) - ts_new = TrainState( - TrainingBackendCache{:Tracker, true}(grads, nothing), obj_fn, ts.model, - ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step) - return Lux.Training.compute_gradients(AutoTracker(), obj_fn, data, ts_new) + cache = TrainingBackendCache(ad, True(), grads, nothing) + @set! ts.cache = cache + @set! ts.objective_function = obj_fn + return Lux.Training.compute_gradients(ad, obj_fn, data, ts) end diff --git a/src/helpers/recursive_ops.jl b/src/helpers/recursive_ops.jl index 329d2d174..48b837dd1 100644 --- a/src/helpers/recursive_ops.jl +++ b/src/helpers/recursive_ops.jl @@ -102,18 +102,12 @@ function recursive_map(f::F, x::AbstractArray{T}, args...) where {F, T} (T <: Number || isbitstype(T)) && return f(x, args...) # Not all Number types (BigFloat) are bitstype return f.(x, args...) end -function recursive_map(f::F, x::Tuple, args...) where {F} +function recursive_map(f::F, x::Union{NamedTuple, Tuple}, args...) where {F} map_fn = let f = f (args_...) -> recursive_map(f, args_...) end return map(map_fn, x, args...) end -function recursive_map(f::F, x::NamedTuple{fields}, args...) where {F, fields} - map_fn = let f = f - (args_...) -> recursive_map(f, args_...) - end - return NamedTuple{fields}(map(map_fn, values(x), values.(args)...)) -end recursive_map(f::F, x, args...) where {F} = fmap(f, x, args...) @compat(public, diff --git a/src/helpers/training.jl b/src/helpers/training.jl index d19a441ae..51fdb1a48 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -6,6 +6,7 @@ using ConcreteStructs: @concrete using FastClosures: @closure using Optimisers: Optimisers using Setfield: @set! +using Static: StaticBool, Static, False, True using ..Lux: Lux using LuxCore: LuxCore, AbstractLuxLayer @@ -50,13 +51,10 @@ Constructor for [`TrainState`](@ref). ## Arguments - - `rng`: Random Number Generator. - `ps`: Parameters of the model. - `st`: States of the model. - `model`: `Lux` model. - `optimizer`: Optimizer from `Optimisers.jl`. - - `transform_variables`: Function to transform the variables of the model. Typically used - to transfer variables to GPU / CPU. ## Returns @@ -67,12 +65,18 @@ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.Abstr return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0) end -@concrete struct TrainingBackendCache{backend, first_try} +@concrete struct TrainingBackendCache + backend + first_try <: StaticBool dparameters extras end -training_backend(::TrainingBackendCache{backend}) where {backend} = backend +dparameters(cache::TrainingBackendCache) = dparameters(cache, cache.first_try) +function dparameters(cache::TrainingBackendCache, ::False) + return Lux.recursive_make_zero!!(cache.dparameters) +end +dparameters(cache::TrainingBackendCache, ::True) = cache.dparameters function Base.show(io::IO, ::MIME"text/plain", ts::TrainState) println(io, "TrainState") @@ -83,8 +87,7 @@ function Base.show(io::IO, ::MIME"text/plain", ts::TrainState) print(io, " step: ", ts.step) if ts.cache !== nothing if ts.cache isa TrainingBackendCache - print(io, - "\n cache: $(nameof(typeof(ts.cache))){$(training_backend(ts.cache))}") + print(io, "\n cache: $(nameof(typeof(ts.cache)))($(ts.cache.backend))") else print(io, "\n cache: $(nameof(typeof(ts.cache)))") end @@ -198,7 +201,7 @@ for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme) end end -function generate_wrappers(::F, m, ps, st, data, ::Val{false}) where {F} +function generate_wrappers(::F, m, ps, st, data, ::False) where {F} @warn "Detected function wrapper generation with function being updated between calls. \ This will generate type-unstable code. A possible reason for this is \ `TrainState` was compiled (first call to `compute_gradients`) with function \ @@ -208,13 +211,13 @@ function generate_wrappers(::F, m, ps, st, data, ::Val{false}) where {F} end # Run the code when trying to compile the function for the first time. -function generate_wrappers(objective_function::F, m, ps, st, data, ::Val{true}) where {F} +function generate_wrappers(objective_function::F, m, ps, st, data, ::True) where {F} _, stₙ, statsₙ = objective_function(m, ps, st, data) return Ref{typeof(stₙ)}(stₙ), Ref{typeof(statsₙ)}(statsₙ) end function wrap_objective_function( - objective_function::F, m, ps, st, data, first_try::Val) where {F} + objective_function::F, m, ps, st, data, first_try::StaticBool) where {F} st_updated, stats = generate_wrappers(objective_function, m, ps, st, data, first_try) wrapped_objective_function = @closure (model, ps, st, data) -> begin diff --git a/test/runtests.jl b/test/runtests.jl index 599697663..377a0075a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -109,7 +109,7 @@ const RETESTITEMS_NWORKERS = parse( @info "Running tests for group: [$(i)/$(length(LUX_TEST_GROUP))] $tag" ReTestItems.runtests(Lux; tags=(tag == "all" ? nothing : [Symbol(tag)]), - nworkers=RETESTITEMS_NWORKERS, testitem_timeout=2400, retries=1) + nworkers=RETESTITEMS_NWORKERS, testitem_timeout=2400) end end