Skip to content

Commit

Permalink
Finish implementation of compiled reversediff
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 23, 2024
1 parent 66a5cda commit 1da2e52
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.5.56"
version = "0.5.57"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api/Lux/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Lux.xlogx
```@docs
Lux.recursive_map
Lux.recursive_add!!
Lux.recursive_copyto!
Lux.recursive_eltype
Lux.recursive_make_zero
Lux.recursive_make_zero!!
Expand Down
31 changes: 8 additions & 23 deletions ext/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@ using Lux.Experimental: TrainingBackendCache, TrainState

# Case I: We have TrainingBackendCache{:Enzyme} and obj_fn is unchanged.
function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:Enzyme, FT}, F}) where {F, FT}
dps = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)
ts::TrainState{<:TrainingBackendCache{:Enzyme, false}, F}) where {F}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)

_, loss = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, ts.cache.extras.obj_fn, Active, Const(ts.model),
Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

ts_new = __construct_new_trainstate(
ts.cache.extras.st_wrap[], ts.states, ts, obj_fn, dps,
ts.cache.extras.obj_fn, ts.cache.extras.st_wrap, ts.cache.extras.stats_wrap)
ts_new = TrainState(ts.cache, obj_fn, ts.model, ts.parameters,
ts.cache.extras.st_wrap[], ts.optimizer_state, ts.step)

return dps, loss, ts.cache.extras.stats_wrap[], ts_new
end
Expand All @@ -33,8 +32,10 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data,
Enzyme.ReverseWithPrimal, obj_fn_wrap, Active, Const(ts.model),
Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

ts_new = __construct_new_trainstate(
st_wrap[], ts.states, ts, obj_fn, dps, obj_fn_wrap, st_wrap, stats_wrap)
cache = TrainingBackendCache{:Enzyme, false}(
dps, (; obj_fn=obj_fn_wrap, st_wrap, stats_wrap))
ts_new = TrainState(
cache, obj_fn, ts.model, ts.parameters, st_wrap[], ts.optimizer_state, ts.step)

return dps, loss, stats_wrap[], ts_new
end
Expand All @@ -50,20 +51,4 @@ function Lux.Experimental.compute_gradients(
return Lux.Experimental.compute_gradients(ad, obj_fn, data, ts_new)
end

# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not
# storing the objective function.
function __construct_new_trainstate(st_new::S, ::S, ts::TrainState, objective_fn::O, dps,
obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2}
cache = TrainingBackendCache{:Enzyme, false}(dps, (; obj_fn, st_wrap, stats_wrap))
return TrainState(
cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step)
end

function __construct_new_trainstate(st_new, _, ts::TrainState, objective_fn::O, dps,
obj_fn::O2, st_wrap, stats_wrap) where {O, O2}
cache = TrainingBackendCache{:Enzyme, false}(dps, (; obj_fn, st_wrap, stats_wrap))
return TrainState(
cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step)
end

end
80 changes: 71 additions & 9 deletions ext/LuxReverseDiffExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,84 @@ end
# Compiled ReverseDiff
@inline function __compiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, nothing), obj_fn,
data_cache = deepcopy(data)
ps_cache = deepcopy(ts.parameters)
extras = (; data_cache, ps_cache)

ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, extras), nothing,
ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step)
return __compiled_reverse_diff(obj_fn, data, ts_new)
end

## Tape hasn't been compiled yet
@inline function __compiled_reverse_diff(obj_fn::F,
data,
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT, P, Nothing}}) where {
F, FT, P}
## Tape hasn't been compiled yet / Function mismatch so recompile
@inline function __compiled_reverse_diff(obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT}
if Lux.statelength(ts.states) != 0
throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for Lux \
models with non-empty state `st`."))
end

if FT # do a dry run
_, st_, stats = obj_fn(ts.model, ts.parameters, ts.states, data)
if stats != NamedTuple()
throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for \
loss functions that return non-empty `stats`."))
end
if Lux.statelength(st_) != 0
throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for \
models with non-empty state `st`."))
end
end

dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)

(; ps_cache, data_cache) = ts.cache.extras
if !FT
Lux.recursive_copyto!(ps_cache, ts.parameters)
Lux.recursive_copyto!(data_cache, data)
end

obj_fn_wrap = first obj_fn

tape = ReverseDiff.InstructionTape()
ps_tracked = Lux.recursive_map(
Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, dparams)
Lux.__Fix3(ReverseDiff.TrackedArray, tape), ps_cache, dparams)

loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data)
loss = obj_fn_wrap(ts.model, ps_tracked, ts.states, data_cache)
loss.deriv = true
ReverseDiff.reverse_pass!(tape)

forward_executor = [ReverseDiff.FunctionWrapper{Nothing, Tuple{}}(ReverseDiff.ForwardExecutor(instruction))
for instruction in tape]
reverse_executor = [ReverseDiff.FunctionWrapper{Nothing, Tuple{}}(ReverseDiff.ReverseExecutor(tape[i]))
for i in length(tape):-1:1]

compiled_extras = (;
ps_cache, data_cache, forward_executor, reverse_executor, output=loss)
ts_new = TrainState(
TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, compiled_extras),
obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step)

return dparams, ReverseDiff.value(loss), NamedTuple(), ts_new
end

@inline function __compiled_reverse_diff(obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, false}, F}) where {F}
(; ps_cache, data_cache, output) = ts.cache.extras

error(1)
dparams = Lux.recursive_make_zero!!(ts.cache.dparameters)
Lux.recursive_copyto!(ps_cache, ts.parameters)
Lux.recursive_copyto!(data_cache, data)

for wrapper in ts.cache.extras.forward_executor
wrapper()
end
output.deriv = true
for wrapper in ts.cache.extras.reverse_executor
wrapper()
end

ts_new = TrainState(
ts.cache, obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step)
return dparams, ReverseDiff.value(output), NamedTuple(), ts_new
end
4 changes: 3 additions & 1 deletion src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ export MPIBackend, NCCLBackend, DistributedUtils
# Unexported functions that are part of the public API
@compat public Experimental
@compat public xlogx, xlogy
@compat public recursive_add!!, recursive_eltype, recursive_make_zero, recursive_make_zero!!
@compat(public,
(recursive_add!!, recursive_copyto!, recursive_eltype,
recursive_make_zero, recursive_map, recursive_make_zero!!))

end
6 changes: 6 additions & 0 deletions src/contrib/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ A 4-Tuple containing:
- `stats`: Any computed statistics from the objective function.
- `ts`: Updated Training State.
## Known Limitations
- `AutoReverseDiff(; compile=true)` is not supported for Lux models with empty state
`st`. Additionally the returned stats must be empty (`NamedTuple()`). We catch these
issues in most cases and throw an error.
!!! danger "Aliased Gradients"
`grads` returned by this function might be aliased by the implementation of the gradient
Expand Down
9 changes: 9 additions & 0 deletions src/helpers/recursive_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ See also [`Lux.recursive_make_zero`](@ref) for fully out-of-place version.
"""
@inline recursive_make_zero!!(x) = recursive_map(__zero!!, x)::typeof(x)

"""
recursive_copyto!(x, y)
Recursively copy the leaves of two nested structures `x` and `y`. In Functor language, this
is equivalent to doing `fmap(copyto!, x, y)`, but this implementation uses type stable code
for common cases. Note that any immutable leaf will lead to an error.
"""
@inline recursive_copyto!(x, y) = recursive_map(copyto!, x, y)

"""
recursive_map(f, x, args...)
Expand Down
71 changes: 71 additions & 0 deletions test/contrib/training_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,74 @@ end

@inferred Lux.Experimental.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)
end

@testitem "Compiled ReverseDiff" setup=[SharedTestSetup] tags=[:contrib] begin
using ADTypes, Optimisers, ReverseDiff

function mse1(model, ps, st, data)
x_data, y_data = data
y, st_ = model(x_data, ps, st)
return sum(abs2, y .- y_data), st_, (;)
end

function mse2(model, ps, st, data)
l, st_, stats = mse1(model, ps, st, data)
return l, st_, (; data=2.0f0)
end

rng = StableRNG(12345)

dataset = [(randn(rng, Float32, 4, 32), randn(rng, Float32, 4, 32)) for _ in 1:100]

@testset "Unhandled Cases" begin
model = Chain(Dense(4, 32, tanh), BatchNorm(32),
Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4))
ps, st = Lux.setup(rng, model)

tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0))

# Stateful models are not supported
@test_throws ArgumentError Lux.Experimental.compute_gradients(
AutoReverseDiff(; compile=true), mse1, dataset[1], tstate)

model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4))
ps, st = Lux.setup(rng, model)

tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0))

# Loss functions that return non-empty `stats` are not supported
@test_throws ArgumentError Lux.Experimental.compute_gradients(
AutoReverseDiff(; compile=true), mse2, dataset[1], tstate)

struct StrangeModel <: Lux.AbstractExplicitLayer end

function (m::StrangeModel)(x, ps, st)
return x, (; new_state=0.0)
end

model = StrangeModel()
ps, st = Lux.setup(rng, model)

tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0))

# Stateful models are not supported
@test_throws ArgumentError Lux.Experimental.compute_gradients(
AutoReverseDiff(; compile=true), mse1, dataset[1], tstate)
end

model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4))
ps, st = Lux.setup(rng, model)

tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0))

loss_initial = first(mse1(model, ps, st, dataset[1]))
for i in 1:100
for (x, y) in dataset
_, _, _, tstate = Lux.Experimental.single_train_step!(
AutoReverseDiff(; compile=true), mse1, (x, y), tstate)
end
end
loss_final = first(mse1(model, tstate.parameters, tstate.states, dataset[1]))

@test loss_final * 100 < loss_initial
end

0 comments on commit 1da2e52

Please sign in to comment.