From 35c76488f2fb503c3e0b49a37e6cedd23af4c2da Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 May 2024 00:26:46 -0400 Subject: [PATCH] renmae apply_gradients --- docs/src/api/Lux/contrib.md | 1 + examples/HyperNet/main.jl | 2 +- examples/PolynomialFitting/main.jl | 2 +- examples/SimpleChains/main.jl | 2 +- examples/SimpleRNN/main.jl | 2 +- ext/LuxOptimisersExt.jl | 23 +++++++++++------------ src/contrib/training.jl | 19 ++++++++++++++++++- 7 files changed, 34 insertions(+), 17 deletions(-) diff --git a/docs/src/api/Lux/contrib.md b/docs/src/api/Lux/contrib.md index bf9c9c238..e0ce8f01d 100644 --- a/docs/src/api/Lux/contrib.md +++ b/docs/src/api/Lux/contrib.md @@ -35,6 +35,7 @@ basic building blocks which can be seamlessly composed to create complex trainin Lux.Experimental.TrainState Lux.Experimental.compute_gradients Lux.Experimental.apply_gradients +Lux.Experimental.apply_gradients! ``` ## Parameter Freezing diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index 44f802d28..37da291ca 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -102,7 +102,7 @@ function train() y = y |> dev (gs, _, _, train_state) = Lux.Experimental.compute_gradients( AutoZygote(), loss, (data_idx, x, y), train_state) - train_state = Lux.Experimental.apply_gradients(train_state, gs, true) + train_state = Lux.Experimental.apply_gradients!(train_state, gs) end ttime = time() - stime diff --git a/examples/PolynomialFitting/main.jl b/examples/PolynomialFitting/main.jl index 303c2d070..efe2442de 100644 --- a/examples/PolynomialFitting/main.jl +++ b/examples/PolynomialFitting/main.jl @@ -79,7 +79,7 @@ function main(tstate::Lux.Experimental.TrainState, vjp, data, epochs) if epoch % 50 == 1 || epoch == epochs @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss end - tstate = Lux.Training.apply_gradients(tstate, grads, true) + tstate = Lux.Training.apply_gradients!(tstate, grads) end return tstate end diff --git a/examples/SimpleChains/main.jl b/examples/SimpleChains/main.jl index c628181a5..7e92d64c4 100644 --- a/examples/SimpleChains/main.jl +++ b/examples/SimpleChains/main.jl @@ -82,7 +82,7 @@ function train(model; rng=Xoshiro(0), kwargs...) for (x, y) in train_dataloader (gs, _, _, train_state) = Lux.Experimental.compute_gradients( AutoZygote(), loss, (x, y), train_state) - train_state = Lux.Experimental.apply_gradients(train_state, gs, true) + train_state = Lux.Experimental.apply_gradients!(train_state, gs) end ttime = time() - stime diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index 1cb624eb3..1a03492b3 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -157,7 +157,7 @@ function main(model_type) gs, loss, _, train_state = Lux.Experimental.compute_gradients( AutoZygote(), compute_loss, (x, y), train_state) - train_state = Lux.Experimental.apply_gradients(train_state, gs, true) + train_state = Lux.Experimental.apply_gradients!(train_state, gs) @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss end diff --git a/ext/LuxOptimisersExt.jl b/ext/LuxOptimisersExt.jl index 5d549bcc3..54652e996 100644 --- a/ext/LuxOptimisersExt.jl +++ b/ext/LuxOptimisersExt.jl @@ -36,18 +36,17 @@ function Lux.Experimental.TrainState( return Lux.Experimental.TrainState(nothing, nothing, model, ps, st, st_opt, 0) end -function Lux.Experimental.apply_gradients( - ts::Lux.Experimental.TrainState, grads, update_inplace=false) - if update_inplace - optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads) - return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model, - ps, ts.states, optimizer_state, ts.step + 1) - else - Optimisers.update!(ts.optimizer_state, ts.parameters, grads) - return Lux.Experimental.TrainState( - ts.cache, ts.objective_function, ts.model, ts.parameters, - ts.states, ts.optimizer_state, ts.step + 1) - end +function Lux.Experimental.apply_gradients(ts::Lux.Experimental.TrainState, grads) + optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads) + return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model, + ps, ts.states, optimizer_state, ts.step + 1) +end + +function Lux.Experimental.apply_gradients!(ts::Lux.Experimental.TrainState, grads) + Optimisers.update!(ts.optimizer_state, ts.parameters, grads) + return Lux.Experimental.TrainState( + ts.cache, ts.objective_function, ts.model, ts.parameters, + ts.states, ts.optimizer_state, ts.step + 1) end # DistributedUtils diff --git a/src/contrib/training.jl b/src/contrib/training.jl index 1a8ee8f8b..c0f285992 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -37,7 +37,7 @@ function Base.show(io::IO, ts::TrainState) end """ - apply_gradients(ts::TrainState, grads, update_inplace::Bool=false) + apply_gradients(ts::TrainState, grads) Update the parameters stored in `ts` using the gradients `grads`. @@ -53,6 +53,23 @@ Updated [`TrainState`](@ref) object. """ function apply_gradients end +""" + apply_gradients!(ts::TrainState, grads) + +Update the parameters stored in `ts` using the gradients `grads`. This is an inplace version +of [`apply_gradients`](@ref). + +## Arguments + + - `ts`: [`TrainState`](@ref) object. + - `grads`: Gradients of the loss function wrt `ts.params`. + +## Returns + +Updated [`TrainState`](@ref) object. +""" +function apply_gradients! end + """ compute_gradients(ad::ADTypes.AbstractADType, objective_function::Function, data, ts::TrainState)