From 7c842fd713d7d52e038c8969a55b1fd0c4e08fc3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 8 Oct 2024 16:27:34 -0400 Subject: [PATCH] fix: reactant GPU support --- ext/LuxReactantExt/training.jl | 35 +++++++++------------ src/helpers/training.jl | 10 ++++-- test/reactant/training_tests.jl | 56 ++++++++++++++++++++++++++++++++- 3 files changed, 77 insertions(+), 24 deletions(-) diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index 874742f83..182ca9c86 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -1,15 +1,13 @@ function Lux.Training.compute_gradients_impl( backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState) where {F} - dps = Lux.recursive_make_zero(ts.parameters) - compiled_gradient_function = @compile compute_gradients_internal( - objective_function, ts.model, data, ts.parameters, dps, ts.states) + objective_function, ts.model, data, ts.parameters, ts.states) grads, loss, stats, st = compiled_gradient_function( - objective_function, ts.model, data, ts.parameters, dps, ts.states) + objective_function, ts.model, data, ts.parameters, ts.states) - cache = TrainingBackendCache(backend, False(), dps, (; compiled_gradient_function)) + cache = TrainingBackendCache(backend, False(), nothing, (; compiled_gradient_function)) @set! ts.cache = cache @set! ts.objective_function = objective_function @set! ts.states = st @@ -18,17 +16,14 @@ end function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data, ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F} - dps = Lux.recursive_make_zero!!(ts.cache.dparameters) - grads, loss, stats, st = ts.cache.extras.compiled_gradient_function( - obj_fn, ts.model, data, ts.parameters, dps, ts.states) - + obj_fn, ts.model, data, ts.parameters, ts.states) @set! ts.states = st return grads, loss, stats, ts end -function compute_gradients_internal( - objective_function::F, model, data, ps, dps, st) where {F} +function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F} + dps = Enzyme.make_zero(ps) _, (loss, stₙ, stats) = Enzyme.autodiff( Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model), Duplicated(ps, dps), Const(st), Const(data)) @@ -41,18 +36,16 @@ for inplace in ("!", "") @eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState) where {F} - dps = Lux.recursive_make_zero(ts.parameters) - compiled_grad_and_step_function = @compile $(internal_fn)( - objective_function, ts.model, data, ts.parameters, dps, ts.states, + objective_function, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function( - objective_function, ts.model, data, ts.parameters, dps, ts.states, + objective_function, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) cache = TrainingBackendCache( - backend, False(), dps, (; compiled_grad_and_step_function)) + backend, False(), nothing, (; compiled_grad_and_step_function)) @set! ts.cache = cache @set! ts.objective_function = objective_function @set! ts.states = st @@ -65,10 +58,8 @@ for inplace in ("!", "") @eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data, ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F} - dps = Lux.recursive_make_zero!!(ts.cache.dparameters) - grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function( - obj_fn, ts.model, data, ts.parameters, dps, ts.states, ts.optimizer_state) + obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) @set! ts.states = st @set! ts.parameters = ps @@ -79,8 +70,9 @@ for inplace in ("!", "") end end -function compute_gradients_internal_and_step(objective_function::F, model, data, ps, dps, +function compute_gradients_internal_and_step(objective_function::F, model, data, ps, st, opt_state) where {F} + dps = Enzyme.make_zero(ps) _, (loss, stₙ, stats) = Enzyme.autodiff( Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model), Duplicated(ps, dps), Const(st), Const(data)) @@ -88,8 +80,9 @@ function compute_gradients_internal_and_step(objective_function::F, model, data, return dps, ps, loss, stats, stₙ, opt_state end -function compute_gradients_internal_and_step!(objective_function::F, model, data, ps, dps, +function compute_gradients_internal_and_step!(objective_function::F, model, data, ps, st, opt_state) where {F} + dps = Enzyme.make_zero(ps) _, (loss, stₙ, stats) = Enzyme.autodiff( Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model), Duplicated(ps, dps), Const(st), Const(data)) diff --git a/src/helpers/training.jl b/src/helpers/training.jl index d495dbf4c..c0e6644ff 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -10,7 +10,7 @@ using Static: StaticBool, Static, False, True using ..Lux: Lux using LuxCore: LuxCore, AbstractLuxLayer -using MLDataDevices: XLADevice, get_device_type +using MLDataDevices: XLADevice, get_device_type, get_device, cpu_device """ TrainState @@ -62,7 +62,13 @@ Constructor for [`TrainState`](@ref). [`TrainState`](@ref) object. """ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule) - st_opt = Optimisers.setup(optimizer, ps) + dev = get_device(ps) + st_opt = if dev isa XLADevice + ps_cpu = ps |> cpu_device() + Optimisers.setup(optimizer, ps_cpu) |> dev + else + Optimisers.setup(optimizer, ps) + end return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0) end diff --git a/test/reactant/training_tests.jl b/test/reactant/training_tests.jl index 54d24b3ad..427f3a171 100644 --- a/test/reactant/training_tests.jl +++ b/test/reactant/training_tests.jl @@ -1,5 +1,5 @@ @testitem "Reactant: Training API" tags=[:reactant] setup=[SharedTestSetup] begin - using Reactant + using Reactant, Optimisers @testset "$(mode)" for (mode, atype, dev, ongpu) in MODES if mode == "amdgpu" @@ -12,5 +12,59 @@ else Reactant.set_default_backend("cpu") end + + # TODO: Test for compute_gradients + + xdev = xla_device(; force=true) + + @testset "MLP Training: $(version)" for version in (:iip, :oop) + model = Chain( + Dense(2 => 32, gelu), + Dense(32 => 32, gelu), + Dense(32 => 2) + ) + ps, st = Lux.setup(StableRNG(1234), model) |> xdev + + x_ra = randn(Float32, 2, 32) |> xdev + + inference_fn = @compile model(x_ra, ps, Lux.testmode(st)) + + x = [rand(Float32, 2, 32) for _ in 1:32] + y = [xᵢ .^ 2 for xᵢ in x] + + dataloader = DeviceIterator(xdev, zip(x, y)) + + total_initial_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ) + ŷᵢ, _ = inference_fn(xᵢ, ps, Lux.testmode(st)) + return MSELoss()(ŷᵢ, yᵢ) + end + + train_state = Training.TrainState(model, ps, st, Adam(0.01f0)) + + # FIXME: Use MSELoss <-- currently fails due to Enzyme + function sse(model, ps, st, (x, y)) + z, stₙ = model(x, ps, st) + return sum(abs2, z .- y), stₙ, (;) + end + + for epoch in 1:100, (xᵢ, yᵢ) in dataloader + grads, loss, stats, train_state = if version === :iip + Training.single_train_step!(AutoEnzyme(), sse, (xᵢ, yᵢ), train_state) + elseif version === :oop + Training.single_train_step(AutoEnzyme(), sse, (xᵢ, yᵢ), train_state) + else + error("Invalid version: $(version)") + end + end + + total_final_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ) + ŷᵢ, _ = inference_fn(xᵢ, train_state.parameters, Lux.testmode(st)) + return MSELoss()(ŷᵢ, yᵢ) + end + + @test total_final_loss < 100 * total_initial_loss + end + + # TODO: Training a CNN end end