Skip to content

Commit

Permalink
renmae apply_gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 13, 2024
1 parent eb8918d commit 35c7648
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 17 deletions.
1 change: 1 addition & 0 deletions docs/src/api/Lux/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ basic building blocks which can be seamlessly composed to create complex trainin
Lux.Experimental.TrainState
Lux.Experimental.compute_gradients
Lux.Experimental.apply_gradients
Lux.Experimental.apply_gradients!
```

## Parameter Freezing
Expand Down
2 changes: 1 addition & 1 deletion examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ function train()
y = y |> dev
(gs, _, _, train_state) = Lux.Experimental.compute_gradients(
AutoZygote(), loss, (data_idx, x, y), train_state)
train_state = Lux.Experimental.apply_gradients(train_state, gs, true)
train_state = Lux.Experimental.apply_gradients!(train_state, gs)
end
ttime = time() - stime

Expand Down
2 changes: 1 addition & 1 deletion examples/PolynomialFitting/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function main(tstate::Lux.Experimental.TrainState, vjp, data, epochs)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
tstate = Lux.Training.apply_gradients(tstate, grads, true)
tstate = Lux.Training.apply_gradients!(tstate, grads)
end
return tstate
end
Expand Down
2 changes: 1 addition & 1 deletion examples/SimpleChains/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function train(model; rng=Xoshiro(0), kwargs...)
for (x, y) in train_dataloader
(gs, _, _, train_state) = Lux.Experimental.compute_gradients(
AutoZygote(), loss, (x, y), train_state)
train_state = Lux.Experimental.apply_gradients(train_state, gs, true)
train_state = Lux.Experimental.apply_gradients!(train_state, gs)
end
ttime = time() - stime

Expand Down
2 changes: 1 addition & 1 deletion examples/SimpleRNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ function main(model_type)

gs, loss, _, train_state = Lux.Experimental.compute_gradients(
AutoZygote(), compute_loss, (x, y), train_state)
train_state = Lux.Experimental.apply_gradients(train_state, gs, true)
train_state = Lux.Experimental.apply_gradients!(train_state, gs)

@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
Expand Down
23 changes: 11 additions & 12 deletions ext/LuxOptimisersExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,17 @@ function Lux.Experimental.TrainState(
return Lux.Experimental.TrainState(nothing, nothing, model, ps, st, st_opt, 0)
end

function Lux.Experimental.apply_gradients(
ts::Lux.Experimental.TrainState, grads, update_inplace=false)
if update_inplace
optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads)
return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model,
ps, ts.states, optimizer_state, ts.step + 1)
else
Optimisers.update!(ts.optimizer_state, ts.parameters, grads)
return Lux.Experimental.TrainState(
ts.cache, ts.objective_function, ts.model, ts.parameters,
ts.states, ts.optimizer_state, ts.step + 1)
end
function Lux.Experimental.apply_gradients(ts::Lux.Experimental.TrainState, grads)
optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads)
return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model,
ps, ts.states, optimizer_state, ts.step + 1)
end

function Lux.Experimental.apply_gradients!(ts::Lux.Experimental.TrainState, grads)
Optimisers.update!(ts.optimizer_state, ts.parameters, grads)
return Lux.Experimental.TrainState(
ts.cache, ts.objective_function, ts.model, ts.parameters,
ts.states, ts.optimizer_state, ts.step + 1)
end

# DistributedUtils
Expand Down
19 changes: 18 additions & 1 deletion src/contrib/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function Base.show(io::IO, ts::TrainState)
end

"""
apply_gradients(ts::TrainState, grads, update_inplace::Bool=false)
apply_gradients(ts::TrainState, grads)
Update the parameters stored in `ts` using the gradients `grads`.
Expand All @@ -53,6 +53,23 @@ Updated [`TrainState`](@ref) object.
"""
function apply_gradients end

"""
apply_gradients!(ts::TrainState, grads)
Update the parameters stored in `ts` using the gradients `grads`. This is an inplace version
of [`apply_gradients`](@ref).
## Arguments
- `ts`: [`TrainState`](@ref) object.
- `grads`: Gradients of the loss function wrt `ts.params`.
## Returns
Updated [`TrainState`](@ref) object.
"""
function apply_gradients! end

"""
compute_gradients(ad::ADTypes.AbstractADType, objective_function::Function, data,
ts::TrainState)
Expand Down

0 comments on commit 35c7648

Please sign in to comment.