From 6b818172d5f859c2c3eedad1ca4c3e5c2e4415fd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Jun 2024 12:21:32 -0700 Subject: [PATCH 1/6] Move things a bit --- ext/LuxReverseDiffExt.jl | 53 ---------------------- ext/LuxReverseDiffExt/LuxReverseDiffExt.jl | 29 ++++++++++++ ext/LuxReverseDiffExt/rules.jl | 14 ++++++ ext/LuxReverseDiffExt/training.jl | 11 +++++ src/contrib/training.jl | 21 +++------ 5 files changed, 61 insertions(+), 67 deletions(-) delete mode 100644 ext/LuxReverseDiffExt.jl create mode 100644 ext/LuxReverseDiffExt/LuxReverseDiffExt.jl create mode 100644 ext/LuxReverseDiffExt/rules.jl create mode 100644 ext/LuxReverseDiffExt/training.jl diff --git a/ext/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt.jl deleted file mode 100644 index 4af247d24..000000000 --- a/ext/LuxReverseDiffExt.jl +++ /dev/null @@ -1,53 +0,0 @@ -module LuxReverseDiffExt - -using ADTypes: AutoReverseDiff -using ArrayInterface: ArrayInterface -using Functors: fmap -using Lux: Lux, LuxCPUDevice -using ReverseDiff: ReverseDiff, TrackedArray, @grad_from_chainrules -using Setfield: @set! - -function Lux.Experimental.compute_gradients(::AutoReverseDiff, objective_function::F, data, - ts::Lux.Experimental.TrainState) where {F} - tape = ReverseDiff.InstructionTape() - grads = fmap(zero, ts.parameters) - ps_tracked = fmap((p, g) -> ReverseDiff.TrackedArray(p, g, tape), ts.parameters, grads) - loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) - loss.deriv = true - ReverseDiff.reverse_pass!(tape) - @set! ts.states = st - return grads, ReverseDiff.value(loss), stats, ts -end - -# AoS to SoA conversion -function Lux.apply( - m::Lux.AbstractExplicitLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) - @warn "Lux.apply(m::Lux.AbstractExplicitLayer, \ - x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \ - Lux.apply(m::Lux.AbstractExplicitLayer, x::ReverseDiff.TrackedArray}, ps, \ - st).\n\n\ - 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ - 2. This might have performance implications. Check which layer was causing this \ - problem using `Lux.Experimental.@debug_mode`." maxlog=1 - return Lux.apply(m, reshape(ArrayInterface.aos_to_soa(x), size(x)), ps, st) -end - -## Prevent an infinite loop -Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) - -# SimpleChains.jl -@grad_from_chainrules Lux.__apply_simple_chain(layer, x::TrackedArray, ps, ::LuxCPUDevice) -@grad_from_chainrules Lux.__apply_simple_chain(layer, x, ps::TrackedArray, ::LuxCPUDevice) -@grad_from_chainrules Lux.__apply_simple_chain( - layer, x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice) - -# DynamicExpressions.jl -@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr, - operator_enum, x::TrackedArray, ps, ::LuxCPUDevice) -@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr, - operator_enum, x, ps::TrackedArray, ::LuxCPUDevice) -@grad_from_chainrules Lux.__apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, operator_enum, - x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice) - -end diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl new file mode 100644 index 000000000..706b24b90 --- /dev/null +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -0,0 +1,29 @@ +module LuxReverseDiffExt + +using ADTypes: AutoReverseDiff +using ArrayInterface: ArrayInterface +using Functors: fmap +using Lux: Lux, LuxCPUDevice +using ReverseDiff: ReverseDiff, TrackedArray, @grad_from_chainrules +using Setfield: @set! + +# AoS to SoA conversion +function Lux.apply( + m::Lux.AbstractExplicitLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) + @warn "Lux.apply(m::Lux.AbstractExplicitLayer, \ + x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \ + Lux.apply(m::Lux.AbstractExplicitLayer, x::ReverseDiff.TrackedArray}, ps, \ + st).\n\n\ + 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ + 2. This might have performance implications. Check which layer was causing this \ + problem using `Lux.Experimental.@debug_mode`." maxlog=1 + return Lux.apply(m, reshape(ArrayInterface.aos_to_soa(x), size(x)), ps, st) +end + +## Prevent an infinite loop +Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) + +include("rules.jl") +include("training.jl") + +end diff --git a/ext/LuxReverseDiffExt/rules.jl b/ext/LuxReverseDiffExt/rules.jl new file mode 100644 index 000000000..122df5ab8 --- /dev/null +++ b/ext/LuxReverseDiffExt/rules.jl @@ -0,0 +1,14 @@ +# SimpleChains.jl +@grad_from_chainrules Lux.__apply_simple_chain(layer, x::TrackedArray, ps, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_simple_chain(layer, x, ps::TrackedArray, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_simple_chain( + layer, x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice) + +# DynamicExpressions.jl +@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr, + operator_enum, x::TrackedArray, ps, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr, + operator_enum, x, ps::TrackedArray, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_dynamic_expression( + de::Lux.DynamicExpressionsLayer, expr, operator_enum, + x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice) diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl new file mode 100644 index 000000000..6e56f899a --- /dev/null +++ b/ext/LuxReverseDiffExt/training.jl @@ -0,0 +1,11 @@ +function Lux.Experimental.compute_gradients(::AutoReverseDiff, objective_function::F, data, + ts::Lux.Experimental.TrainState) where {F} + tape = ReverseDiff.InstructionTape() + grads = fmap(zero, ts.parameters) + ps_tracked = fmap((p, g) -> ReverseDiff.TrackedArray(p, g, tape), ts.parameters, grads) + loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) + loss.deriv = true + ReverseDiff.reverse_pass!(tape) + @set! ts.states = st + return grads, ReverseDiff.value(loss), stats, ts +end diff --git a/src/contrib/training.jl b/src/contrib/training.jl index 8dd65acc7..a6df98eb2 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -79,12 +79,12 @@ Compute the gradients of the objective function wrt parameters stored in `ts`. ## Backends & AD Packages -| Supported Backends | Packages Needed | -|:------------------ |:---------------- | -| `AutoZygote` | `Zygote.jl` | -| `AutoReverseDiff` | `ReverseDiff.jl` | -| `AutoTracker` | `Tracker.jl` | -| `AutoEnzyme` | `Enzyme.jl` | +| Supported Backends | Packages Needed | +|:---------------------------- |:---------------- | +| `AutoZygote` | `Zygote.jl` | +| `AutoReverseDiff(; compile)` | `ReverseDiff.jl` | +| `AutoTracker` | `Tracker.jl` | +| `AutoEnzyme` | `Enzyme.jl` | ## Arguments @@ -105,14 +105,7 @@ A 4-Tuple containing: - `stats`: Any computed statistics from the objective function. - `ts`: Updated Training State. -## Special Notes on Backends - - - `AutoEnzyme`: `mode` is always ignored and Enzyme ReverseMode is used. The first call - to `compute_gradients` will be type-unstable. It is recommended to call this function - once outside of the training loop and use the returned train_state for type stability. - - `AutoReverseDiff`: `compile` is always ignored and the gradient tape is never compiled. - -!!! danger +!!! danger "Aliased Gradients" `grads` returned by this function might be aliased by the implementation of the gradient backend. For example, if you cache the `grads` from step `i`, the new gradients From e0d262a62bc5f5ec2d7400ace4b5de9bc38346d1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Jun 2024 14:04:26 -0700 Subject: [PATCH 2/6] Use recursive_map --- docs/src/api/Lux/utilities.md | 1 + ext/LuxReverseDiffExt/LuxReverseDiffExt.jl | 2 +- ext/LuxReverseDiffExt/training.jl | 41 ++++++++++-- src/helpers/recursive_ops.jl | 78 ++++++++++++++-------- src/utils.jl | 17 +++++ 5 files changed, 105 insertions(+), 34 deletions(-) diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index 78461ac4a..e73b31530 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -59,6 +59,7 @@ Lux.xlogx ## Recursive Operations ```@docs +Lux.recursive_map Lux.recursive_add!! Lux.recursive_eltype Lux.recursive_make_zero diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl index 706b24b90..d22c39e5c 100644 --- a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -1,6 +1,6 @@ module LuxReverseDiffExt -using ADTypes: AutoReverseDiff +using ADTypes: ADTypes, AutoReverseDiff using ArrayInterface: ArrayInterface using Functors: fmap using Lux: Lux, LuxCPUDevice diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl index 6e56f899a..e41b15e85 100644 --- a/ext/LuxReverseDiffExt/training.jl +++ b/ext/LuxReverseDiffExt/training.jl @@ -1,11 +1,44 @@ -function Lux.Experimental.compute_gradients(::AutoReverseDiff, objective_function::F, data, - ts::Lux.Experimental.TrainState) where {F} +@static if pkgversion(ADTypes) < v"1.5" + # older versions did not have `compile` type parameter. Use slower type-unstable code + function Lux.Experimental.compute_gradients(ad::AutoReverseDiff, objective_function::F, + data, ts::Lux.Experimental.TrainState) where {F} + ad.compile && return __compiled_reverse_diff(objective_function, data, ts) + return __uncompiled_reverse_diff(objective_function, data, ts) + end +else + for compiled in (false, true) + fname = compiled ? :__compiled_reverse_diff : :__uncompiled_reverse_diff + @eval function Lux.Experimental.compute_gradients( + ::AutoReverseDiff{$(compiled)}, objective_function::F, + data, ts::Lux.Experimental.TrainState) where {F} + return $(fname)(objective_function, data, ts) + end + end +end + +@inline function __uncompiled_reverse_diff( + objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} tape = ReverseDiff.InstructionTape() - grads = fmap(zero, ts.parameters) - ps_tracked = fmap((p, g) -> ReverseDiff.TrackedArray(p, g, tape), ts.parameters, grads) + grads = Lux.recursive_make_zero(ts.parameters) + ps_tracked = Lux.recursive_map( + Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, grads) loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) loss.deriv = true ReverseDiff.reverse_pass!(tape) @set! ts.states = st return grads, ReverseDiff.value(loss), stats, ts end + +@inline function __compiled_reverse_diff( + objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} + # tape = ReverseDiff.InstructionTape() + # grads = Lux.recursive_make_zero(ts.parameters) + # ps_tracked = Lux.recursive_map( + # Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, grads) + # loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) + # loss.deriv = true + # ReverseDiff.reverse_pass!(tape) + # @set! ts.states = st + # return grads, ReverseDiff.value(loss), stats, ts + error("Not implemented yet") +end diff --git a/src/helpers/recursive_ops.jl b/src/helpers/recursive_ops.jl index f358cb9bb..42460348f 100644 --- a/src/helpers/recursive_ops.jl +++ b/src/helpers/recursive_ops.jl @@ -7,17 +7,7 @@ common cases. Any leaves of `x` that are arrays and allow in-place addition will be modified in place. """ -function recursive_add!!(x::AbstractArray, y::AbstractArray) - ArrayInterface.can_setindex(x) || return x .+ y - @. x += y - return x -end -recursive_add!!(x::Tuple, y::Tuple) = map(recursive_add!!, x, y) -recursive_add!!(::Nothing, ::Nothing) = nothing -function recursive_add!!(x::NamedTuple{F}, y::NamedTuple{F}) where {F} - return NamedTuple{F}(map(recursive_add!!, values(x), values(y))) -end -recursive_add!!(x, y) = fmap(recursive_add!!, x, y) +recursive_add!!(x, y) = recursive_map(__add!!, x, y) """ recursive_eltype(x) @@ -48,15 +38,7 @@ Recursively create a zero value for a nested structure `x`. This is equivalent t See also [`Lux.recursive_make_zero!!`](@ref). """ -@inline recursive_make_zero(x::Number) = zero(x) -@inline recursive_make_zero(x::AbstractArray{<:Number}) = zero(x) -@inline recursive_make_zero(x::AbstractArray) = map(recursive_make_zero, x) -@inline recursive_make_zero(x::Tuple) = map(recursive_make_zero, x) -@inline recursive_make_zero(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( - recursive_make_zero, values(x))) -@inline recursive_make_zero(::Nothing) = nothing -@inline recursive_make_zero(v::Val) = v -@inline recursive_make_zero(x) = fmap(recursive_make_zero, x) +@inline recursive_make_zero(x) = recursive_map(__zero, x) """ recursive_make_zero!!(x) @@ -66,12 +48,50 @@ in-place zeroing will be modified in place. See also [`Lux.recursive_make_zero`](@ref) for fully out-of-place version. """ -@inline recursive_make_zero!!(x::Number) = zero(x) -@inline recursive_make_zero!!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x))) -@inline recursive_make_zero!!(x::AbstractArray) = map(recursive_make_zero!!, x) -@inline recursive_make_zero!!(x::Tuple) = map(recursive_make_zero!!, x) -@inline recursive_make_zero!!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( - recursive_make_zero!!, values(x))) -@inline recursive_make_zero!!(::Nothing) = nothing -@inline recursive_make_zero!!(x::Val) = x -@inline recursive_make_zero!!(x) = fmap(recursive_make_zero!!, x) +@inline recursive_make_zero!!(x) = recursive_map(__zero!!, x) + +""" + recursive_map(f, x, args...) + +Similar to `fmap(f, args...)` but with restricted support for the notion of "leaf" types. +However, this allows for more efficient and type stable implementations of recursive +operations. + +## How this works? + +For the following types it directly defines recursion rules: + + 1. `AbstractArray`: If eltype is `isbitstype`, then `f` is applied to the array, else we + recurse on the array. + 2. `Tuple/NamedTuple`: We recurse on the values. + 3. `Number/Val/Nothing`: We directly apply `f`. + 4. For all other types, we recurse on the fields using `Functors.fmap`. + +!!! note + + In most cases, users should gravitate towards `Functors.fmap` if it is being used + outside of hot loops. Even for other cases, it is always recommended to verify the + correctness of this implementation for specific usecases. +""" +function recursive_map end + +for direct_call in (Number, Val, Nothing) + @eval @inline recursive_map(f::F, x::$(direct_call), args...) where {F} = f(x, args...) +end +@inline function recursive_map(f::F, x::AbstractArray{T}, args...) where {F, T} + isbitstype(T) && return f(x, args...) + return f.(x, args...) +end +@inline function recursive_map(f::F, x::Tuple, args...) where {F} + map_fn = let f = f + (args_...) -> recursive_map(f, args_...) + end + return map(map_fn, x, args...) +end +@inline function recursive_map(f::F, x::NamedTuple{fields}, args...) where {F, fields} + map_fn = let f = f + (args_...) -> recursive_map(f, args_...) + end + return NamedTuple{fields}(map(map_fn, values(x), values.(args)...)) +end +@inline recursive_map(f::F, x, args...) where {F} = fmap(f, x, args...) diff --git a/src/utils.jl b/src/utils.jl index d5a99d3e3..e2b0f0df4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -354,3 +354,20 @@ end @inline __get_dims(::AbstractVector) = Colon() @inline __get_dims(::AbstractArray{T, N}) where {T, N} = 1:(N - 1) + +@inline __zero(x) = zero(x) +@inline __zero(::Nothing) = nothing +@inline __zero(x::Val) = x + +@inline __zero!!(x::Number) = zero(x) +@inline __zero!!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x))) +@inline __zero!!(::Nothing) = nothing +@inline __zero!!(x::Val) = x + +@inline function __add!!(x::AbstractArray{<:Number}, y::AbstractArray{<:Number}) + ArrayInterface.can_setindex(x) || return x .+ y + @. x += y + return x +end +@inline __add!!(x::Number, y::Number) = x + y +@inline __add!!(::Nothing, ::Nothing) = nothing From aa99508e712a8732a6ec555fe13d510f7468e3a9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Jun 2024 14:58:33 -0700 Subject: [PATCH 3/6] Consolidate the Backend parameter caching --- ext/LuxEnzymeExt.jl | 62 +++++++++------------ ext/LuxReverseDiffExt/LuxReverseDiffExt.jl | 3 +- ext/LuxReverseDiffExt/training.jl | 51 +++++++++-------- ext/LuxTrackerExt.jl | 64 ++++++++++++---------- src/contrib/training.jl | 16 +++++- src/utils.jl | 4 +- 6 files changed, 107 insertions(+), 93 deletions(-) diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index a4c4994eb..e3c3950a6 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -1,75 +1,67 @@ module LuxEnzymeExt using ADTypes: AutoEnzyme -using ConcreteStructs: @concrete using Enzyme: Enzyme, Active, Const, Duplicated using Lux: Lux +using Lux.Experimental: TrainingBackendCache, TrainState -@concrete struct CachedEnzymeExtras{FT} - dparameters - objective_function - st_wrap - stats_wrap -end - -# Case I: We have CachedEnzymeExtras and objective_function is unchanged. -function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, - ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras{FT}, F}) where {F, FT} +# Case I: We have TrainingBackendCache{:Enzyme} and obj_fn is unchanged. +function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:Enzyme, FT}, F}) where {F, FT} dps = Lux.recursive_make_zero!!(ts.cache.dparameters) _, loss = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, ts.cache.objective_function, Active, Const(ts.model), + Enzyme.ReverseWithPrimal, ts.cache.obj_fn, Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) ts_new = __construct_new_trainstate( - ts.cache.st_wrap[], ts.states, ts, objective_function, dps, - ts.cache.objective_function, ts.cache.st_wrap, ts.cache.stats_wrap) + ts.cache.extras.st_wrap[], ts.states, ts, obj_fn, dps, + ts.cache.obj_fn, ts.cache.extras.st_wrap, ts.cache.extras.stats_wrap) - return dps, loss, ts.cache.stats_wrap[], ts_new + return dps, loss, ts.cache.extras.stats_wrap[], ts_new end -# Case II: We have CachedEnzymeExtras and objective_function is changed. -function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, - ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras{FT}}) where {F, FT} +# Case II: We have CachedEnzymeExtras and obj_fn is changed. +function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:Enzyme, FT}}) where {F, FT} dps = Lux.recursive_make_zero!!(ts.cache.dparameters) obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( - objective_function, ts.model, ts.parameters, ts.states, data, Val(FT)) + obj_fn, ts.model, ts.parameters, ts.states, data, Val(FT)) _, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) ts_new = __construct_new_trainstate( - st_wrap[], ts.states, ts, objective_function, dps, obj_fn, st_wrap, stats_wrap) + st_wrap[], ts.states, ts, obj_fn, dps, obj_fn, st_wrap, stats_wrap) return dps, loss, stats_wrap[], ts_new end # Case III: Nothing is cached. First call to `compute_gradients` -function Lux.Experimental.compute_gradients(ad::AutoEnzyme, objective_function::F, data, - ts::Lux.Experimental.TrainState) where {F} +function Lux.Experimental.compute_gradients( + ad::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F} dps = Lux.recursive_make_zero(ts.parameters) - cache = CachedEnzymeExtras{true}(dps, nothing, nothing, nothing) - ts_new = Lux.Experimental.TrainState( + cache = TrainingBackendCache{:Enzyme, true}( + dps, nothing, (; st_wrap=nothing, stats_wrap=nothing)) + ts_new = TrainState( cache, nothing, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) - return Lux.Experimental.compute_gradients(ad, objective_function, data, ts_new) + return Lux.Experimental.compute_gradients(ad, obj_fn, data, ts_new) end # If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not # storing the objective function. -function __construct_new_trainstate( - st_new::S, ::S, ts::Lux.Experimental.TrainState, objective_fn::O, - dps, obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2} - cache = CachedEnzymeExtras{false}(dps, obj_fn, st_wrap, stats_wrap) - return Lux.Experimental.TrainState( +function __construct_new_trainstate(st_new::S, ::S, ts::TrainState, objective_fn::O, dps, + obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2} + cache = TrainingBackendCache{:Enzyme, false}(dps, obj_fn, (; st_wrap, stats_wrap)) + return TrainState( cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) end -function __construct_new_trainstate( - st_new, _, ts::Lux.Experimental.TrainState, objective_fn::O, - dps, obj_fn::O2, st_wrap, stats_wrap) where {O, O2} - cache = CachedEnzymeExtras{false}(dps, nothing, nothing, nothing) - return Lux.Experimental.TrainState( +function __construct_new_trainstate(st_new, _, ts::TrainState, objective_fn::O, dps, + obj_fn::O2, st_wrap, stats_wrap) where {O, O2} + cache = TrainingBackendCache{:Enzyme, false}(dps, nothing, (; st_wrap, stats_wrap)) + return TrainState( cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) end diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl index d22c39e5c..5015e4c18 100644 --- a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -2,10 +2,9 @@ module LuxReverseDiffExt using ADTypes: ADTypes, AutoReverseDiff using ArrayInterface: ArrayInterface -using Functors: fmap using Lux: Lux, LuxCPUDevice +using Lux.Experimental: TrainingBackendCache, TrainState using ReverseDiff: ReverseDiff, TrackedArray, @grad_from_chainrules -using Setfield: @set! # AoS to SoA conversion function Lux.apply( diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl index e41b15e85..872a270a9 100644 --- a/ext/LuxReverseDiffExt/training.jl +++ b/ext/LuxReverseDiffExt/training.jl @@ -1,44 +1,47 @@ @static if pkgversion(ADTypes) < v"1.5" # older versions did not have `compile` type parameter. Use slower type-unstable code - function Lux.Experimental.compute_gradients(ad::AutoReverseDiff, objective_function::F, - data, ts::Lux.Experimental.TrainState) where {F} - ad.compile && return __compiled_reverse_diff(objective_function, data, ts) - return __uncompiled_reverse_diff(objective_function, data, ts) + function Lux.Experimental.compute_gradients( + ad::AutoReverseDiff, obj_fn::F, data, ts::TrainState) where {F} + ad.compile && return __compiled_reverse_diff(obj_fn, data, ts) + return __uncompiled_reverse_diff(obj_fn, data, ts) end else for compiled in (false, true) fname = compiled ? :__compiled_reverse_diff : :__uncompiled_reverse_diff @eval function Lux.Experimental.compute_gradients( - ::AutoReverseDiff{$(compiled)}, objective_function::F, - data, ts::Lux.Experimental.TrainState) where {F} - return $(fname)(objective_function, data, ts) + ::AutoReverseDiff{$(compiled)}, obj_fn::F, data, ts::TrainState) where {F} + return $(fname)(obj_fn, data, ts) end end end +# Uncompiled ReverseDiff @inline function __uncompiled_reverse_diff( - objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} + obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{:ReverseDiff}}) where {F} tape = ReverseDiff.InstructionTape() - grads = Lux.recursive_make_zero(ts.parameters) ps_tracked = Lux.recursive_map( - Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, grads) - loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) + Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, ts.cache.dparameters) + + loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data) loss.deriv = true ReverseDiff.reverse_pass!(tape) - @set! ts.states = st - return grads, ReverseDiff.value(loss), stats, ts + + ts_new = TrainState( + TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, obj_fn, nothing), + obj_fn, ts.model, ts.parameters, st, ts.optimizer_state, ts.step) + + return ts.cache.dparameters, ReverseDiff.value(loss), stats, ts_new +end + +# First call, nothing is cached +@inline function __uncompiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} + grads = Lux.recursive_make_zero(ts.parameters) + ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, obj_fn, nothing), + obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + return __uncompiled_reverse_diff(obj_fn, data, ts_new) end -@inline function __compiled_reverse_diff( - objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} - # tape = ReverseDiff.InstructionTape() - # grads = Lux.recursive_make_zero(ts.parameters) - # ps_tracked = Lux.recursive_map( - # Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, grads) - # loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) - # loss.deriv = true - # ReverseDiff.reverse_pass!(tape) - # @set! ts.states = st - # return grads, ReverseDiff.value(loss), stats, ts +# Compiled ReverseDiff +@inline function __compiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} error("Not implemented yet") end diff --git a/ext/LuxTrackerExt.jl b/ext/LuxTrackerExt.jl index fe5785669..0e0b5acf5 100644 --- a/ext/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt.jl @@ -3,33 +3,14 @@ module LuxTrackerExt using ADTypes: AutoTracker using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore -using FastClosures: @closure using Functors: fmap using Lux: Lux, LuxCPUDevice +using Lux.Experimental: TrainingBackendCache, TrainState using LuxCore: LuxCore -using Setfield: @set! using Tracker: Tracker, TrackedArray, @grad_from_chainrules const CRC = ChainRulesCore -# Type Piracy: Need to upstream -Tracker.param(nt::NamedTuple{F}) where {F} = NamedTuple{F}(Tracker.param.(values(nt))) -Tracker.param(t::Tuple) = map(Tracker.param, t) -Tracker.param(l::LuxCore.AbstractExplicitLayer) = l - -Tracker.zero_grad!(nt::NamedTuple) = Tracker.zero_grad!.(values(nt)) -Tracker.zero_grad!(::LuxCore.AbstractExplicitLayer) = nothing - -function Tracker.extract_grad!(nt::NamedTuple{F}) where {F} - return NamedTuple{F}(Tracker.extract_grad!.(values(nt))) -end -Tracker.extract_grad!(t::Tuple) = map(Tracker.extract_grad!, t) -Tracker.extract_grad!(::LuxCore.AbstractExplicitLayer) = nothing - -Tracker.data(nt::NamedTuple) = fmap(Tracker.data, nt) -Tracker.data(t::Tuple) = map(Tracker.data, t) -Tracker.data(l::LuxCore.AbstractExplicitLayer) = l - # Weight Norm Patch @inline Lux._norm(x::TrackedArray; dims=Colon()) = sqrt.(sum(abs2.(x); dims)) @@ -37,15 +18,36 @@ Tracker.data(l::LuxCore.AbstractExplicitLayer) = l @inline Lux._gate(x::Tracker.TrackedVector, h::Int, n::Int) = x[Lux._gate(h, n)] @inline Lux._gate(x::Tracker.TrackedMatrix, h::Int, n::Int) = x[Lux._gate(h, n), :] +function __construct_tracked_params(ps, dps) + map_fn = (p, dp) -> Tracker.TrackedArray(Tracker.Call(), p, dp) + return Lux.recursive_map(map_fn, ps, dps) +end + # Lux.Training -function Lux.Experimental.compute_gradients(::AutoTracker, objective_function::F, data, - ts::Lux.Experimental.TrainState) where {F} - ps_tracked = fmap(Tracker.param, ts.parameters) - loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) +## Use the cached gradient parameters +function Lux.Experimental.compute_gradients(::AutoTracker, obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:Tracker, FT}}) where {F, FT} + dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + ps_tracked = __construct_tracked_params(ts.parameters, dparams) + + loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data) Tracker.back!(loss) - @set! ts.states = st - grads = fmap(Tracker.grad, ps_tracked) - return grads, Tracker.value(loss), stats, ts + + ts_new = TrainState( + TrainingBackendCache{:Tracker, false}(ts.cache.dparameters, obj_fn, nothing), + obj_fn, ts.model, ts.parameters, st, ts.optimizer_state, ts.step) + + return dparams, Tracker.value(loss), stats, ts_new +end + +## First call, nothing is cached +function Lux.Experimental.compute_gradients( + ::AutoTracker, obj_fn::F, data, ts::TrainState) where {F} + grads = Lux.recursive_make_zero(ts.parameters) + ts_new = TrainState( + TrainingBackendCache{:Tracker, true}(grads, obj_fn, nothing), obj_fn, + ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + return Lux.Experimental.compute_gradients(AutoTracker(), obj_fn, data, ts_new) end # AoS to SoA conversion @@ -77,9 +79,11 @@ Tracker.@grad function Lux.__apply_simple_chain(layer, x, ps, ::LuxCPUDevice) As such please test your model with FiniteDifferences or Zygote before using \ `Tracker.jl` for your model." maxlog=1 y, pb_f = CRC.rrule(layer, Tracker.data(x), Tracker.data(ps)) - __∇apply_simple_chain = @closure Δ -> begin - _, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ))) - return Tracker.nobacksies(:__apply_simple_chain, (nothing, ∂x, ∂ps, nothing)) + __∇apply_simple_chain = let pb_f = pb_f + Δ -> begin + _, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ))) + return Tracker.nobacksies(:__apply_simple_chain, (nothing, ∂x, ∂ps, nothing)) + end end # Tracker is not great at handling arbitrary types, so we convert to Array return Array(y), __∇apply_simple_chain diff --git a/src/contrib/training.jl b/src/contrib/training.jl index a6df98eb2..972606424 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -29,6 +29,14 @@ Internal fields: step::Int end +@concrete struct TrainingBackendCache{backend, first_try} + dparameters + objective_function + extras +end + +@inline __backend(::TrainingBackendCache{backend}) where {backend} = backend + function Base.show(io::IO, ts::TrainState) println(io, "TrainState") println(io, " model: ", ts.model) @@ -36,7 +44,13 @@ function Base.show(io::IO, ts::TrainState) println(io, " # of states: ", Lux.statelength(ts.states)) println(io, " optimizer_state: ", ts.optimizer_state) print(io, " step: ", ts.step) - ts.cache !== nothing && print(io, "\n cache: ", nameof(typeof(ts.cache))) + if ts.cache !== nothing + if ts.cache isa TrainingBackendCache + print(io, "\n cache: $(nameof(typeof(ts.cache))){$(__backend(ts.cache))}") + else + print(io, "\n cache: $(nameof(typeof(ts.cache)))") + end + end ts.objective_function !== nothing && print(io, "\n objective_function: ", nameof(typeof(ts.objective_function))) end diff --git a/src/utils.jl b/src/utils.jl index e2b0f0df4..3089a103a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -331,7 +331,9 @@ end end @inline function __fused_agg(::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y) fast_scalar_indexing(x) && fast_scalar_indexing(y) && return sum(lfn, x, y) - return mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y) + # mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y) leads to slowdowns, better to + # allocate a new array + return sum(lfn.(x, y)) end @inline __fused_agg(::Nothing, op::OP, args...) where {OP} = op.(args...) From 2b002668b45a9242ce44a0af9d1807b0c4132b27 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Jun 2024 15:11:47 -0700 Subject: [PATCH 4/6] Zero the gradients --- ext/LuxEnzymeExt.jl | 21 +++++++++++---------- ext/LuxReverseDiffExt/training.jl | 30 +++++++++++++++++++----------- ext/LuxTrackerExt.jl | 8 ++------ src/contrib/training.jl | 1 - test/misc_tests.jl | 21 --------------------- 5 files changed, 32 insertions(+), 49 deletions(-) diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index e3c3950a6..9cedf49d7 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -8,15 +8,15 @@ using Lux.Experimental: TrainingBackendCache, TrainState # Case I: We have TrainingBackendCache{:Enzyme} and obj_fn is unchanged. function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{:Enzyme, FT}, F}) where {F, FT} - dps = Lux.recursive_make_zero!!(ts.cache.dparameters) + dps = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) _, loss = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, ts.cache.obj_fn, Active, Const(ts.model), + Enzyme.ReverseWithPrimal, ts.cache.extras.obj_fn, Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) ts_new = __construct_new_trainstate( ts.cache.extras.st_wrap[], ts.states, ts, obj_fn, dps, - ts.cache.obj_fn, ts.cache.extras.st_wrap, ts.cache.extras.stats_wrap) + ts.cache.extras.obj_fn, ts.cache.extras.st_wrap, ts.cache.extras.stats_wrap) return dps, loss, ts.cache.extras.stats_wrap[], ts_new end @@ -24,16 +24,17 @@ end # Case II: We have CachedEnzymeExtras and obj_fn is changed. function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{:Enzyme, FT}}) where {F, FT} - dps = Lux.recursive_make_zero!!(ts.cache.dparameters) + dps = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) - obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( + obj_fn_wrap, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( obj_fn, ts.model, ts.parameters, ts.states, data, Val(FT)) - _, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(ts.model), + _, loss = Enzyme.autodiff( + Enzyme.ReverseWithPrimal, obj_fn_wrap, Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) ts_new = __construct_new_trainstate( - st_wrap[], ts.states, ts, obj_fn, dps, obj_fn, st_wrap, stats_wrap) + st_wrap[], ts.states, ts, obj_fn, dps, obj_fn_wrap, st_wrap, stats_wrap) return dps, loss, stats_wrap[], ts_new end @@ -43,7 +44,7 @@ function Lux.Experimental.compute_gradients( ad::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F} dps = Lux.recursive_make_zero(ts.parameters) cache = TrainingBackendCache{:Enzyme, true}( - dps, nothing, (; st_wrap=nothing, stats_wrap=nothing)) + dps, (; obj_fn=nothing, st_wrap=nothing, stats_wrap=nothing)) ts_new = TrainState( cache, nothing, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) return Lux.Experimental.compute_gradients(ad, obj_fn, data, ts_new) @@ -53,14 +54,14 @@ end # storing the objective function. function __construct_new_trainstate(st_new::S, ::S, ts::TrainState, objective_fn::O, dps, obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2} - cache = TrainingBackendCache{:Enzyme, false}(dps, obj_fn, (; st_wrap, stats_wrap)) + cache = TrainingBackendCache{:Enzyme, false}(dps, (; obj_fn, st_wrap, stats_wrap)) return TrainState( cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) end function __construct_new_trainstate(st_new, _, ts::TrainState, objective_fn::O, dps, obj_fn::O2, st_wrap, stats_wrap) where {O, O2} - cache = TrainingBackendCache{:Enzyme, false}(dps, nothing, (; st_wrap, stats_wrap)) + cache = TrainingBackendCache{:Enzyme, false}(dps, (; obj_fn, st_wrap, stats_wrap)) return TrainState( cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) end diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl index 872a270a9..1bf69af88 100644 --- a/ext/LuxReverseDiffExt/training.jl +++ b/ext/LuxReverseDiffExt/training.jl @@ -16,32 +16,40 @@ else end # Uncompiled ReverseDiff -@inline function __uncompiled_reverse_diff( - obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{:ReverseDiff}}) where {F} +@inline function __uncompiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} + grads = Lux.recursive_make_zero(ts.parameters) + ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, nothing), + obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + return __uncompiled_reverse_diff(obj_fn, data, ts_new) +end + +@inline function __uncompiled_reverse_diff(obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT} tape = ReverseDiff.InstructionTape() + dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) ps_tracked = Lux.recursive_map( - Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, ts.cache.dparameters) + Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, dparams) loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data) loss.deriv = true ReverseDiff.reverse_pass!(tape) ts_new = TrainState( - TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, obj_fn, nothing), + TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, nothing), obj_fn, ts.model, ts.parameters, st, ts.optimizer_state, ts.step) return ts.cache.dparameters, ReverseDiff.value(loss), stats, ts_new end -# First call, nothing is cached -@inline function __uncompiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} +# Compiled ReverseDiff +@inline function __compiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} grads = Lux.recursive_make_zero(ts.parameters) - ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, obj_fn, nothing), + ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, nothing), obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) - return __uncompiled_reverse_diff(obj_fn, data, ts_new) + return __compiled_reverse_diff(obj_fn, data, ts_new) end -# Compiled ReverseDiff -@inline function __compiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} - error("Not implemented yet") +@inline function __compiled_reverse_diff(obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT} + error(1) end diff --git a/ext/LuxTrackerExt.jl b/ext/LuxTrackerExt.jl index 0e0b5acf5..3ba136261 100644 --- a/ext/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt.jl @@ -3,7 +3,6 @@ module LuxTrackerExt using ADTypes: AutoTracker using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore -using Functors: fmap using Lux: Lux, LuxCPUDevice using Lux.Experimental: TrainingBackendCache, TrainState using LuxCore: LuxCore @@ -24,7 +23,6 @@ function __construct_tracked_params(ps, dps) end # Lux.Training -## Use the cached gradient parameters function Lux.Experimental.compute_gradients(::AutoTracker, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{:Tracker, FT}}) where {F, FT} dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) @@ -34,18 +32,16 @@ function Lux.Experimental.compute_gradients(::AutoTracker, obj_fn::F, data, Tracker.back!(loss) ts_new = TrainState( - TrainingBackendCache{:Tracker, false}(ts.cache.dparameters, obj_fn, nothing), + TrainingBackendCache{:Tracker, false}(ts.cache.dparameters, nothing), obj_fn, ts.model, ts.parameters, st, ts.optimizer_state, ts.step) return dparams, Tracker.value(loss), stats, ts_new end -## First call, nothing is cached function Lux.Experimental.compute_gradients( ::AutoTracker, obj_fn::F, data, ts::TrainState) where {F} grads = Lux.recursive_make_zero(ts.parameters) - ts_new = TrainState( - TrainingBackendCache{:Tracker, true}(grads, obj_fn, nothing), obj_fn, + ts_new = TrainState(TrainingBackendCache{:Tracker, true}(grads, nothing), obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) return Lux.Experimental.compute_gradients(AutoTracker(), obj_fn, data, ts_new) end diff --git a/src/contrib/training.jl b/src/contrib/training.jl index 972606424..e62bbba9a 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -31,7 +31,6 @@ end @concrete struct TrainingBackendCache{backend, first_try} dparameters - objective_function extras end diff --git a/test/misc_tests.jl b/test/misc_tests.jl index 1d4a4167e..693acc184 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -21,24 +21,3 @@ y_t = first(nn(x_t, ps, st)) @test y_t isa AbstractArray{<:ReverseDiff.TrackedReal} end - -@testitem "Tracker.jl patches" setup=[SharedTestSetup] tags=[:autodiff] begin - using Tracker - - nested_st = (; m=Dense(2 => 3), v=rand(2), d=(; x=(rand(2), 1))) - tnested_st = Tracker.param(nested_st) - - @test tnested_st.m === nested_st.m - @test tnested_st.v isa TrackedArray - @test tnested_st.d.x[1] isa TrackedArray - @test tnested_st.d.x[2] isa Tracker.TrackedReal - - @test_nowarn Tracker.zero_grad!(nested_st) - @test_nowarn Tracker.zero_grad!(nested_st.m) - - @test_nowarn Tracker.extract_grad!(tnested_st) - @test_nowarn Tracker.data(tnested_st) - - x = ones(10) |> Tracker.param - @test Lux._gate(x, 1, 1) isa TrackedVector -end From 66a5cdac4f6b929b122c3cbb76ef67df4057f67d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Jun 2024 17:19:31 -0700 Subject: [PATCH 5/6] Type assert for safety --- ext/LuxReverseDiffExt/training.jl | 24 +++++++++++++++++------- src/helpers/recursive_ops.jl | 6 +++--- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl index 1bf69af88..ba63849bb 100644 --- a/ext/LuxReverseDiffExt/training.jl +++ b/ext/LuxReverseDiffExt/training.jl @@ -18,15 +18,15 @@ end # Uncompiled ReverseDiff @inline function __uncompiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} grads = Lux.recursive_make_zero(ts.parameters) - ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, nothing), - obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, nothing), obj_fn, + ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) return __uncompiled_reverse_diff(obj_fn, data, ts_new) end @inline function __uncompiled_reverse_diff(obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT} - tape = ReverseDiff.InstructionTape() dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + tape = ReverseDiff.InstructionTape() ps_tracked = Lux.recursive_map( Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, dparams) @@ -44,12 +44,22 @@ end # Compiled ReverseDiff @inline function __compiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} grads = Lux.recursive_make_zero(ts.parameters) - ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, nothing), - obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, nothing), obj_fn, + ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) return __compiled_reverse_diff(obj_fn, data, ts_new) end -@inline function __compiled_reverse_diff(obj_fn::F, data, - ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT} +## Tape hasn't been compiled yet +@inline function __compiled_reverse_diff(obj_fn::F, + data, + ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT, P, Nothing}}) where { + F, FT, P} + dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + tape = ReverseDiff.InstructionTape() + ps_tracked = Lux.recursive_map( + Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, dparams) + + loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data) + error(1) end diff --git a/src/helpers/recursive_ops.jl b/src/helpers/recursive_ops.jl index 42460348f..e98912326 100644 --- a/src/helpers/recursive_ops.jl +++ b/src/helpers/recursive_ops.jl @@ -7,7 +7,7 @@ common cases. Any leaves of `x` that are arrays and allow in-place addition will be modified in place. """ -recursive_add!!(x, y) = recursive_map(__add!!, x, y) +@inline recursive_add!!(x, y) = recursive_map(__add!!, x, y) """ recursive_eltype(x) @@ -38,7 +38,7 @@ Recursively create a zero value for a nested structure `x`. This is equivalent t See also [`Lux.recursive_make_zero!!`](@ref). """ -@inline recursive_make_zero(x) = recursive_map(__zero, x) +@inline recursive_make_zero(x) = recursive_map(__zero, x)::typeof(x) """ recursive_make_zero!!(x) @@ -48,7 +48,7 @@ in-place zeroing will be modified in place. See also [`Lux.recursive_make_zero`](@ref) for fully out-of-place version. """ -@inline recursive_make_zero!!(x) = recursive_map(__zero!!, x) +@inline recursive_make_zero!!(x) = recursive_map(__zero!!, x)::typeof(x) """ recursive_map(f, x, args...) From 1da2e52be15bdf6c6af9e58fcdeff177b6424622 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Jun 2024 19:09:25 -0700 Subject: [PATCH 6/6] Finish implementation of compiled reversediff --- Project.toml | 2 +- docs/src/api/Lux/utilities.md | 1 + ext/LuxEnzymeExt.jl | 31 ++++-------- ext/LuxReverseDiffExt/training.jl | 80 +++++++++++++++++++++++++++---- src/Lux.jl | 4 +- src/contrib/training.jl | 6 +++ src/helpers/recursive_ops.jl | 9 ++++ test/contrib/training_tests.jl | 71 +++++++++++++++++++++++++++ 8 files changed, 170 insertions(+), 34 deletions(-) diff --git a/Project.toml b/Project.toml index 1cf926be3..b88bbce2e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.56" +version = "0.5.57" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index e73b31530..bd8a71189 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -61,6 +61,7 @@ Lux.xlogx ```@docs Lux.recursive_map Lux.recursive_add!! +Lux.recursive_copyto! Lux.recursive_eltype Lux.recursive_make_zero Lux.recursive_make_zero!! diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index 9cedf49d7..32ef8aae5 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -7,16 +7,15 @@ using Lux.Experimental: TrainingBackendCache, TrainState # Case I: We have TrainingBackendCache{:Enzyme} and obj_fn is unchanged. function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data, - ts::TrainState{<:TrainingBackendCache{:Enzyme, FT}, F}) where {F, FT} - dps = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + ts::TrainState{<:TrainingBackendCache{:Enzyme, false}, F}) where {F} + dps = Lux.recursive_make_zero!!(ts.cache.dparameters) _, loss = Enzyme.autodiff( Enzyme.ReverseWithPrimal, ts.cache.extras.obj_fn, Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) - ts_new = __construct_new_trainstate( - ts.cache.extras.st_wrap[], ts.states, ts, obj_fn, dps, - ts.cache.extras.obj_fn, ts.cache.extras.st_wrap, ts.cache.extras.stats_wrap) + ts_new = TrainState(ts.cache, obj_fn, ts.model, ts.parameters, + ts.cache.extras.st_wrap[], ts.optimizer_state, ts.step) return dps, loss, ts.cache.extras.stats_wrap[], ts_new end @@ -33,8 +32,10 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data, Enzyme.ReverseWithPrimal, obj_fn_wrap, Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) - ts_new = __construct_new_trainstate( - st_wrap[], ts.states, ts, obj_fn, dps, obj_fn_wrap, st_wrap, stats_wrap) + cache = TrainingBackendCache{:Enzyme, false}( + dps, (; obj_fn=obj_fn_wrap, st_wrap, stats_wrap)) + ts_new = TrainState( + cache, obj_fn, ts.model, ts.parameters, st_wrap[], ts.optimizer_state, ts.step) return dps, loss, stats_wrap[], ts_new end @@ -50,20 +51,4 @@ function Lux.Experimental.compute_gradients( return Lux.Experimental.compute_gradients(ad, obj_fn, data, ts_new) end -# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not -# storing the objective function. -function __construct_new_trainstate(st_new::S, ::S, ts::TrainState, objective_fn::O, dps, - obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2} - cache = TrainingBackendCache{:Enzyme, false}(dps, (; obj_fn, st_wrap, stats_wrap)) - return TrainState( - cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) -end - -function __construct_new_trainstate(st_new, _, ts::TrainState, objective_fn::O, dps, - obj_fn::O2, st_wrap, stats_wrap) where {O, O2} - cache = TrainingBackendCache{:Enzyme, false}(dps, (; obj_fn, st_wrap, stats_wrap)) - return TrainState( - cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) -end - end diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl index ba63849bb..5d189053b 100644 --- a/ext/LuxReverseDiffExt/training.jl +++ b/ext/LuxReverseDiffExt/training.jl @@ -44,22 +44,84 @@ end # Compiled ReverseDiff @inline function __compiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} grads = Lux.recursive_make_zero(ts.parameters) - ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, nothing), obj_fn, + data_cache = deepcopy(data) + ps_cache = deepcopy(ts.parameters) + extras = (; data_cache, ps_cache) + + ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, extras), nothing, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) return __compiled_reverse_diff(obj_fn, data, ts_new) end -## Tape hasn't been compiled yet -@inline function __compiled_reverse_diff(obj_fn::F, - data, - ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT, P, Nothing}}) where { - F, FT, P} +## Tape hasn't been compiled yet / Function mismatch so recompile +@inline function __compiled_reverse_diff(obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT} + if Lux.statelength(ts.states) != 0 + throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for Lux \ + models with non-empty state `st`.")) + end + + if FT # do a dry run + _, st_, stats = obj_fn(ts.model, ts.parameters, ts.states, data) + if stats != NamedTuple() + throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for \ + loss functions that return non-empty `stats`.")) + end + if Lux.statelength(st_) != 0 + throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for \ + models with non-empty state `st`.")) + end + end + dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + + (; ps_cache, data_cache) = ts.cache.extras + if !FT + Lux.recursive_copyto!(ps_cache, ts.parameters) + Lux.recursive_copyto!(data_cache, data) + end + + obj_fn_wrap = first ∘ obj_fn + tape = ReverseDiff.InstructionTape() ps_tracked = Lux.recursive_map( - Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, dparams) + Lux.__Fix3(ReverseDiff.TrackedArray, tape), ps_cache, dparams) - loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data) + loss = obj_fn_wrap(ts.model, ps_tracked, ts.states, data_cache) + loss.deriv = true + ReverseDiff.reverse_pass!(tape) + + forward_executor = [ReverseDiff.FunctionWrapper{Nothing, Tuple{}}(ReverseDiff.ForwardExecutor(instruction)) + for instruction in tape] + reverse_executor = [ReverseDiff.FunctionWrapper{Nothing, Tuple{}}(ReverseDiff.ReverseExecutor(tape[i])) + for i in length(tape):-1:1] + + compiled_extras = (; + ps_cache, data_cache, forward_executor, reverse_executor, output=loss) + ts_new = TrainState( + TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, compiled_extras), + obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + + return dparams, ReverseDiff.value(loss), NamedTuple(), ts_new +end + +@inline function __compiled_reverse_diff(obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:ReverseDiff, false}, F}) where {F} + (; ps_cache, data_cache, output) = ts.cache.extras - error(1) + dparams = Lux.recursive_make_zero!!(ts.cache.dparameters) + Lux.recursive_copyto!(ps_cache, ts.parameters) + Lux.recursive_copyto!(data_cache, data) + + for wrapper in ts.cache.extras.forward_executor + wrapper() + end + output.deriv = true + for wrapper in ts.cache.extras.reverse_executor + wrapper() + end + + ts_new = TrainState( + ts.cache, obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + return dparams, ReverseDiff.value(output), NamedTuple(), ts_new end diff --git a/src/Lux.jl b/src/Lux.jl index 9d59eacac..1c60d587e 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -129,6 +129,8 @@ export MPIBackend, NCCLBackend, DistributedUtils # Unexported functions that are part of the public API @compat public Experimental @compat public xlogx, xlogy -@compat public recursive_add!!, recursive_eltype, recursive_make_zero, recursive_make_zero!! +@compat(public, + (recursive_add!!, recursive_copyto!, recursive_eltype, + recursive_make_zero, recursive_map, recursive_make_zero!!)) end diff --git a/src/contrib/training.jl b/src/contrib/training.jl index e62bbba9a..bbe1b082b 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -118,6 +118,12 @@ A 4-Tuple containing: - `stats`: Any computed statistics from the objective function. - `ts`: Updated Training State. +## Known Limitations + +- `AutoReverseDiff(; compile=true)` is not supported for Lux models with empty state + `st`. Additionally the returned stats must be empty (`NamedTuple()`). We catch these + issues in most cases and throw an error. + !!! danger "Aliased Gradients" `grads` returned by this function might be aliased by the implementation of the gradient diff --git a/src/helpers/recursive_ops.jl b/src/helpers/recursive_ops.jl index e98912326..eba1f1238 100644 --- a/src/helpers/recursive_ops.jl +++ b/src/helpers/recursive_ops.jl @@ -50,6 +50,15 @@ See also [`Lux.recursive_make_zero`](@ref) for fully out-of-place version. """ @inline recursive_make_zero!!(x) = recursive_map(__zero!!, x)::typeof(x) +""" + recursive_copyto!(x, y) + +Recursively copy the leaves of two nested structures `x` and `y`. In Functor language, this +is equivalent to doing `fmap(copyto!, x, y)`, but this implementation uses type stable code +for common cases. Note that any immutable leaf will lead to an error. +""" +@inline recursive_copyto!(x, y) = recursive_map(copyto!, x, y) + """ recursive_map(f, x, args...) diff --git a/test/contrib/training_tests.jl b/test/contrib/training_tests.jl index 3077b745f..02e6d075a 100644 --- a/test/contrib/training_tests.jl +++ b/test/contrib/training_tests.jl @@ -177,3 +177,74 @@ end @inferred Lux.Experimental.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new) end + +@testitem "Compiled ReverseDiff" setup=[SharedTestSetup] tags=[:contrib] begin + using ADTypes, Optimisers, ReverseDiff + + function mse1(model, ps, st, data) + x_data, y_data = data + y, st_ = model(x_data, ps, st) + return sum(abs2, y .- y_data), st_, (;) + end + + function mse2(model, ps, st, data) + l, st_, stats = mse1(model, ps, st, data) + return l, st_, (; data=2.0f0) + end + + rng = StableRNG(12345) + + dataset = [(randn(rng, Float32, 4, 32), randn(rng, Float32, 4, 32)) for _ in 1:100] + + @testset "Unhandled Cases" begin + model = Chain(Dense(4, 32, tanh), BatchNorm(32), + Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4)) + ps, st = Lux.setup(rng, model) + + tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + + # Stateful models are not supported + @test_throws ArgumentError Lux.Experimental.compute_gradients( + AutoReverseDiff(; compile=true), mse1, dataset[1], tstate) + + model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4)) + ps, st = Lux.setup(rng, model) + + tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + + # Loss functions that return non-empty `stats` are not supported + @test_throws ArgumentError Lux.Experimental.compute_gradients( + AutoReverseDiff(; compile=true), mse2, dataset[1], tstate) + + struct StrangeModel <: Lux.AbstractExplicitLayer end + + function (m::StrangeModel)(x, ps, st) + return x, (; new_state=0.0) + end + + model = StrangeModel() + ps, st = Lux.setup(rng, model) + + tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + + # Stateful models are not supported + @test_throws ArgumentError Lux.Experimental.compute_gradients( + AutoReverseDiff(; compile=true), mse1, dataset[1], tstate) + end + + model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4)) + ps, st = Lux.setup(rng, model) + + tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + + loss_initial = first(mse1(model, ps, st, dataset[1])) + for i in 1:100 + for (x, y) in dataset + _, _, _, tstate = Lux.Experimental.single_train_step!( + AutoReverseDiff(; compile=true), mse1, (x, y), tstate) + end + end + loss_final = first(mse1(model, tstate.parameters, tstate.states, dataset[1])) + + @test loss_final * 100 < loss_initial +end