Skip to content

Commit

Permalink
Add additional fields to the struct
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 12, 2024
1 parent 475a8cc commit 19ad710
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
6 changes: 3 additions & 3 deletions ext/LuxOptimisersExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ function Lux.Experimental.TrainState(
transform_variables::Union{Function, AbstractLuxDevice}=gpu_device())
ps, st = Lux.setup(rng, model) .|> transform_variables
st_opt = Optimisers.setup(optimizer, ps)
return Lux.Experimental.TrainState(model, ps, st, st_opt, 0)
return Lux.Experimental.TrainState(nothing, nothing, model, ps, st, st_opt, 0)
end

function Lux.Experimental.apply_gradients(ts::Lux.Experimental.TrainState, grads)
optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads)
return Lux.Experimental.TrainState(
ts.model, ps, ts.states, optimizer_state, ts.step + 1)
return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model,
ps, ts.states, optimizer_state, ts.step + 1)
end

# DistributedUtils
Expand Down
9 changes: 8 additions & 1 deletion src/contrib/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@ Training State containing:
- `states`: Non-trainable Variables of the `model`.
- `optimizer_state`: Optimizer State.
- `step`: Number of updates of the parameters made.
Internal fields:
- `cache`: Cached values. Implementations are free to use this for whatever they want.
- `objective_function`: Objective function might be cached.
"""
@concrete struct TrainState
@concrete struct TrainState{C, F}
cache::C
objective_function::F
model
parameters
states
Expand Down
2 changes: 1 addition & 1 deletion test/contrib/training_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
end

@testitem "AbstractADTypes" setup=[SharedTestSetup] tags=[:contrib] begin
using ADTypes, Optimisers
using ADTypes, Optimisers, Enzyme

function _loss_function(model, ps, st, data)
y, st = model(data, ps, st)
Expand Down

0 comments on commit 19ad710

Please sign in to comment.