Skip to content

Commit

Permalink
fix: reactant GPU support
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 9, 2024
1 parent fcbf03a commit 7c842fd
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 24 deletions.
35 changes: 14 additions & 21 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
function Lux.Training.compute_gradients_impl(
backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
dps = Lux.recursive_make_zero(ts.parameters)

compiled_gradient_function = @compile compute_gradients_internal(
objective_function, ts.model, data, ts.parameters, dps, ts.states)
objective_function, ts.model, data, ts.parameters, ts.states)

grads, loss, stats, st = compiled_gradient_function(
objective_function, ts.model, data, ts.parameters, dps, ts.states)
objective_function, ts.model, data, ts.parameters, ts.states)

cache = TrainingBackendCache(backend, False(), dps, (; compiled_gradient_function))
cache = TrainingBackendCache(backend, False(), nothing, (; compiled_gradient_function))
@set! ts.cache = cache
@set! ts.objective_function = objective_function
@set! ts.states = st
Expand All @@ -18,17 +16,14 @@ end

function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)

grads, loss, stats, st = ts.cache.extras.compiled_gradient_function(
obj_fn, ts.model, data, ts.parameters, dps, ts.states)

obj_fn, ts.model, data, ts.parameters, ts.states)
@set! ts.states = st
return grads, loss, stats, ts
end

function compute_gradients_internal(
objective_function::F, model, data, ps, dps, st) where {F}
function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
Expand All @@ -41,18 +36,16 @@ for inplace in ("!", "")

@eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
dps = Lux.recursive_make_zero(ts.parameters)

compiled_grad_and_step_function = @compile $(internal_fn)(
objective_function, ts.model, data, ts.parameters, dps, ts.states,
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
objective_function, ts.model, data, ts.parameters, dps, ts.states,
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

cache = TrainingBackendCache(
backend, False(), dps, (; compiled_grad_and_step_function))
backend, False(), nothing, (; compiled_grad_and_step_function))
@set! ts.cache = cache
@set! ts.objective_function = objective_function
@set! ts.states = st
Expand All @@ -65,10 +58,8 @@ for inplace in ("!", "")

@eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)

grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function(
obj_fn, ts.model, data, ts.parameters, dps, ts.states, ts.optimizer_state)
obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state)

@set! ts.states = st
@set! ts.parameters = ps
Expand All @@ -79,17 +70,19 @@ for inplace in ("!", "")
end
end

function compute_gradients_internal_and_step(objective_function::F, model, data, ps, dps,
function compute_gradients_internal_and_step(objective_function::F, model, data, ps,
st, opt_state) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
opt_state, ps = Optimisers.update(opt_state, ps, dps)
return dps, ps, loss, stats, stₙ, opt_state
end

function compute_gradients_internal_and_step!(objective_function::F, model, data, ps, dps,
function compute_gradients_internal_and_step!(objective_function::F, model, data, ps,
st, opt_state) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
Expand Down
10 changes: 8 additions & 2 deletions src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Static: StaticBool, Static, False, True

using ..Lux: Lux
using LuxCore: LuxCore, AbstractLuxLayer
using MLDataDevices: XLADevice, get_device_type
using MLDataDevices: XLADevice, get_device_type, get_device, cpu_device

"""
TrainState
Expand Down Expand Up @@ -62,7 +62,13 @@ Constructor for [`TrainState`](@ref).
[`TrainState`](@ref) object.
"""
function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule)
st_opt = Optimisers.setup(optimizer, ps)
dev = get_device(ps)
st_opt = if dev isa XLADevice
ps_cpu = ps |> cpu_device()
Optimisers.setup(optimizer, ps_cpu) |> dev
else
Optimisers.setup(optimizer, ps)
end
return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0)
end

Expand Down
56 changes: 55 additions & 1 deletion test/reactant/training_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testitem "Reactant: Training API" tags=[:reactant] setup=[SharedTestSetup] begin
using Reactant
using Reactant, Optimisers

@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
if mode == "amdgpu"
Expand All @@ -12,5 +12,59 @@
else
Reactant.set_default_backend("cpu")
end

# TODO: Test for compute_gradients

xdev = xla_device(; force=true)

@testset "MLP Training: $(version)" for version in (:iip, :oop)
model = Chain(
Dense(2 => 32, gelu),
Dense(32 => 32, gelu),
Dense(32 => 2)
)
ps, st = Lux.setup(StableRNG(1234), model) |> xdev

x_ra = randn(Float32, 2, 32) |> xdev

inference_fn = @compile model(x_ra, ps, Lux.testmode(st))

x = [rand(Float32, 2, 32) for _ in 1:32]
y = [xᵢ .^ 2 for xᵢ in x]

dataloader = DeviceIterator(xdev, zip(x, y))

total_initial_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ)
ŷᵢ, _ = inference_fn(xᵢ, ps, Lux.testmode(st))
return MSELoss()(ŷᵢ, yᵢ)
end

train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

# FIXME: Use MSELoss <-- currently fails due to Enzyme
function sse(model, ps, st, (x, y))
z, stₙ = model(x, ps, st)
return sum(abs2, z .- y), stₙ, (;)
end

for epoch in 1:100, (xᵢ, yᵢ) in dataloader
grads, loss, stats, train_state = if version === :iip
Training.single_train_step!(AutoEnzyme(), sse, (xᵢ, yᵢ), train_state)
elseif version === :oop
Training.single_train_step(AutoEnzyme(), sse, (xᵢ, yᵢ), train_state)
else
error("Invalid version: $(version)")
end
end

total_final_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ)
ŷᵢ, _ = inference_fn(xᵢ, train_state.parameters, Lux.testmode(st))
return MSELoss()(ŷᵢ, yᵢ)
end

@test total_final_loss < 100 * total_initial_loss
end

# TODO: Training a CNN
end
end

0 comments on commit 7c842fd

Please sign in to comment.