diff --git a/examples/OptimizationIntegration/main.jl b/examples/OptimizationIntegration/main.jl index 7c617348a..3bc227963 100644 --- a/examples/OptimizationIntegration/main.jl +++ b/examples/OptimizationIntegration/main.jl @@ -117,15 +117,13 @@ function train_model(dataloader) res_adam = solve(opt_prob, Optimisers.Adam(0.001); callback, maxiters=epochs) ## Let's finetune a bit with L-BFGS - opt_prob = remake(opt_prob; u0=res_adam.u) + opt_prob = OptimizationProblem(opt_func, res_adam.u, (gdev(ode_data), TimeWrapper(t))) res_lbfgs = solve(opt_prob, LBFGS(); callback, maxiters=epochs) ## Now that we have a good fit, let's train it on the entire dataset without ## Minibatching. We need to do this since ODE solves can lead to accumulated errors if ## the model was trained on individual parts (without a data-shooting approach). - opt_func = OptimizationFunction(loss_adjoint, Optimization.AutoZygote()) - opt_prob = OptimizationProblem(opt_func, res_lbfgs.u, (gdev(ode_data), TimeWrapper(t))) - + opt_prob = remake(opt_prob; u0=res_lbfgs.u) res = solve(opt_prob, Optimisers.Adam(0.005); maxiters=500, callback) return StatefulLuxLayer{true}(model, res.u, smodel.st)