From 66db48ff0393b17e88e1887e67c6d17e4a855aad Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 12 Mar 2024 04:46:48 +0000 Subject: [PATCH] docs: use BFGS with back tracking for low level system tutorial --- docs/src/tutorials/low_level.md | 7 +++---- docs/src/tutorials/pdesystem.md | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/src/tutorials/low_level.md b/docs/src/tutorials/low_level.md index d625bbcf0..63fa0d476 100644 --- a/docs/src/tutorials/low_level.md +++ b/docs/src/tutorials/low_level.md @@ -13,7 +13,7 @@ u(t, -1) = u(t, 1) = 0 \, , with Physics-Informed Neural Networks. Here is an example of using the low-level API: ```@example low_level -using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL +using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL, LineSearches using ModelingToolkit: Interval, infimum, supremum @parameters t, x @@ -37,7 +37,7 @@ domains = [t ∈ Interval(0.0, 1.0), # Neural network chain = Lux.Chain(Dense(2, 16, Lux.σ), Dense(16, 16, Lux.σ), Dense(16, 1)) -strategy = NeuralPDE.QuadratureTraining() +strategy = NeuralPDE.QuadratureTraining(; abstol = 1e-6, reltol = 1e-6, batch = 200) indvars = [t, x] depvars = [u(t, x)] @@ -67,8 +67,7 @@ end f_ = OptimizationFunction(loss_function, Optimization.AutoZygote()) prob = Optimization.OptimizationProblem(f_, sym_prob.flat_init_params) -res = Optimization.solve(prob, OptimizationOptimJL.BFGS(); callback = callback, - maxiters = 2000) +res = Optimization.solve(prob, BFGS(linesearch = BackTracking()); callback = callback, maxiters = 3000) ``` And some analysis: diff --git a/docs/src/tutorials/pdesystem.md b/docs/src/tutorials/pdesystem.md index d5962f8d1..4076bce5c 100644 --- a/docs/src/tutorials/pdesystem.md +++ b/docs/src/tutorials/pdesystem.md @@ -66,7 +66,7 @@ end # Optimizer opt = OptimizationOptimJL.LBFGS(linesearch = BackTracking()) -res = solve(prob, opt, callback = callback, maxiters = 2000) +res = solve(prob, opt, callback = callback, maxiters = 1000) phi = discretization.phi dx = 0.05