Skip to content

Commit

Permalink
Move things a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 22, 2024
1 parent 66bd131 commit 6b81817
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 67 deletions.
53 changes: 0 additions & 53 deletions ext/LuxReverseDiffExt.jl

This file was deleted.

29 changes: 29 additions & 0 deletions ext/LuxReverseDiffExt/LuxReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
module LuxReverseDiffExt

using ADTypes: AutoReverseDiff
using ArrayInterface: ArrayInterface
using Functors: fmap
using Lux: Lux, LuxCPUDevice
using ReverseDiff: ReverseDiff, TrackedArray, @grad_from_chainrules
using Setfield: @set!

# AoS to SoA conversion
function Lux.apply(
m::Lux.AbstractExplicitLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st)
@warn "Lux.apply(m::Lux.AbstractExplicitLayer, \
x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \
Lux.apply(m::Lux.AbstractExplicitLayer, x::ReverseDiff.TrackedArray}, ps, \
st).\n\n\
1. If this was not the desired behavior overload the dispatch on `m`.\n\n\
2. This might have performance implications. Check which layer was causing this \
problem using `Lux.Experimental.@debug_mode`." maxlog=1
return Lux.apply(m, reshape(ArrayInterface.aos_to_soa(x), size(x)), ps, st)
end

## Prevent an infinite loop
Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st)

include("rules.jl")
include("training.jl")

end
14 changes: 14 additions & 0 deletions ext/LuxReverseDiffExt/rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SimpleChains.jl
@grad_from_chainrules Lux.__apply_simple_chain(layer, x::TrackedArray, ps, ::LuxCPUDevice)
@grad_from_chainrules Lux.__apply_simple_chain(layer, x, ps::TrackedArray, ::LuxCPUDevice)
@grad_from_chainrules Lux.__apply_simple_chain(
layer, x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice)

# DynamicExpressions.jl
@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr,
operator_enum, x::TrackedArray, ps, ::LuxCPUDevice)
@grad_from_chainrules Lux.__apply_dynamic_expression(de::Lux.DynamicExpressionsLayer, expr,
operator_enum, x, ps::TrackedArray, ::LuxCPUDevice)
@grad_from_chainrules Lux.__apply_dynamic_expression(
de::Lux.DynamicExpressionsLayer, expr, operator_enum,
x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice)
11 changes: 11 additions & 0 deletions ext/LuxReverseDiffExt/training.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function Lux.Experimental.compute_gradients(::AutoReverseDiff, objective_function::F, data,
ts::Lux.Experimental.TrainState) where {F}
tape = ReverseDiff.InstructionTape()
grads = fmap(zero, ts.parameters)
ps_tracked = fmap((p, g) -> ReverseDiff.TrackedArray(p, g, tape), ts.parameters, grads)
loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data)
loss.deriv = true
ReverseDiff.reverse_pass!(tape)
@set! ts.states = st
return grads, ReverseDiff.value(loss), stats, ts
end
21 changes: 7 additions & 14 deletions src/contrib/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ Compute the gradients of the objective function wrt parameters stored in `ts`.
## Backends & AD Packages
| Supported Backends | Packages Needed |
|:------------------ |:---------------- |
| `AutoZygote` | `Zygote.jl` |
| `AutoReverseDiff` | `ReverseDiff.jl` |
| `AutoTracker` | `Tracker.jl` |
| `AutoEnzyme` | `Enzyme.jl` |
| Supported Backends | Packages Needed |
|:---------------------------- |:---------------- |
| `AutoZygote` | `Zygote.jl` |
| `AutoReverseDiff(; compile)` | `ReverseDiff.jl` |
| `AutoTracker` | `Tracker.jl` |
| `AutoEnzyme` | `Enzyme.jl` |
## Arguments
Expand All @@ -105,14 +105,7 @@ A 4-Tuple containing:
- `stats`: Any computed statistics from the objective function.
- `ts`: Updated Training State.
## Special Notes on Backends
- `AutoEnzyme`: `mode` is always ignored and Enzyme ReverseMode is used. The first call
to `compute_gradients` will be type-unstable. It is recommended to call this function
once outside of the training loop and use the returned train_state for type stability.
- `AutoReverseDiff`: `compile` is always ignored and the gradient tape is never compiled.
!!! danger
!!! danger "Aliased Gradients"
`grads` returned by this function might be aliased by the implementation of the gradient
backend. For example, if you cache the `grads` from step `i`, the new gradients
Expand Down

0 comments on commit 6b81817

Please sign in to comment.