diff --git a/ext/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt.jl deleted file mode 100644 index 4af247d24..000000000 --- a/ext/LuxReverseDiffExt.jl +++ /dev/null @@ -1,53 +0,0 @@ -module LuxReverseDiffExt - -using ADTypes: AutoReverseDiff -using ArrayInterface: ArrayInterface -using Functors: fmap -using Lux: Lux, LuxCPUDevice -using ReverseDiff: ReverseDiff, TrackedArray, @grad_from_chainrules -using Setfield: @set! - -function Lux.Experimental.compute_gradients(::AutoReverseDiff, objective_function::F, data, - ts::Lux.Experimental.TrainState) where {F} - tape = ReverseDiff.InstructionTape() - grads = fmap(zero, ts.parameters) - ps_tracked = fmap((p, g) -> ReverseDiff.TrackedArray(p, g, tape), ts.parameters, grads) - loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) - loss.deriv = true - ReverseDiff.reverse_pass!(tape) - @set! ts.states = st - return grads, ReverseDiff.value(loss), stats, ts -end - -# AoS to SoA conversion -function Lux.apply( - m::Lux.AbstractExplicitLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) - @warn "Lux.apply(m::Lux.AbstractExplicitLayer, \ - x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \ - Lux.apply(m::Lux.AbstractExplicitLayer, x::ReverseDiff.TrackedArray}, ps, \ - st).\n\n\ - 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ - 2. This might have performance implications. Check which layer was causing this \ - problem using `Lux.Experimental.@debug_mode`." maxlog=1 - return Lux.apply(m, reshape(ArrayInterface.aos_to_soa(x), size(x)), ps, st) -end - -## Prevent an infinite loop -Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) - -# SimpleChains.jl -@grad_from_chainrules Lux.__apply_simple_chain(layer, x::TrackedArray, ps, ::LuxCPUDevice) -@grad_from_chainrules Lux.__apply_simple_chain(layer, x, ps::TrackedArray, ::LuxCPUDevice) -@grad_from_chainrules Lux.__apply_simple_chain( - layer, x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice) - -# DynamicExpressions.jl -@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr, - operator_enum, x::TrackedArray, ps, ::LuxCPUDevice) -@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr, - operator_enum, x, ps::TrackedArray, ::LuxCPUDevice) -@grad_from_chainrules Lux.__apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, operator_enum, - x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice) - -end diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl new file mode 100644 index 000000000..706b24b90 --- /dev/null +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -0,0 +1,29 @@ +module LuxReverseDiffExt + +using ADTypes: AutoReverseDiff +using ArrayInterface: ArrayInterface +using Functors: fmap +using Lux: Lux, LuxCPUDevice +using ReverseDiff: ReverseDiff, TrackedArray, @grad_from_chainrules +using Setfield: @set! + +# AoS to SoA conversion +function Lux.apply( + m::Lux.AbstractExplicitLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) + @warn "Lux.apply(m::Lux.AbstractExplicitLayer, \ + x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \ + Lux.apply(m::Lux.AbstractExplicitLayer, x::ReverseDiff.TrackedArray}, ps, \ + st).\n\n\ + 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ + 2. This might have performance implications. Check which layer was causing this \ + problem using `Lux.Experimental.@debug_mode`." maxlog=1 + return Lux.apply(m, reshape(ArrayInterface.aos_to_soa(x), size(x)), ps, st) +end + +## Prevent an infinite loop +Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) + +include("rules.jl") +include("training.jl") + +end diff --git a/ext/LuxReverseDiffExt/rules.jl b/ext/LuxReverseDiffExt/rules.jl new file mode 100644 index 000000000..122df5ab8 --- /dev/null +++ b/ext/LuxReverseDiffExt/rules.jl @@ -0,0 +1,14 @@ +# SimpleChains.jl +@grad_from_chainrules Lux.__apply_simple_chain(layer, x::TrackedArray, ps, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_simple_chain(layer, x, ps::TrackedArray, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_simple_chain( + layer, x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice) + +# DynamicExpressions.jl +@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr, + operator_enum, x::TrackedArray, ps, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr, + operator_enum, x, ps::TrackedArray, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_dynamic_expression( + de::Lux.DynamicExpressionsLayer, expr, operator_enum, + x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice) diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl new file mode 100644 index 000000000..6e56f899a --- /dev/null +++ b/ext/LuxReverseDiffExt/training.jl @@ -0,0 +1,11 @@ +function Lux.Experimental.compute_gradients(::AutoReverseDiff, objective_function::F, data, + ts::Lux.Experimental.TrainState) where {F} + tape = ReverseDiff.InstructionTape() + grads = fmap(zero, ts.parameters) + ps_tracked = fmap((p, g) -> ReverseDiff.TrackedArray(p, g, tape), ts.parameters, grads) + loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) + loss.deriv = true + ReverseDiff.reverse_pass!(tape) + @set! ts.states = st + return grads, ReverseDiff.value(loss), stats, ts +end diff --git a/src/contrib/training.jl b/src/contrib/training.jl index 8dd65acc7..a6df98eb2 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -79,12 +79,12 @@ Compute the gradients of the objective function wrt parameters stored in `ts`. ## Backends & AD Packages -| Supported Backends | Packages Needed | -|:------------------ |:---------------- | -| `AutoZygote` | `Zygote.jl` | -| `AutoReverseDiff` | `ReverseDiff.jl` | -| `AutoTracker` | `Tracker.jl` | -| `AutoEnzyme` | `Enzyme.jl` | +| Supported Backends | Packages Needed | +|:---------------------------- |:---------------- | +| `AutoZygote` | `Zygote.jl` | +| `AutoReverseDiff(; compile)` | `ReverseDiff.jl` | +| `AutoTracker` | `Tracker.jl` | +| `AutoEnzyme` | `Enzyme.jl` | ## Arguments @@ -105,14 +105,7 @@ A 4-Tuple containing: - `stats`: Any computed statistics from the objective function. - `ts`: Updated Training State. -## Special Notes on Backends - - - `AutoEnzyme`: `mode` is always ignored and Enzyme ReverseMode is used. The first call - to `compute_gradients` will be type-unstable. It is recommended to call this function - once outside of the training loop and use the returned train_state for type stability. - - `AutoReverseDiff`: `compile` is always ignored and the gradient tape is never compiled. - -!!! danger +!!! danger "Aliased Gradients" `grads` returned by this function might be aliased by the implementation of the gradient backend. For example, if you cache the `grads` from step `i`, the new gradients