diff --git a/Project.toml b/Project.toml index 1cf926be3..b88bbce2e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.56" +version = "0.5.57" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index 78461ac4a..bd8a71189 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -59,7 +59,9 @@ Lux.xlogx ## Recursive Operations ```@docs +Lux.recursive_map Lux.recursive_add!! +Lux.recursive_copyto! Lux.recursive_eltype Lux.recursive_make_zero Lux.recursive_make_zero!! diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index a4c4994eb..32ef8aae5 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -1,76 +1,54 @@ module LuxEnzymeExt using ADTypes: AutoEnzyme -using ConcreteStructs: @concrete using Enzyme: Enzyme, Active, Const, Duplicated using Lux: Lux +using Lux.Experimental: TrainingBackendCache, TrainState -@concrete struct CachedEnzymeExtras{FT} - dparameters - objective_function - st_wrap - stats_wrap -end - -# Case I: We have CachedEnzymeExtras and objective_function is unchanged. -function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, - ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras{FT}, F}) where {F, FT} +# Case I: We have TrainingBackendCache{:Enzyme} and obj_fn is unchanged. +function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:Enzyme, false}, F}) where {F} dps = Lux.recursive_make_zero!!(ts.cache.dparameters) _, loss = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, ts.cache.objective_function, Active, Const(ts.model), + Enzyme.ReverseWithPrimal, ts.cache.extras.obj_fn, Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) - ts_new = __construct_new_trainstate( - ts.cache.st_wrap[], ts.states, ts, objective_function, dps, - ts.cache.objective_function, ts.cache.st_wrap, ts.cache.stats_wrap) + ts_new = TrainState(ts.cache, obj_fn, ts.model, ts.parameters, + ts.cache.extras.st_wrap[], ts.optimizer_state, ts.step) - return dps, loss, ts.cache.stats_wrap[], ts_new + return dps, loss, ts.cache.extras.stats_wrap[], ts_new end -# Case II: We have CachedEnzymeExtras and objective_function is changed. -function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, - ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras{FT}}) where {F, FT} - dps = Lux.recursive_make_zero!!(ts.cache.dparameters) +# Case II: We have CachedEnzymeExtras and obj_fn is changed. +function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:Enzyme, FT}}) where {F, FT} + dps = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) - obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( - objective_function, ts.model, ts.parameters, ts.states, data, Val(FT)) + obj_fn_wrap, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( + obj_fn, ts.model, ts.parameters, ts.states, data, Val(FT)) - _, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(ts.model), + _, loss = Enzyme.autodiff( + Enzyme.ReverseWithPrimal, obj_fn_wrap, Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) - ts_new = __construct_new_trainstate( - st_wrap[], ts.states, ts, objective_function, dps, obj_fn, st_wrap, stats_wrap) + cache = TrainingBackendCache{:Enzyme, false}( + dps, (; obj_fn=obj_fn_wrap, st_wrap, stats_wrap)) + ts_new = TrainState( + cache, obj_fn, ts.model, ts.parameters, st_wrap[], ts.optimizer_state, ts.step) return dps, loss, stats_wrap[], ts_new end # Case III: Nothing is cached. First call to `compute_gradients` -function Lux.Experimental.compute_gradients(ad::AutoEnzyme, objective_function::F, data, - ts::Lux.Experimental.TrainState) where {F} +function Lux.Experimental.compute_gradients( + ad::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F} dps = Lux.recursive_make_zero(ts.parameters) - cache = CachedEnzymeExtras{true}(dps, nothing, nothing, nothing) - ts_new = Lux.Experimental.TrainState( + cache = TrainingBackendCache{:Enzyme, true}( + dps, (; obj_fn=nothing, st_wrap=nothing, stats_wrap=nothing)) + ts_new = TrainState( cache, nothing, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) - return Lux.Experimental.compute_gradients(ad, objective_function, data, ts_new) -end - -# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not -# storing the objective function. -function __construct_new_trainstate( - st_new::S, ::S, ts::Lux.Experimental.TrainState, objective_fn::O, - dps, obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2} - cache = CachedEnzymeExtras{false}(dps, obj_fn, st_wrap, stats_wrap) - return Lux.Experimental.TrainState( - cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) -end - -function __construct_new_trainstate( - st_new, _, ts::Lux.Experimental.TrainState, objective_fn::O, - dps, obj_fn::O2, st_wrap, stats_wrap) where {O, O2} - cache = CachedEnzymeExtras{false}(dps, nothing, nothing, nothing) - return Lux.Experimental.TrainState( - cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) + return Lux.Experimental.compute_gradients(ad, obj_fn, data, ts_new) end end diff --git a/ext/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt.jl deleted file mode 100644 index 4af247d24..000000000 --- a/ext/LuxReverseDiffExt.jl +++ /dev/null @@ -1,53 +0,0 @@ -module LuxReverseDiffExt - -using ADTypes: AutoReverseDiff -using ArrayInterface: ArrayInterface -using Functors: fmap -using Lux: Lux, LuxCPUDevice -using ReverseDiff: ReverseDiff, TrackedArray, @grad_from_chainrules -using Setfield: @set! - -function Lux.Experimental.compute_gradients(::AutoReverseDiff, objective_function::F, data, - ts::Lux.Experimental.TrainState) where {F} - tape = ReverseDiff.InstructionTape() - grads = fmap(zero, ts.parameters) - ps_tracked = fmap((p, g) -> ReverseDiff.TrackedArray(p, g, tape), ts.parameters, grads) - loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) - loss.deriv = true - ReverseDiff.reverse_pass!(tape) - @set! ts.states = st - return grads, ReverseDiff.value(loss), stats, ts -end - -# AoS to SoA conversion -function Lux.apply( - m::Lux.AbstractExplicitLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) - @warn "Lux.apply(m::Lux.AbstractExplicitLayer, \ - x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \ - Lux.apply(m::Lux.AbstractExplicitLayer, x::ReverseDiff.TrackedArray}, ps, \ - st).\n\n\ - 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ - 2. This might have performance implications. Check which layer was causing this \ - problem using `Lux.Experimental.@debug_mode`." maxlog=1 - return Lux.apply(m, reshape(ArrayInterface.aos_to_soa(x), size(x)), ps, st) -end - -## Prevent an infinite loop -Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) - -# SimpleChains.jl -@grad_from_chainrules Lux.__apply_simple_chain(layer, x::TrackedArray, ps, ::LuxCPUDevice) -@grad_from_chainrules Lux.__apply_simple_chain(layer, x, ps::TrackedArray, ::LuxCPUDevice) -@grad_from_chainrules Lux.__apply_simple_chain( - layer, x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice) - -# DynamicExpressions.jl -@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr, - operator_enum, x::TrackedArray, ps, ::LuxCPUDevice) -@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr, - operator_enum, x, ps::TrackedArray, ::LuxCPUDevice) -@grad_from_chainrules Lux.__apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, operator_enum, - x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice) - -end diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl new file mode 100644 index 000000000..5015e4c18 --- /dev/null +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -0,0 +1,28 @@ +module LuxReverseDiffExt + +using ADTypes: ADTypes, AutoReverseDiff +using ArrayInterface: ArrayInterface +using Lux: Lux, LuxCPUDevice +using Lux.Experimental: TrainingBackendCache, TrainState +using ReverseDiff: ReverseDiff, TrackedArray, @grad_from_chainrules + +# AoS to SoA conversion +function Lux.apply( + m::Lux.AbstractExplicitLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) + @warn "Lux.apply(m::Lux.AbstractExplicitLayer, \ + x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \ + Lux.apply(m::Lux.AbstractExplicitLayer, x::ReverseDiff.TrackedArray}, ps, \ + st).\n\n\ + 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ + 2. This might have performance implications. Check which layer was causing this \ + problem using `Lux.Experimental.@debug_mode`." maxlog=1 + return Lux.apply(m, reshape(ArrayInterface.aos_to_soa(x), size(x)), ps, st) +end + +## Prevent an infinite loop +Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) + +include("rules.jl") +include("training.jl") + +end diff --git a/ext/LuxReverseDiffExt/rules.jl b/ext/LuxReverseDiffExt/rules.jl new file mode 100644 index 000000000..122df5ab8 --- /dev/null +++ b/ext/LuxReverseDiffExt/rules.jl @@ -0,0 +1,14 @@ +# SimpleChains.jl +@grad_from_chainrules Lux.__apply_simple_chain(layer, x::TrackedArray, ps, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_simple_chain(layer, x, ps::TrackedArray, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_simple_chain( + layer, x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice) + +# DynamicExpressions.jl +@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr, + operator_enum, x::TrackedArray, ps, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr, + operator_enum, x, ps::TrackedArray, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_dynamic_expression( + de::Lux.DynamicExpressionsLayer, expr, operator_enum, + x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice) diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl new file mode 100644 index 000000000..5d189053b --- /dev/null +++ b/ext/LuxReverseDiffExt/training.jl @@ -0,0 +1,127 @@ +@static if pkgversion(ADTypes) < v"1.5" + # older versions did not have `compile` type parameter. Use slower type-unstable code + function Lux.Experimental.compute_gradients( + ad::AutoReverseDiff, obj_fn::F, data, ts::TrainState) where {F} + ad.compile && return __compiled_reverse_diff(obj_fn, data, ts) + return __uncompiled_reverse_diff(obj_fn, data, ts) + end +else + for compiled in (false, true) + fname = compiled ? :__compiled_reverse_diff : :__uncompiled_reverse_diff + @eval function Lux.Experimental.compute_gradients( + ::AutoReverseDiff{$(compiled)}, obj_fn::F, data, ts::TrainState) where {F} + return $(fname)(obj_fn, data, ts) + end + end +end + +# Uncompiled ReverseDiff +@inline function __uncompiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} + grads = Lux.recursive_make_zero(ts.parameters) + ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, nothing), obj_fn, + ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + return __uncompiled_reverse_diff(obj_fn, data, ts_new) +end + +@inline function __uncompiled_reverse_diff(obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT} + dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + tape = ReverseDiff.InstructionTape() + ps_tracked = Lux.recursive_map( + Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, dparams) + + loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data) + loss.deriv = true + ReverseDiff.reverse_pass!(tape) + + ts_new = TrainState( + TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, nothing), + obj_fn, ts.model, ts.parameters, st, ts.optimizer_state, ts.step) + + return ts.cache.dparameters, ReverseDiff.value(loss), stats, ts_new +end + +# Compiled ReverseDiff +@inline function __compiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} + grads = Lux.recursive_make_zero(ts.parameters) + data_cache = deepcopy(data) + ps_cache = deepcopy(ts.parameters) + extras = (; data_cache, ps_cache) + + ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, extras), nothing, + ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + return __compiled_reverse_diff(obj_fn, data, ts_new) +end + +## Tape hasn't been compiled yet / Function mismatch so recompile +@inline function __compiled_reverse_diff(obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT} + if Lux.statelength(ts.states) != 0 + throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for Lux \ + models with non-empty state `st`.")) + end + + if FT # do a dry run + _, st_, stats = obj_fn(ts.model, ts.parameters, ts.states, data) + if stats != NamedTuple() + throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for \ + loss functions that return non-empty `stats`.")) + end + if Lux.statelength(st_) != 0 + throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for \ + models with non-empty state `st`.")) + end + end + + dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + + (; ps_cache, data_cache) = ts.cache.extras + if !FT + Lux.recursive_copyto!(ps_cache, ts.parameters) + Lux.recursive_copyto!(data_cache, data) + end + + obj_fn_wrap = first ∘ obj_fn + + tape = ReverseDiff.InstructionTape() + ps_tracked = Lux.recursive_map( + Lux.__Fix3(ReverseDiff.TrackedArray, tape), ps_cache, dparams) + + loss = obj_fn_wrap(ts.model, ps_tracked, ts.states, data_cache) + loss.deriv = true + ReverseDiff.reverse_pass!(tape) + + forward_executor = [ReverseDiff.FunctionWrapper{Nothing, Tuple{}}(ReverseDiff.ForwardExecutor(instruction)) + for instruction in tape] + reverse_executor = [ReverseDiff.FunctionWrapper{Nothing, Tuple{}}(ReverseDiff.ReverseExecutor(tape[i])) + for i in length(tape):-1:1] + + compiled_extras = (; + ps_cache, data_cache, forward_executor, reverse_executor, output=loss) + ts_new = TrainState( + TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, compiled_extras), + obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + + return dparams, ReverseDiff.value(loss), NamedTuple(), ts_new +end + +@inline function __compiled_reverse_diff(obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:ReverseDiff, false}, F}) where {F} + (; ps_cache, data_cache, output) = ts.cache.extras + + dparams = Lux.recursive_make_zero!!(ts.cache.dparameters) + Lux.recursive_copyto!(ps_cache, ts.parameters) + Lux.recursive_copyto!(data_cache, data) + + for wrapper in ts.cache.extras.forward_executor + wrapper() + end + output.deriv = true + for wrapper in ts.cache.extras.reverse_executor + wrapper() + end + + ts_new = TrainState( + ts.cache, obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + return dparams, ReverseDiff.value(output), NamedTuple(), ts_new +end diff --git a/ext/LuxTrackerExt.jl b/ext/LuxTrackerExt.jl index fe5785669..3ba136261 100644 --- a/ext/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt.jl @@ -3,33 +3,13 @@ module LuxTrackerExt using ADTypes: AutoTracker using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore -using FastClosures: @closure -using Functors: fmap using Lux: Lux, LuxCPUDevice +using Lux.Experimental: TrainingBackendCache, TrainState using LuxCore: LuxCore -using Setfield: @set! using Tracker: Tracker, TrackedArray, @grad_from_chainrules const CRC = ChainRulesCore -# Type Piracy: Need to upstream -Tracker.param(nt::NamedTuple{F}) where {F} = NamedTuple{F}(Tracker.param.(values(nt))) -Tracker.param(t::Tuple) = map(Tracker.param, t) -Tracker.param(l::LuxCore.AbstractExplicitLayer) = l - -Tracker.zero_grad!(nt::NamedTuple) = Tracker.zero_grad!.(values(nt)) -Tracker.zero_grad!(::LuxCore.AbstractExplicitLayer) = nothing - -function Tracker.extract_grad!(nt::NamedTuple{F}) where {F} - return NamedTuple{F}(Tracker.extract_grad!.(values(nt))) -end -Tracker.extract_grad!(t::Tuple) = map(Tracker.extract_grad!, t) -Tracker.extract_grad!(::LuxCore.AbstractExplicitLayer) = nothing - -Tracker.data(nt::NamedTuple) = fmap(Tracker.data, nt) -Tracker.data(t::Tuple) = map(Tracker.data, t) -Tracker.data(l::LuxCore.AbstractExplicitLayer) = l - # Weight Norm Patch @inline Lux._norm(x::TrackedArray; dims=Colon()) = sqrt.(sum(abs2.(x); dims)) @@ -37,15 +17,33 @@ Tracker.data(l::LuxCore.AbstractExplicitLayer) = l @inline Lux._gate(x::Tracker.TrackedVector, h::Int, n::Int) = x[Lux._gate(h, n)] @inline Lux._gate(x::Tracker.TrackedMatrix, h::Int, n::Int) = x[Lux._gate(h, n), :] +function __construct_tracked_params(ps, dps) + map_fn = (p, dp) -> Tracker.TrackedArray(Tracker.Call(), p, dp) + return Lux.recursive_map(map_fn, ps, dps) +end + # Lux.Training -function Lux.Experimental.compute_gradients(::AutoTracker, objective_function::F, data, - ts::Lux.Experimental.TrainState) where {F} - ps_tracked = fmap(Tracker.param, ts.parameters) - loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) +function Lux.Experimental.compute_gradients(::AutoTracker, obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:Tracker, FT}}) where {F, FT} + dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + ps_tracked = __construct_tracked_params(ts.parameters, dparams) + + loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data) Tracker.back!(loss) - @set! ts.states = st - grads = fmap(Tracker.grad, ps_tracked) - return grads, Tracker.value(loss), stats, ts + + ts_new = TrainState( + TrainingBackendCache{:Tracker, false}(ts.cache.dparameters, nothing), + obj_fn, ts.model, ts.parameters, st, ts.optimizer_state, ts.step) + + return dparams, Tracker.value(loss), stats, ts_new +end + +function Lux.Experimental.compute_gradients( + ::AutoTracker, obj_fn::F, data, ts::TrainState) where {F} + grads = Lux.recursive_make_zero(ts.parameters) + ts_new = TrainState(TrainingBackendCache{:Tracker, true}(grads, nothing), obj_fn, + ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + return Lux.Experimental.compute_gradients(AutoTracker(), obj_fn, data, ts_new) end # AoS to SoA conversion @@ -77,9 +75,11 @@ Tracker.@grad function Lux.__apply_simple_chain(layer, x, ps, ::LuxCPUDevice) As such please test your model with FiniteDifferences or Zygote before using \ `Tracker.jl` for your model." maxlog=1 y, pb_f = CRC.rrule(layer, Tracker.data(x), Tracker.data(ps)) - __∇apply_simple_chain = @closure Δ -> begin - _, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ))) - return Tracker.nobacksies(:__apply_simple_chain, (nothing, ∂x, ∂ps, nothing)) + __∇apply_simple_chain = let pb_f = pb_f + Δ -> begin + _, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ))) + return Tracker.nobacksies(:__apply_simple_chain, (nothing, ∂x, ∂ps, nothing)) + end end # Tracker is not great at handling arbitrary types, so we convert to Array return Array(y), __∇apply_simple_chain diff --git a/src/Lux.jl b/src/Lux.jl index 9d59eacac..1c60d587e 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -129,6 +129,8 @@ export MPIBackend, NCCLBackend, DistributedUtils # Unexported functions that are part of the public API @compat public Experimental @compat public xlogx, xlogy -@compat public recursive_add!!, recursive_eltype, recursive_make_zero, recursive_make_zero!! +@compat(public, + (recursive_add!!, recursive_copyto!, recursive_eltype, + recursive_make_zero, recursive_map, recursive_make_zero!!)) end diff --git a/src/contrib/training.jl b/src/contrib/training.jl index 8dd65acc7..bbe1b082b 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -29,6 +29,13 @@ Internal fields: step::Int end +@concrete struct TrainingBackendCache{backend, first_try} + dparameters + extras +end + +@inline __backend(::TrainingBackendCache{backend}) where {backend} = backend + function Base.show(io::IO, ts::TrainState) println(io, "TrainState") println(io, " model: ", ts.model) @@ -36,7 +43,13 @@ function Base.show(io::IO, ts::TrainState) println(io, " # of states: ", Lux.statelength(ts.states)) println(io, " optimizer_state: ", ts.optimizer_state) print(io, " step: ", ts.step) - ts.cache !== nothing && print(io, "\n cache: ", nameof(typeof(ts.cache))) + if ts.cache !== nothing + if ts.cache isa TrainingBackendCache + print(io, "\n cache: $(nameof(typeof(ts.cache))){$(__backend(ts.cache))}") + else + print(io, "\n cache: $(nameof(typeof(ts.cache)))") + end + end ts.objective_function !== nothing && print(io, "\n objective_function: ", nameof(typeof(ts.objective_function))) end @@ -79,12 +92,12 @@ Compute the gradients of the objective function wrt parameters stored in `ts`. ## Backends & AD Packages -| Supported Backends | Packages Needed | -|:------------------ |:---------------- | -| `AutoZygote` | `Zygote.jl` | -| `AutoReverseDiff` | `ReverseDiff.jl` | -| `AutoTracker` | `Tracker.jl` | -| `AutoEnzyme` | `Enzyme.jl` | +| Supported Backends | Packages Needed | +|:---------------------------- |:---------------- | +| `AutoZygote` | `Zygote.jl` | +| `AutoReverseDiff(; compile)` | `ReverseDiff.jl` | +| `AutoTracker` | `Tracker.jl` | +| `AutoEnzyme` | `Enzyme.jl` | ## Arguments @@ -105,14 +118,13 @@ A 4-Tuple containing: - `stats`: Any computed statistics from the objective function. - `ts`: Updated Training State. -## Special Notes on Backends +## Known Limitations - - `AutoEnzyme`: `mode` is always ignored and Enzyme ReverseMode is used. The first call - to `compute_gradients` will be type-unstable. It is recommended to call this function - once outside of the training loop and use the returned train_state for type stability. - - `AutoReverseDiff`: `compile` is always ignored and the gradient tape is never compiled. +- `AutoReverseDiff(; compile=true)` is not supported for Lux models with empty state + `st`. Additionally the returned stats must be empty (`NamedTuple()`). We catch these + issues in most cases and throw an error. -!!! danger +!!! danger "Aliased Gradients" `grads` returned by this function might be aliased by the implementation of the gradient backend. For example, if you cache the `grads` from step `i`, the new gradients diff --git a/src/helpers/recursive_ops.jl b/src/helpers/recursive_ops.jl index f358cb9bb..eba1f1238 100644 --- a/src/helpers/recursive_ops.jl +++ b/src/helpers/recursive_ops.jl @@ -7,17 +7,7 @@ common cases. Any leaves of `x` that are arrays and allow in-place addition will be modified in place. """ -function recursive_add!!(x::AbstractArray, y::AbstractArray) - ArrayInterface.can_setindex(x) || return x .+ y - @. x += y - return x -end -recursive_add!!(x::Tuple, y::Tuple) = map(recursive_add!!, x, y) -recursive_add!!(::Nothing, ::Nothing) = nothing -function recursive_add!!(x::NamedTuple{F}, y::NamedTuple{F}) where {F} - return NamedTuple{F}(map(recursive_add!!, values(x), values(y))) -end -recursive_add!!(x, y) = fmap(recursive_add!!, x, y) +@inline recursive_add!!(x, y) = recursive_map(__add!!, x, y) """ recursive_eltype(x) @@ -48,15 +38,7 @@ Recursively create a zero value for a nested structure `x`. This is equivalent t See also [`Lux.recursive_make_zero!!`](@ref). """ -@inline recursive_make_zero(x::Number) = zero(x) -@inline recursive_make_zero(x::AbstractArray{<:Number}) = zero(x) -@inline recursive_make_zero(x::AbstractArray) = map(recursive_make_zero, x) -@inline recursive_make_zero(x::Tuple) = map(recursive_make_zero, x) -@inline recursive_make_zero(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( - recursive_make_zero, values(x))) -@inline recursive_make_zero(::Nothing) = nothing -@inline recursive_make_zero(v::Val) = v -@inline recursive_make_zero(x) = fmap(recursive_make_zero, x) +@inline recursive_make_zero(x) = recursive_map(__zero, x)::typeof(x) """ recursive_make_zero!!(x) @@ -66,12 +48,59 @@ in-place zeroing will be modified in place. See also [`Lux.recursive_make_zero`](@ref) for fully out-of-place version. """ -@inline recursive_make_zero!!(x::Number) = zero(x) -@inline recursive_make_zero!!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x))) -@inline recursive_make_zero!!(x::AbstractArray) = map(recursive_make_zero!!, x) -@inline recursive_make_zero!!(x::Tuple) = map(recursive_make_zero!!, x) -@inline recursive_make_zero!!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( - recursive_make_zero!!, values(x))) -@inline recursive_make_zero!!(::Nothing) = nothing -@inline recursive_make_zero!!(x::Val) = x -@inline recursive_make_zero!!(x) = fmap(recursive_make_zero!!, x) +@inline recursive_make_zero!!(x) = recursive_map(__zero!!, x)::typeof(x) + +""" + recursive_copyto!(x, y) + +Recursively copy the leaves of two nested structures `x` and `y`. In Functor language, this +is equivalent to doing `fmap(copyto!, x, y)`, but this implementation uses type stable code +for common cases. Note that any immutable leaf will lead to an error. +""" +@inline recursive_copyto!(x, y) = recursive_map(copyto!, x, y) + +""" + recursive_map(f, x, args...) + +Similar to `fmap(f, args...)` but with restricted support for the notion of "leaf" types. +However, this allows for more efficient and type stable implementations of recursive +operations. + +## How this works? + +For the following types it directly defines recursion rules: + + 1. `AbstractArray`: If eltype is `isbitstype`, then `f` is applied to the array, else we + recurse on the array. + 2. `Tuple/NamedTuple`: We recurse on the values. + 3. `Number/Val/Nothing`: We directly apply `f`. + 4. For all other types, we recurse on the fields using `Functors.fmap`. + +!!! note + + In most cases, users should gravitate towards `Functors.fmap` if it is being used + outside of hot loops. Even for other cases, it is always recommended to verify the + correctness of this implementation for specific usecases. +""" +function recursive_map end + +for direct_call in (Number, Val, Nothing) + @eval @inline recursive_map(f::F, x::$(direct_call), args...) where {F} = f(x, args...) +end +@inline function recursive_map(f::F, x::AbstractArray{T}, args...) where {F, T} + isbitstype(T) && return f(x, args...) + return f.(x, args...) +end +@inline function recursive_map(f::F, x::Tuple, args...) where {F} + map_fn = let f = f + (args_...) -> recursive_map(f, args_...) + end + return map(map_fn, x, args...) +end +@inline function recursive_map(f::F, x::NamedTuple{fields}, args...) where {F, fields} + map_fn = let f = f + (args_...) -> recursive_map(f, args_...) + end + return NamedTuple{fields}(map(map_fn, values(x), values.(args)...)) +end +@inline recursive_map(f::F, x, args...) where {F} = fmap(f, x, args...) diff --git a/src/utils.jl b/src/utils.jl index d5a99d3e3..3089a103a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -331,7 +331,9 @@ end end @inline function __fused_agg(::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y) fast_scalar_indexing(x) && fast_scalar_indexing(y) && return sum(lfn, x, y) - return mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y) + # mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y) leads to slowdowns, better to + # allocate a new array + return sum(lfn.(x, y)) end @inline __fused_agg(::Nothing, op::OP, args...) where {OP} = op.(args...) @@ -354,3 +356,20 @@ end @inline __get_dims(::AbstractVector) = Colon() @inline __get_dims(::AbstractArray{T, N}) where {T, N} = 1:(N - 1) + +@inline __zero(x) = zero(x) +@inline __zero(::Nothing) = nothing +@inline __zero(x::Val) = x + +@inline __zero!!(x::Number) = zero(x) +@inline __zero!!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x))) +@inline __zero!!(::Nothing) = nothing +@inline __zero!!(x::Val) = x + +@inline function __add!!(x::AbstractArray{<:Number}, y::AbstractArray{<:Number}) + ArrayInterface.can_setindex(x) || return x .+ y + @. x += y + return x +end +@inline __add!!(x::Number, y::Number) = x + y +@inline __add!!(::Nothing, ::Nothing) = nothing diff --git a/test/contrib/training_tests.jl b/test/contrib/training_tests.jl index 3077b745f..02e6d075a 100644 --- a/test/contrib/training_tests.jl +++ b/test/contrib/training_tests.jl @@ -177,3 +177,74 @@ end @inferred Lux.Experimental.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new) end + +@testitem "Compiled ReverseDiff" setup=[SharedTestSetup] tags=[:contrib] begin + using ADTypes, Optimisers, ReverseDiff + + function mse1(model, ps, st, data) + x_data, y_data = data + y, st_ = model(x_data, ps, st) + return sum(abs2, y .- y_data), st_, (;) + end + + function mse2(model, ps, st, data) + l, st_, stats = mse1(model, ps, st, data) + return l, st_, (; data=2.0f0) + end + + rng = StableRNG(12345) + + dataset = [(randn(rng, Float32, 4, 32), randn(rng, Float32, 4, 32)) for _ in 1:100] + + @testset "Unhandled Cases" begin + model = Chain(Dense(4, 32, tanh), BatchNorm(32), + Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4)) + ps, st = Lux.setup(rng, model) + + tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + + # Stateful models are not supported + @test_throws ArgumentError Lux.Experimental.compute_gradients( + AutoReverseDiff(; compile=true), mse1, dataset[1], tstate) + + model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4)) + ps, st = Lux.setup(rng, model) + + tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + + # Loss functions that return non-empty `stats` are not supported + @test_throws ArgumentError Lux.Experimental.compute_gradients( + AutoReverseDiff(; compile=true), mse2, dataset[1], tstate) + + struct StrangeModel <: Lux.AbstractExplicitLayer end + + function (m::StrangeModel)(x, ps, st) + return x, (; new_state=0.0) + end + + model = StrangeModel() + ps, st = Lux.setup(rng, model) + + tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + + # Stateful models are not supported + @test_throws ArgumentError Lux.Experimental.compute_gradients( + AutoReverseDiff(; compile=true), mse1, dataset[1], tstate) + end + + model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4)) + ps, st = Lux.setup(rng, model) + + tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + + loss_initial = first(mse1(model, ps, st, dataset[1])) + for i in 1:100 + for (x, y) in dataset + _, _, _, tstate = Lux.Experimental.single_train_step!( + AutoReverseDiff(; compile=true), mse1, (x, y), tstate) + end + end + loss_final = first(mse1(model, tstate.parameters, tstate.states, dataset[1])) + + @test loss_final * 100 < loss_initial +end diff --git a/test/misc_tests.jl b/test/misc_tests.jl index 1d4a4167e..693acc184 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -21,24 +21,3 @@ y_t = first(nn(x_t, ps, st)) @test y_t isa AbstractArray{<:ReverseDiff.TrackedReal} end - -@testitem "Tracker.jl patches" setup=[SharedTestSetup] tags=[:autodiff] begin - using Tracker - - nested_st = (; m=Dense(2 => 3), v=rand(2), d=(; x=(rand(2), 1))) - tnested_st = Tracker.param(nested_st) - - @test tnested_st.m === nested_st.m - @test tnested_st.v isa TrackedArray - @test tnested_st.d.x[1] isa TrackedArray - @test tnested_st.d.x[2] isa Tracker.TrackedReal - - @test_nowarn Tracker.zero_grad!(nested_st) - @test_nowarn Tracker.zero_grad!(nested_st.m) - - @test_nowarn Tracker.extract_grad!(tnested_st) - @test_nowarn Tracker.data(tnested_st) - - x = ones(10) |> Tracker.param - @test Lux._gate(x, 1, 1) isa TrackedVector -end