From 19ad7105d01d44810f171b8e2ec17dcd2525f105 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 19:03:54 -0400 Subject: [PATCH] Add additional fields to the struct --- ext/LuxOptimisersExt.jl | 6 +++--- src/contrib/training.jl | 9 ++++++++- test/contrib/training_tests.jl | 2 +- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/ext/LuxOptimisersExt.jl b/ext/LuxOptimisersExt.jl index 7a1275a52..c5f6950d7 100644 --- a/ext/LuxOptimisersExt.jl +++ b/ext/LuxOptimisersExt.jl @@ -33,13 +33,13 @@ function Lux.Experimental.TrainState( transform_variables::Union{Function, AbstractLuxDevice}=gpu_device()) ps, st = Lux.setup(rng, model) .|> transform_variables st_opt = Optimisers.setup(optimizer, ps) - return Lux.Experimental.TrainState(model, ps, st, st_opt, 0) + return Lux.Experimental.TrainState(nothing, nothing, model, ps, st, st_opt, 0) end function Lux.Experimental.apply_gradients(ts::Lux.Experimental.TrainState, grads) optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads) - return Lux.Experimental.TrainState( - ts.model, ps, ts.states, optimizer_state, ts.step + 1) + return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model, + ps, ts.states, optimizer_state, ts.step + 1) end # DistributedUtils diff --git a/src/contrib/training.jl b/src/contrib/training.jl index f29bb5c6d..ca496fc38 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -8,8 +8,15 @@ Training State containing: - `states`: Non-trainable Variables of the `model`. - `optimizer_state`: Optimizer State. - `step`: Number of updates of the parameters made. + +Internal fields: + + - `cache`: Cached values. Implementations are free to use this for whatever they want. + - `objective_function`: Objective function might be cached. """ -@concrete struct TrainState +@concrete struct TrainState{C, F} + cache::C + objective_function::F model parameters states diff --git a/test/contrib/training_tests.jl b/test/contrib/training_tests.jl index a11281b1c..fc5cb56cd 100644 --- a/test/contrib/training_tests.jl +++ b/test/contrib/training_tests.jl @@ -23,7 +23,7 @@ end @testitem "AbstractADTypes" setup=[SharedTestSetup] tags=[:contrib] begin - using ADTypes, Optimisers + using ADTypes, Optimisers, Enzyme function _loss_function(model, ps, st, data) y, st = model(data, ps, st)