diff --git a/test/contrib/training_tests.jl b/test/contrib/training_tests.jl index 3dfae1845..5a25ff5ac 100644 --- a/test/contrib/training_tests.jl +++ b/test/contrib/training_tests.jl @@ -88,7 +88,7 @@ end for epoch in 1:100, (x, y) in dataset_ grads, loss, _, tstate = Lux.Experimental.compute_gradients( ad, mse, (x, y), tstate) - tstate = Lux.Experimental.apply_gradients(tstate, grads, true) + tstate = Lux.Experimental.apply_gradients!(tstate, grads) end final_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1]))