From 8331b575cee3ff964a87b199f86770757f39e249 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Jun 2024 19:09:25 -0700 Subject: [PATCH] Finish implementation of compiled reversediff --- Project.toml | 2 +- ext/LuxEnzymeExt.jl | 31 ++++-------- ext/LuxReverseDiffExt/training.jl | 80 +++++++++++++++++++++++++++---- src/helpers/recursive_ops.jl | 9 ++++ test/contrib/training_tests.jl | 71 +++++++++++++++++++++++++++ 5 files changed, 160 insertions(+), 33 deletions(-) diff --git a/Project.toml b/Project.toml index 1cf926be3..b88bbce2e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.56" +version = "0.5.57" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index 9cedf49d7..32ef8aae5 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -7,16 +7,15 @@ using Lux.Experimental: TrainingBackendCache, TrainState # Case I: We have TrainingBackendCache{:Enzyme} and obj_fn is unchanged. function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data, - ts::TrainState{<:TrainingBackendCache{:Enzyme, FT}, F}) where {F, FT} - dps = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + ts::TrainState{<:TrainingBackendCache{:Enzyme, false}, F}) where {F} + dps = Lux.recursive_make_zero!!(ts.cache.dparameters) _, loss = Enzyme.autodiff( Enzyme.ReverseWithPrimal, ts.cache.extras.obj_fn, Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) - ts_new = __construct_new_trainstate( - ts.cache.extras.st_wrap[], ts.states, ts, obj_fn, dps, - ts.cache.extras.obj_fn, ts.cache.extras.st_wrap, ts.cache.extras.stats_wrap) + ts_new = TrainState(ts.cache, obj_fn, ts.model, ts.parameters, + ts.cache.extras.st_wrap[], ts.optimizer_state, ts.step) return dps, loss, ts.cache.extras.stats_wrap[], ts_new end @@ -33,8 +32,10 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data, Enzyme.ReverseWithPrimal, obj_fn_wrap, Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) - ts_new = __construct_new_trainstate( - st_wrap[], ts.states, ts, obj_fn, dps, obj_fn_wrap, st_wrap, stats_wrap) + cache = TrainingBackendCache{:Enzyme, false}( + dps, (; obj_fn=obj_fn_wrap, st_wrap, stats_wrap)) + ts_new = TrainState( + cache, obj_fn, ts.model, ts.parameters, st_wrap[], ts.optimizer_state, ts.step) return dps, loss, stats_wrap[], ts_new end @@ -50,20 +51,4 @@ function Lux.Experimental.compute_gradients( return Lux.Experimental.compute_gradients(ad, obj_fn, data, ts_new) end -# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not -# storing the objective function. -function __construct_new_trainstate(st_new::S, ::S, ts::TrainState, objective_fn::O, dps, - obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2} - cache = TrainingBackendCache{:Enzyme, false}(dps, (; obj_fn, st_wrap, stats_wrap)) - return TrainState( - cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) -end - -function __construct_new_trainstate(st_new, _, ts::TrainState, objective_fn::O, dps, - obj_fn::O2, st_wrap, stats_wrap) where {O, O2} - cache = TrainingBackendCache{:Enzyme, false}(dps, (; obj_fn, st_wrap, stats_wrap)) - return TrainState( - cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) -end - end diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl index ba63849bb..5d189053b 100644 --- a/ext/LuxReverseDiffExt/training.jl +++ b/ext/LuxReverseDiffExt/training.jl @@ -44,22 +44,84 @@ end # Compiled ReverseDiff @inline function __compiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} grads = Lux.recursive_make_zero(ts.parameters) - ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, nothing), obj_fn, + data_cache = deepcopy(data) + ps_cache = deepcopy(ts.parameters) + extras = (; data_cache, ps_cache) + + ts_new = TrainState(TrainingBackendCache{:ReverseDiff, true}(grads, extras), nothing, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) return __compiled_reverse_diff(obj_fn, data, ts_new) end -## Tape hasn't been compiled yet -@inline function __compiled_reverse_diff(obj_fn::F, - data, - ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT, P, Nothing}}) where { - F, FT, P} +## Tape hasn't been compiled yet / Function mismatch so recompile +@inline function __compiled_reverse_diff(obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT} + if Lux.statelength(ts.states) != 0 + throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for Lux \ + models with non-empty state `st`.")) + end + + if FT # do a dry run + _, st_, stats = obj_fn(ts.model, ts.parameters, ts.states, data) + if stats != NamedTuple() + throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for \ + loss functions that return non-empty `stats`.")) + end + if Lux.statelength(st_) != 0 + throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for \ + models with non-empty state `st`.")) + end + end + dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + + (; ps_cache, data_cache) = ts.cache.extras + if !FT + Lux.recursive_copyto!(ps_cache, ts.parameters) + Lux.recursive_copyto!(data_cache, data) + end + + obj_fn_wrap = first ∘ obj_fn + tape = ReverseDiff.InstructionTape() ps_tracked = Lux.recursive_map( - Lux.__Fix3(ReverseDiff.TrackedArray, tape), ts.parameters, dparams) + Lux.__Fix3(ReverseDiff.TrackedArray, tape), ps_cache, dparams) - loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data) + loss = obj_fn_wrap(ts.model, ps_tracked, ts.states, data_cache) + loss.deriv = true + ReverseDiff.reverse_pass!(tape) + + forward_executor = [ReverseDiff.FunctionWrapper{Nothing, Tuple{}}(ReverseDiff.ForwardExecutor(instruction)) + for instruction in tape] + reverse_executor = [ReverseDiff.FunctionWrapper{Nothing, Tuple{}}(ReverseDiff.ReverseExecutor(tape[i])) + for i in length(tape):-1:1] + + compiled_extras = (; + ps_cache, data_cache, forward_executor, reverse_executor, output=loss) + ts_new = TrainState( + TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, compiled_extras), + obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + + return dparams, ReverseDiff.value(loss), NamedTuple(), ts_new +end + +@inline function __compiled_reverse_diff(obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:ReverseDiff, false}, F}) where {F} + (; ps_cache, data_cache, output) = ts.cache.extras - error(1) + dparams = Lux.recursive_make_zero!!(ts.cache.dparameters) + Lux.recursive_copyto!(ps_cache, ts.parameters) + Lux.recursive_copyto!(data_cache, data) + + for wrapper in ts.cache.extras.forward_executor + wrapper() + end + output.deriv = true + for wrapper in ts.cache.extras.reverse_executor + wrapper() + end + + ts_new = TrainState( + ts.cache, obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + return dparams, ReverseDiff.value(output), NamedTuple(), ts_new end diff --git a/src/helpers/recursive_ops.jl b/src/helpers/recursive_ops.jl index e98912326..eba1f1238 100644 --- a/src/helpers/recursive_ops.jl +++ b/src/helpers/recursive_ops.jl @@ -50,6 +50,15 @@ See also [`Lux.recursive_make_zero`](@ref) for fully out-of-place version. """ @inline recursive_make_zero!!(x) = recursive_map(__zero!!, x)::typeof(x) +""" + recursive_copyto!(x, y) + +Recursively copy the leaves of two nested structures `x` and `y`. In Functor language, this +is equivalent to doing `fmap(copyto!, x, y)`, but this implementation uses type stable code +for common cases. Note that any immutable leaf will lead to an error. +""" +@inline recursive_copyto!(x, y) = recursive_map(copyto!, x, y) + """ recursive_map(f, x, args...) diff --git a/test/contrib/training_tests.jl b/test/contrib/training_tests.jl index 3077b745f..02e6d075a 100644 --- a/test/contrib/training_tests.jl +++ b/test/contrib/training_tests.jl @@ -177,3 +177,74 @@ end @inferred Lux.Experimental.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new) end + +@testitem "Compiled ReverseDiff" setup=[SharedTestSetup] tags=[:contrib] begin + using ADTypes, Optimisers, ReverseDiff + + function mse1(model, ps, st, data) + x_data, y_data = data + y, st_ = model(x_data, ps, st) + return sum(abs2, y .- y_data), st_, (;) + end + + function mse2(model, ps, st, data) + l, st_, stats = mse1(model, ps, st, data) + return l, st_, (; data=2.0f0) + end + + rng = StableRNG(12345) + + dataset = [(randn(rng, Float32, 4, 32), randn(rng, Float32, 4, 32)) for _ in 1:100] + + @testset "Unhandled Cases" begin + model = Chain(Dense(4, 32, tanh), BatchNorm(32), + Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4)) + ps, st = Lux.setup(rng, model) + + tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + + # Stateful models are not supported + @test_throws ArgumentError Lux.Experimental.compute_gradients( + AutoReverseDiff(; compile=true), mse1, dataset[1], tstate) + + model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4)) + ps, st = Lux.setup(rng, model) + + tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + + # Loss functions that return non-empty `stats` are not supported + @test_throws ArgumentError Lux.Experimental.compute_gradients( + AutoReverseDiff(; compile=true), mse2, dataset[1], tstate) + + struct StrangeModel <: Lux.AbstractExplicitLayer end + + function (m::StrangeModel)(x, ps, st) + return x, (; new_state=0.0) + end + + model = StrangeModel() + ps, st = Lux.setup(rng, model) + + tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + + # Stateful models are not supported + @test_throws ArgumentError Lux.Experimental.compute_gradients( + AutoReverseDiff(; compile=true), mse1, dataset[1], tstate) + end + + model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4)) + ps, st = Lux.setup(rng, model) + + tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + + loss_initial = first(mse1(model, ps, st, dataset[1])) + for i in 1:100 + for (x, y) in dataset + _, _, _, tstate = Lux.Experimental.single_train_step!( + AutoReverseDiff(; compile=true), mse1, (x, y), tstate) + end + end + loss_final = first(mse1(model, tstate.parameters, tstate.states, dataset[1])) + + @test loss_final * 100 < loss_initial +end