From b6fe280305b08d0571a6674cecf2c381491c1786 Mon Sep 17 00:00:00 2001 From: Dorival Pedroso Date: Wed, 15 May 2024 18:25:08 +1000 Subject: [PATCH] [wip] Store Output in OdeSolver --- .vscode/settings.json | 5 + README.md | 3 +- russell_ode/examples/amplifier1t_radau5.rs | 11 +- russell_ode/examples/arenstorf_dopri8.rs | 16 +- .../examples/brusselator_ode_dopri8.rs | 16 +- .../examples/brusselator_ode_fix_step.rs | 2 +- .../examples/brusselator_ode_var_step.rs | 2 +- .../brusselator_pde_2nd_comparison.rs | 6 +- .../examples/brusselator_pde_radau5.rs | 9 +- .../examples/brusselator_pde_radau5_2nd.rs | 9 +- russell_ode/examples/hairer_wanner_eq1.rs | 27 +- .../pde_1d_heat_spectral_collocation.rs | 2 +- russell_ode/examples/robertson.rs | 29 +- .../examples/simple_ode_single_equation.rs | 2 +- .../examples/simple_system_with_mass.rs | 2 +- russell_ode/examples/van_der_pol_dopri5.rs | 25 +- russell_ode/examples/van_der_pol_radau5.rs | 13 +- russell_ode/src/bin/amplifier1t.rs | 9 +- russell_ode/src/bin/brusselator_pde.rs | 2 +- russell_ode/src/ode_solver.rs | 321 ++++++++++-------- russell_ode/src/output.rs | 31 +- russell_ode/tests/test_bweuler.rs | 6 +- russell_ode/tests/test_dopri5_arenstorf.rs | 28 +- .../tests/test_dopri5_arenstorf_debug.rs | 6 +- .../tests/test_dopri5_hairer_wanner_eq1.rs | 19 +- .../tests/test_dopri5_van_der_pol_debug.rs | 17 +- russell_ode/tests/test_dopri8_van_der_pol.rs | 21 +- .../tests/test_dopri8_van_der_pol_debug.rs | 14 +- russell_ode/tests/test_fweuler.rs | 2 +- russell_ode/tests/test_mdeuler.rs | 2 +- russell_ode/tests/test_radau5_amplifier1t.rs | 30 +- .../tests/test_radau5_brusselator_pde.rs | 6 +- .../tests/test_radau5_hairer_wanner_eq1.rs | 19 +- .../test_radau5_hairer_wanner_eq1_debug.rs | 6 +- russell_ode/tests/test_radau5_robertson.rs | 21 +- .../tests/test_radau5_robertson_debug.rs | 6 +- .../tests/test_radau5_robertson_small_h.rs | 6 +- russell_ode/tests/test_radau5_van_der_pol.rs | 21 +- .../tests/test_radau5_van_der_pol_debug.rs | 6 +- 39 files changed, 444 insertions(+), 334 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index af2deddf..f86097da 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -151,5 +151,10 @@ "zscal", "zsyrk", "zticks" + ], + "rust-analyzer.cargo.features": [ + "intel_mkl", + "local_suitesparse", + "with_mumps" ] } \ No newline at end of file diff --git a/README.md b/README.md index 7f1baea7..ab1236a4 100644 --- a/README.md +++ b/README.md @@ -600,10 +600,9 @@ fn main() -> Result<(), StrError> { let mut solver = OdeSolver::new(params, &system)?; // enable dense output - let mut out = Output::new(); let h_out = 0.01; let selected_y_components = &[0, 1]; - out.set_dense_recording(true, h_out, selected_y_components)?; + solver.enable_output().set_dense_recording(true, h_out, selected_y_components)?; // solve the problem solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args)?; diff --git a/russell_ode/examples/amplifier1t_radau5.rs b/russell_ode/examples/amplifier1t_radau5.rs index 901d5e7b..7b3ef299 100644 --- a/russell_ode/examples/amplifier1t_radau5.rs +++ b/russell_ode/examples/amplifier1t_radau5.rs @@ -24,14 +24,15 @@ fn main() -> Result<(), StrError> { let mut solver = OdeSolver::new(params, &system)?; // enable dense output - let mut out = Output::new(); let h_out = 0.0001; let selected_y_components = &[0, 4]; - out.set_dense_recording(true, h_out, selected_y_components)?; + solver + .enable_output() + .set_dense_recording(true, h_out, selected_y_components)?; // solve the problem let x1 = 0.2; - solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args)?; + solver.solve(&mut y0, x0, x1, None, &mut args)?; // print the results and stats let y_ref = &[ @@ -74,9 +75,9 @@ fn main() -> Result<(), StrError> { .set_marker_style("+") .set_line_style("None"); - curve1.draw(&out.dense_x, out.dense_y.get(&0).unwrap()); + curve1.draw(&solver.out().dense_x, solver.out().dense_y.get(&0).unwrap()); curve2.draw(&math.x, &math.y0); - curve3.draw(&out.dense_x, out.dense_y.get(&4).unwrap()); + curve3.draw(&solver.out().dense_x, solver.out().dense_y.get(&4).unwrap()); curve4.draw(&math.x, &math.y4); // save figure diff --git a/russell_ode/examples/arenstorf_dopri8.rs b/russell_ode/examples/arenstorf_dopri8.rs index df2998df..10f9d67d 100644 --- a/russell_ode/examples/arenstorf_dopri8.rs +++ b/russell_ode/examples/arenstorf_dopri8.rs @@ -19,19 +19,22 @@ fn main() -> Result<(), StrError> { // get the ODE system let (system, x0, mut y0, x1, mut args, y_ref) = Samples::arenstorf(); - // solver + // set configuration parameters let params = Params::new(Method::DoPri8); + + // allocate the solver let mut solver = OdeSolver::new(params, &system)?; // enable dense output - let mut out = Output::new(); let h_out = 0.01; let selected_y_components = &[0, 1]; - out.set_dense_recording(true, h_out, selected_y_components)?; + solver + .enable_output() + .set_dense_recording(true, h_out, selected_y_components)?; // solve the problem let y = &mut y0; - solver.solve(y, x0, x1, None, Some(&mut out), &mut args)?; + solver.solve(y, x0, x1, None, &mut args)?; // print the results and stats println!("y_russell = {:?}", y.as_data()); @@ -46,7 +49,10 @@ fn main() -> Result<(), StrError> { let mut curve2 = Curve::new(); curve1.set_label("russell"); curve2.set_label("mathematica"); - curve1.draw(out.dense_y.get(&0).unwrap(), out.dense_y.get(&1).unwrap()); + curve1.draw( + solver.out().dense_y.get(&0).unwrap(), + solver.out().dense_y.get(&1).unwrap(), + ); curve2.set_marker_style(".").set_line_style("None"); curve2.draw(&math.y0, &math.y1); diff --git a/russell_ode/examples/brusselator_ode_dopri8.rs b/russell_ode/examples/brusselator_ode_dopri8.rs index 9a3f5864..771fdae3 100644 --- a/russell_ode/examples/brusselator_ode_dopri8.rs +++ b/russell_ode/examples/brusselator_ode_dopri8.rs @@ -22,18 +22,21 @@ fn main() -> Result<(), StrError> { // final x let x1 = 20.0; - // solver + // set configuration parameters let params = Params::new(Method::DoPri8); + + // allocate the solver let mut solver = OdeSolver::new(params, &system)?; // enable dense output - let mut out = Output::new(); let h_out = 0.01; let selected_y_components = &[0, 1]; - out.set_dense_recording(true, h_out, selected_y_components)?; + solver + .enable_output() + .set_dense_recording(true, h_out, selected_y_components)?; // solve the problem - solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args)?; + solver.solve(&mut y0, x0, x1, None, &mut args)?; // print the results and stats println!("y_russell = {:?}", y0.as_data()); @@ -48,7 +51,10 @@ fn main() -> Result<(), StrError> { let mut curve2 = Curve::new(); curve1.set_label("russell"); curve2.set_label("mathematica"); - curve1.draw(out.dense_y.get(&0).unwrap(), out.dense_y.get(&1).unwrap()); + curve1.draw( + solver.out().dense_y.get(&0).unwrap(), + solver.out().dense_y.get(&1).unwrap(), + ); curve2.set_marker_style(".").set_line_style("None"); curve2.draw(&math.y0, &math.y1); diff --git a/russell_ode/examples/brusselator_ode_fix_step.rs b/russell_ode/examples/brusselator_ode_fix_step.rs index 7789c1e0..5c15dc38 100644 --- a/russell_ode/examples/brusselator_ode_fix_step.rs +++ b/russell_ode/examples/brusselator_ode_fix_step.rs @@ -55,7 +55,7 @@ fn main() -> Result<(), StrError> { print!("{:>w$}", name, w = w1); for i in 0..hh.len() { let mut y = y0.clone(); - solver.solve(&mut y, x0, x1, Some(hh[i]), None, &mut args).unwrap(); + solver.solve(&mut y, x0, x1, Some(hh[i]), &mut args).unwrap(); // compare with the reference solution let (_, err) = vec_max_abs_diff(&y, &y_ref)?; diff --git a/russell_ode/examples/brusselator_ode_var_step.rs b/russell_ode/examples/brusselator_ode_var_step.rs index c6e2b646..65886e32 100644 --- a/russell_ode/examples/brusselator_ode_var_step.rs +++ b/russell_ode/examples/brusselator_ode_var_step.rs @@ -61,7 +61,7 @@ fn main() -> Result<(), StrError> { // call solve let mut y = y0.clone(); - solver.solve(&mut y, x0, x1, None, None, &mut args).unwrap(); + solver.solve(&mut y, x0, x1, None, &mut args).unwrap(); // compare with the reference solution let (_, err) = vec_max_abs_diff(&y, &y_ref)?; diff --git a/russell_ode/examples/brusselator_pde_2nd_comparison.rs b/russell_ode/examples/brusselator_pde_2nd_comparison.rs index 11a7cd09..77f038d1 100644 --- a/russell_ode/examples/brusselator_pde_2nd_comparison.rs +++ b/russell_ode/examples/brusselator_pde_2nd_comparison.rs @@ -21,10 +21,12 @@ fn main() { let mut params = Params::new(Method::Radau5); params.set_tolerances(1e-4, 1e-4, None).unwrap(); - // solve the ODE system + // allocate the solver let mut solver = OdeSolver::new(params, &system).unwrap(); + + // solve the ODE system let mut yy = yy0.clone(); - solver.solve(&mut yy, t0, t1, None, None, &mut args).unwrap(); + solver.solve(&mut yy, t0, t1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); diff --git a/russell_ode/examples/brusselator_pde_radau5.rs b/russell_ode/examples/brusselator_pde_radau5.rs index 61c2dbb9..37117d00 100644 --- a/russell_ode/examples/brusselator_pde_radau5.rs +++ b/russell_ode/examples/brusselator_pde_radau5.rs @@ -25,14 +25,15 @@ fn main() -> Result<(), StrError> { let mut params = Params::new(Method::Radau5); params.set_tolerances(1e-4, 1e-4, None)?; + // allocate the solver + let mut solver = OdeSolver::new(params, &system)?; + // output - let mut out = Output::new(); let h_out = 0.5; - out.set_dense_file_writing(true, h_out, PATH_KEY)?; + solver.enable_output().set_dense_file_writing(true, h_out, PATH_KEY)?; // solve the ODE system - let mut solver = OdeSolver::new(params, &system)?; - solver.solve(&mut yy0, t0, t1, None, Some(&mut out), &mut args)?; + solver.solve(&mut yy0, t0, t1, None, &mut args)?; // get statistics let stat = solver.stats(); diff --git a/russell_ode/examples/brusselator_pde_radau5_2nd.rs b/russell_ode/examples/brusselator_pde_radau5_2nd.rs index 957907d6..d6f8c19f 100644 --- a/russell_ode/examples/brusselator_pde_radau5_2nd.rs +++ b/russell_ode/examples/brusselator_pde_radau5_2nd.rs @@ -26,14 +26,15 @@ fn main() -> Result<(), StrError> { let mut params = Params::new(Method::Radau5); params.set_tolerances(1e-4, 1e-4, None)?; + // allocate the solver + let mut solver = OdeSolver::new(params, &system)?; + // output - let mut out = Output::new(); let h_out = 1.0; - out.set_dense_file_writing(true, h_out, PATH_KEY)?; + solver.enable_output().set_dense_file_writing(true, h_out, PATH_KEY)?; // solve the ODE system - let mut solver = OdeSolver::new(params, &system)?; - solver.solve(&mut yy0, t0, t1, None, Some(&mut out), &mut args)?; + solver.solve(&mut yy0, t0, t1, None, &mut args)?; // get statistics let stat = solver.stats(); diff --git a/russell_ode/examples/hairer_wanner_eq1.rs b/russell_ode/examples/hairer_wanner_eq1.rs index 241c1493..089f4c91 100644 --- a/russell_ode/examples/hairer_wanner_eq1.rs +++ b/russell_ode/examples/hairer_wanner_eq1.rs @@ -30,25 +30,26 @@ fn main() -> Result<(), StrError> { let mut fweuler = OdeSolver::new(Params::new(Method::FwEuler), &system)?; // solve the problem with BwEuler and h = 0.5 - let mut out1 = Output::new(); - out1.set_step_recording(true, &[0]); + bweuler.enable_output().set_step_recording(true, &[0]); let h = 0.5; let mut y = y0.clone(); - bweuler.solve(&mut y, x0, x1, Some(h), Some(&mut out1), &mut args)?; + bweuler.solve(&mut y, x0, x1, Some(h), &mut args)?; // solve the problem with FwEuler and h = 1.974/50.0 - let mut out2 = Output::new(); - out2.set_step_recording(true, &[0]); + fweuler.enable_output().set_step_recording(true, &[0]); let h = 1.974 / 50.0; let mut y = y0.clone(); - fweuler.solve(&mut y, x0, x1, Some(h), Some(&mut out2), &mut args)?; + fweuler.solve(&mut y, x0, x1, Some(h), &mut args)?; + + // save the results for later + let out2_x = fweuler.out().step_x.clone(); + let out2_y = fweuler.out().step_y.get(&0).unwrap().clone(); // solve the problem with FwEuler and h = 1.875/50.0 - let mut out3 = Output::new(); - out3.set_step_recording(true, &[0]); + fweuler.enable_output().clear().set_step_recording(true, &[0]); let h = 1.875 / 50.0; let mut y = y0.clone(); - fweuler.solve(&mut y, x0, x1, Some(h), Some(&mut out3), &mut args)?; + fweuler.solve(&mut y, x0, x1, Some(h), &mut args)?; // analytical solution let mut y_aux = Vector::new(system.get_ndim()); @@ -67,17 +68,15 @@ fn main() -> Result<(), StrError> { let mut curve1 = Curve::new(); curve1 .set_label("BwEuler h = 0.5") - .draw(&out1.step_x, out1.step_y.get(&0).unwrap()); + .draw(&bweuler.out().step_x, bweuler.out().step_y.get(&0).unwrap()); // FwEuler curves let mut curve2 = Curve::new(); let mut curve3 = Curve::new(); - curve2 - .set_label("FwEuler h = 1.974/50") - .draw(&out2.step_x, out2.step_y.get(&0).unwrap()); + curve2.set_label("FwEuler h = 1.974/50").draw(&out2_x, &out2_y); curve3 .set_label("FwEuler h = 1.875/50") - .draw(&out3.step_x, out3.step_y.get(&0).unwrap()); + .draw(&fweuler.out().step_x, fweuler.out().step_y.get(&0).unwrap()); // save figure let mut plot = Plot::new(); diff --git a/russell_ode/examples/pde_1d_heat_spectral_collocation.rs b/russell_ode/examples/pde_1d_heat_spectral_collocation.rs index 8744e2b8..e2e964e3 100644 --- a/russell_ode/examples/pde_1d_heat_spectral_collocation.rs +++ b/russell_ode/examples/pde_1d_heat_spectral_collocation.rs @@ -106,7 +106,7 @@ fn run( |x| f64::sin((x + 1.0) * PI)); // solve the problem - ode.solve(&mut uu, t0, t1, None, None, &mut args)?; + ode.solve(&mut uu, t0, t1, None, &mut args)?; // print stats if print_stats { diff --git a/russell_ode/examples/robertson.rs b/russell_ode/examples/robertson.rs index 51e66835..3202eaed 100644 --- a/russell_ode/examples/robertson.rs +++ b/russell_ode/examples/robertson.rs @@ -50,27 +50,28 @@ fn main() -> Result<(), StrError> { let sel = 1; // solve the problem with Radau5 - let mut out1 = Output::new(); - out1.set_step_recording(true, &[sel]); + radau5.enable_output().set_step_recording(true, &[sel]); let mut y = y0.clone(); - radau5.solve(&mut y, x0, x1, None, Some(&mut out1), &mut args)?; + radau5.solve(&mut y, x0, x1, None, &mut args)?; println!("{}", radau5.stats()); let n_accepted1 = radau5.stats().n_accepted; // solve the problem with DoPri5 and Tol = 1e-2 - let mut out2 = Output::new(); - out2.set_step_recording(true, &[sel]); + dopri5.enable_output().set_step_recording(true, &[sel]); let mut y = y0.clone(); - dopri5.solve(&mut y, x0, x1, None, Some(&mut out2), &mut args)?; + dopri5.solve(&mut y, x0, x1, None, &mut args)?; println!("\nTol = 1e-2\n{}", dopri5.stats()); let n_accepted2 = dopri5.stats().n_accepted; - // solve the problem with DoPri5 and Tol = 1e-3 - let mut out3 = Output::new(); - out3.set_step_recording(true, &[sel]); + // save the results for later + let out2_x = dopri5.out().step_x.clone(); + let out2_y = dopri5.out().step_y.get(&sel).unwrap().clone(); + + // solve the problem again with DoPri5 and Tol = 1e-3 + dopri5.enable_output().clear().set_step_recording(true, &[sel]); let mut y = y0.clone(); dopri5.update_params(params3)?; - dopri5.solve(&mut y, x0, x1, None, Some(&mut out3), &mut args)?; + dopri5.solve(&mut y, x0, x1, None, &mut args)?; println!("\nTol = 1e-3\n{}", dopri5.stats()); let n_accepted3 = dopri5.stats().n_accepted; @@ -79,7 +80,7 @@ fn main() -> Result<(), StrError> { curve1 .set_label(&format!("Radau5, n_accepted = {}", n_accepted1)) .set_marker_style("o") - .draw(&out1.step_x, out1.step_y.get(&sel).unwrap()); + .draw(&radau5.out().step_x, radau5.out().step_y.get(&sel).unwrap()); // DoPri5 curves let mut curve2 = Curve::new(); @@ -87,11 +88,11 @@ fn main() -> Result<(), StrError> { let mut curve4 = Curve::new(); curve2 .set_label(&format!("DoPri5, Tol = 1e-2, n_accepted = {}", n_accepted2)) - .draw(&out2.step_x, out2.step_y.get(&sel).unwrap()); + .draw(&out2_x, &out2_y); curve3 .set_label(&format!("DoPri5, Tol = 1e-3, n_accepted = {}", n_accepted3)) - .draw(&out3.step_x, out3.step_y.get(&sel).unwrap()); - curve4.draw(&out3.step_x, &out3.step_h); + .draw(&dopri5.out().step_x, dopri5.out().step_y.get(&sel).unwrap()); + curve4.draw(&dopri5.out().step_x, &dopri5.out().step_h); // save figures let mut plot1 = Plot::new(); diff --git a/russell_ode/examples/simple_ode_single_equation.rs b/russell_ode/examples/simple_ode_single_equation.rs index 14254857..6dcbcfde 100644 --- a/russell_ode/examples/simple_ode_single_equation.rs +++ b/russell_ode/examples/simple_ode_single_equation.rs @@ -27,7 +27,7 @@ fn main() -> Result<(), StrError> { // solve from x = 0 to x = 1 let x1 = 1.0; let mut args = 0; - solver.solve(&mut y, x, x1, None, None, &mut args)?; + solver.solve(&mut y, x, x1, None, &mut args)?; println!("y =\n{}", y); // check the results diff --git a/russell_ode/examples/simple_system_with_mass.rs b/russell_ode/examples/simple_system_with_mass.rs index 2c2f02b4..817fd89c 100644 --- a/russell_ode/examples/simple_system_with_mass.rs +++ b/russell_ode/examples/simple_system_with_mass.rs @@ -47,7 +47,7 @@ fn main() -> Result<(), StrError> { // solve from x = 0 to x = 20 let x1 = 20.0; let mut args = 0; - solver.solve(&mut y, x, x1, None, None, &mut args)?; + solver.solve(&mut y, x, x1, None, &mut args)?; println!("y =\n{}", y); // check the results diff --git a/russell_ode/examples/van_der_pol_dopri5.rs b/russell_ode/examples/van_der_pol_dopri5.rs index d44c0f10..eff1bb64 100644 --- a/russell_ode/examples/van_der_pol_dopri5.rs +++ b/russell_ode/examples/van_der_pol_dopri5.rs @@ -21,24 +21,27 @@ fn main() -> Result<(), StrError> { let (system, x0, _, x1, mut args) = Samples::van_der_pol(EPS, false); let mut y0 = Vector::from(&[2.0, 0.0]); - // solver + // set configuration parameters let mut params = Params::new(Method::DoPri5); params.stiffness.enabled = true; params.stiffness.stop_with_error = false; params.stiffness.save_results = true; params.step.h_ini = 1e-4; params.set_tolerances(1e-5, 1e-5, None)?; + + // allocate the solver let mut solver = OdeSolver::new(params, &system)?; // enable step and dense output - let mut out = Output::new(); let h_out = 0.01; let selected_y_components = &[0, 1]; - out.set_step_recording(true, selected_y_components); - out.set_dense_recording(true, h_out, selected_y_components)?; + solver + .enable_output() + .set_step_recording(true, selected_y_components) + .set_dense_recording(true, h_out, selected_y_components)?; // solve the problem - solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args)?; + solver.solve(&mut y0, x0, x1, None, &mut args)?; println!("y =\n{}", y0); // print stats @@ -51,17 +54,17 @@ fn main() -> Result<(), StrError> { let mut curve4 = Curve::new(); curve1 .set_line_color("black") - .draw(&out.dense_x, out.dense_y.get(&0).unwrap()); + .draw(&solver.out().dense_x, solver.out().dense_y.get(&0).unwrap()); curve2 .set_marker_style(".") .set_marker_color("cyan") - .draw(&out.step_x, out.step_y.get(&0).unwrap()); + .draw(&solver.out().step_x, solver.out().step_y.get(&0).unwrap()); curve3.set_line_color("red").set_line_style("--"); - for i in 0..out.stiff_x.len() { - curve3.draw_ray(out.stiff_x[i], 0.0, RayEndpoint::Vertical); + for i in 0..solver.out().stiff_x.len() { + curve3.draw_ray(solver.out().stiff_x[i], 0.0, RayEndpoint::Vertical); } - let fac: Vec<_> = out.stiff_h_times_rho.iter().map(|hr| hr / 3.3).collect(); - curve4.set_marker_style(".").draw(&out.step_x, &fac); + let fac: Vec<_> = solver.out().stiff_h_times_rho.iter().map(|hr| hr / 3.3).collect(); + curve4.set_marker_style(".").draw(&solver.out().step_x, &fac); // save figure let mut plot = Plot::new(); diff --git a/russell_ode/examples/van_der_pol_radau5.rs b/russell_ode/examples/van_der_pol_radau5.rs index 64bd5900..f799258f 100644 --- a/russell_ode/examples/van_der_pol_radau5.rs +++ b/russell_ode/examples/van_der_pol_radau5.rs @@ -21,19 +21,20 @@ fn main() -> Result<(), StrError> { let (system, x0, _, x1, mut args) = Samples::van_der_pol(EPS, false); let mut y0 = Vector::from(&[2.0, -0.6]); - // solver + // set configuration parameters let mut params = Params::new(Method::Radau5); params.step.h_ini = 1e-4; params.set_tolerances(1e-4, 1e-4, None)?; + + // allocate the solver let mut solver = OdeSolver::new(params, &system)?; // enable step output - let mut out = Output::new(); let selected_y_components = &[0, 1]; - out.set_step_recording(true, selected_y_components); + solver.enable_output().set_step_recording(true, selected_y_components); // solve the problem - solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args)?; + solver.solve(&mut y0, x0, x1, None, &mut args)?; println!("y =\n{}", y0); // print stats @@ -46,12 +47,12 @@ fn main() -> Result<(), StrError> { .set_marker_color("red") .set_marker_line_color("red") .set_marker_style(".") - .draw(&out.step_x, out.step_y.get(&0).unwrap()); + .draw(&solver.out().step_x, solver.out().step_y.get(&0).unwrap()); curve2 .set_marker_color("green") .set_marker_line_color("green") .set_marker_style(".") - .draw(&out.step_x, &out.step_h); + .draw(&solver.out().step_x, &solver.out().step_h); // save figure let mut plot = Plot::new(); diff --git a/russell_ode/src/bin/amplifier1t.rs b/russell_ode/src/bin/amplifier1t.rs index f4860dc7..ff6a8e98 100644 --- a/russell_ode/src/bin/amplifier1t.rs +++ b/russell_ode/src/bin/amplifier1t.rs @@ -13,14 +13,15 @@ fn main() -> Result<(), StrError> { params.step.h_ini = 1e-6; params.set_tolerances(1e-4, 1e-4, None)?; + // allocate the solver + let mut solver = OdeSolver::new(params, &system)?; + // enable dense output - let mut out = Output::new(); - out.set_dense_recording(true, 0.001, &[0, 4])?; + solver.enable_output().set_dense_recording(true, 0.001, &[0, 4])?; // solve the ODE system let y = &mut y0; - let mut solver = OdeSolver::new(params, &system)?; - solver.solve(y, x0, x1, None, Some(&mut out), &mut args)?; + solver.solve(y, x0, x1, None, &mut args)?; // compare with radau5.f approx_eq(y[0], -2.226517868073645E-02, 1e-10); diff --git a/russell_ode/src/bin/brusselator_pde.rs b/russell_ode/src/bin/brusselator_pde.rs index e183d69f..5f174410 100644 --- a/russell_ode/src/bin/brusselator_pde.rs +++ b/russell_ode/src/bin/brusselator_pde.rs @@ -100,7 +100,7 @@ fn main() -> Result<(), StrError> { // solve the ODE system let mut solver = OdeSolver::new(params, &system)?; - solver.solve(&mut yy0, t0, t1, None, None, &mut args)?; + solver.solve(&mut yy0, t0, t1, None, &mut args)?; // print stat let stat = solver.stats(); diff --git a/russell_ode/src/ode_solver.rs b/russell_ode/src/ode_solver.rs index f0872f8c..2c731960 100644 --- a/russell_ode/src/ode_solver.rs +++ b/russell_ode/src/ode_solver.rs @@ -122,6 +122,12 @@ pub struct OdeSolver<'a, A> { /// Holds statistics, benchmarking and "work" variables work: Workspace, + + /// Assists in generating the output of results (steps or dense) + output: Output<'a, A>, + + /// Indicates whether the output is enabled or not + output_enabled: bool, } impl<'a, A> OdeSolver<'a, A> { @@ -164,6 +170,8 @@ impl<'a, A> OdeSolver<'a, A> { ndim, actual, work: Workspace::new(params.method), + output: Output::new(), + output_enabled: false, }) } @@ -183,14 +191,12 @@ impl<'a, A> OdeSolver<'a, A> { /// if possible, variable step sizes are automatically calculated. If automatic /// stepping is not possible (e.g., the RK method is not embedded), /// a constant (and equal) stepsize will be calculated for [N_EQUAL_STEPS] steps. - /// * `output` -- structure to hold the results at accepted steps or at specified stations (continuous/dense output) pub fn solve( &mut self, y0: &mut Vector, x0: f64, x1: f64, h_equal: Option, - mut output: Option<&mut Output>, args: &mut A, ) -> Result<(), StrError> { // check data @@ -234,12 +240,12 @@ impl<'a, A> OdeSolver<'a, A> { let y = y0; // will become y1 at the end // first output - if let Some(out) = output.as_mut() { - if out.with_dense_output() { + if self.output_enabled { + if self.output.with_dense_output() { self.actual.enable_dense_output()?; } - out.stiff_record = self.params.stiffness.save_results; - let stop = out.execute(&self.work, h, x, y, &self.actual, args)?; + self.output.stiff_record = self.params.stiffness.save_results; + let stop = self.output.execute(&self.work, h, x, y, &self.actual, args)?; if stop { return Ok(()); } @@ -263,8 +269,8 @@ impl<'a, A> OdeSolver<'a, A> { vec_all_finite(&y, self.params.debug)?; // output - if let Some(out) = output.as_mut() { - let stop = out.execute(&self.work, h, x, y, &self.actual, args)?; + if self.output_enabled { + let stop = self.output.execute(&self.work, h, x, y, &self.actual, args)?; if stop { self.work.stats.stop_sw_step(); self.work.stats.stop_sw_total(); @@ -273,8 +279,8 @@ impl<'a, A> OdeSolver<'a, A> { } self.work.stats.stop_sw_step(); } - if let Some(out) = output.as_mut() { - out.last(&self.work, h, x, y, args)?; + if self.output_enabled { + self.output.last(&self.work, h, x, y, args)?; } self.work.stats.stop_sw_total(); return Ok(()); @@ -336,8 +342,8 @@ impl<'a, A> OdeSolver<'a, A> { self.work.stats.h_accepted = self.work.h_new; // output - if let Some(out) = output.as_mut() { - let stop = out.execute(&self.work, h, x, y, &self.actual, args)?; + if self.output_enabled { + let stop = self.output.execute(&self.work, h, x, y, &self.actual, args)?; if stop { self.work.stats.stop_sw_step(); self.work.stats.stop_sw_total(); @@ -376,8 +382,8 @@ impl<'a, A> OdeSolver<'a, A> { } // last output - if let Some(out) = output.as_mut() { - out.last(&self.work, h, x, y, args)?; + if self.output_enabled { + self.output.last(&self.work, h, x, y, args)?; } // done @@ -399,6 +405,26 @@ impl<'a, A> OdeSolver<'a, A> { self.params = params; Ok(()) } + + /// Enables the output of results + /// + /// Returns an access to the output structure for further configuration + pub fn enable_output(&mut self) -> &mut Output<'a, A> { + self.output_enabled = true; + &mut self.output + } + + /// Returns a read-only access to the output struct + /// + /// # Panics + /// + /// A panic may occur if the output has not been enabled yet via [OdeSolver::enable_output]. + pub fn out(&self) -> &Output<'a, A> { + if !self.output_enabled { + panic!("the output needs to be enabled first"); + } + &self.output + } } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -406,7 +432,7 @@ impl<'a, A> OdeSolver<'a, A> { #[cfg(test)] mod tests { use super::OdeSolver; - use crate::{no_jacobian, HasJacobian, NoArgs, OutCallback, OutCount, OutData, Output}; + use crate::{no_jacobian, HasJacobian, NoArgs, OutCallback, OutCount, OutData}; use crate::{Method, Params, Samples, System}; use russell_lab::{approx_eq, array_approx_eq, vec_approx_eq, Vector}; use russell_sparse::Genie; @@ -434,17 +460,17 @@ mod tests { let mut solver = OdeSolver::new(params, &system).unwrap(); let mut y0 = Vector::new(system.ndim + 1); // wrong dim assert_eq!( - solver.solve(&mut y0, 0.0, 1.0, None, None, &mut args).err(), + solver.solve(&mut y0, 0.0, 1.0, None, &mut args).err(), Some("y0.dim() must be equal to ndim") ); let mut y0 = Vector::new(system.ndim); assert_eq!( - solver.solve(&mut y0, 0.0, 0.0, None, None, &mut args).err(), + solver.solve(&mut y0, 0.0, 0.0, None, &mut args).err(), Some("x1 must be greater than x0") ); let h_equal = Some(f64::EPSILON); // will cause an error assert_eq!( - solver.solve(&mut y0, 0.0, 1.0, h_equal, None, &mut args).err(), + solver.solve(&mut y0, 0.0, 1.0, h_equal, &mut args).err(), Some("h_equal must be ≥ 10.0 * f64::EPSILON") ); } @@ -457,13 +483,13 @@ mod tests { let params = Params::new(Method::FwEuler); let mut solver = OdeSolver::new(params, &system).unwrap(); assert_eq!( - solver.solve(&mut y0, 0.0, 9.0, Some(1.0), None, &mut args).err(), + solver.solve(&mut y0, 0.0, 9.0, Some(1.0), &mut args).err(), Some("an element of the vector is either infinite or NaN") ); let params = Params::new(Method::MdEuler); let mut solver = OdeSolver::new(params, &system).unwrap(); assert_eq!( - solver.solve(&mut y0, 0.0, 1.0, None, None, &mut args).err(), + solver.solve(&mut y0, 0.0, 1.0, None, &mut args).err(), Some("an element of the vector is either infinite or NaN") ); } @@ -475,7 +501,7 @@ mod tests { params.step.n_step_max = 1; // will make the solver to fail (too few steps) let mut solver = OdeSolver::new(params, &system).unwrap(); assert_eq!( - solver.solve(&mut y0, 0.0, 1.0, None, None, &mut args).err(), + solver.solve(&mut y0, 0.0, 1.0, None, &mut args).err(), Some("variable stepping did not converge") ); } @@ -488,7 +514,7 @@ mod tests { let params = Params::new(Method::FwEuler); let mut solver = OdeSolver::new(params, &system).unwrap(); let mut y = y0.clone(); - solver.solve(&mut y, x0, x1, None, None, &mut args).unwrap(); + solver.solve(&mut y, x0, x1, None, &mut args).unwrap(); vec_approx_eq(&y, &[1.0], 1e-15); } @@ -499,7 +525,7 @@ mod tests { let mut params = Params::new(Method::DoPri5); params.step.h_ini = 20.0; // will be truncated to 1 yielding a single step let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, 0.0, 1.0, None, None, &mut args).unwrap(); + solver.solve(&mut y0, 0.0, 1.0, None, &mut args).unwrap(); assert_eq!(solver.work.stats.n_accepted, 1); vec_approx_eq(&y0, &[1.0], 1e-15); } @@ -510,7 +536,7 @@ mod tests { let mut params = Params::new(Method::MdEuler); params.step.h_ini = 0.1; let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, 0.0, 0.3, None, None, &mut args).unwrap(); + solver.solve(&mut y0, 0.0, 0.3, None, &mut args).unwrap(); vec_approx_eq(&y0, &[0.3], 1e-15); } @@ -549,10 +575,9 @@ mod tests { let x1 = 1.0; let params = Params::new(Method::FwEuler); let mut solver = OdeSolver::new(params, &system).unwrap(); - let mut out = Output::new(); - out.set_dense_recording(true, 0.1, &[0]).unwrap(); + solver.enable_output().set_dense_recording(true, 0.1, &[0]).unwrap(); assert_eq!( - solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args).err(), + solver.solve(&mut y0, x0, x1, None, &mut args).err(), Some("dense output is not available for the FwEuler method") ); } @@ -565,9 +590,10 @@ mod tests { let mut solver = OdeSolver::new(params, &system).unwrap(); // output - let mut out = Output::new(); let path_key = "/tmp/russell_ode/test_solve_step_output_works"; - out.set_yx_correct(y_fn_x) + solver + .enable_output() + .set_yx_correct(y_fn_x) .set_step_file_writing(true, path_key) .set_step_recording(true, &[0]) .set_step_callback(true, |stats, h, x, y, _args| { @@ -580,16 +606,14 @@ mod tests { // solve let h_equal = Some(0.2); let mut y = y0.clone(); - solver - .solve(&mut y, 0.0, 0.4, h_equal, Some(&mut out), &mut args) - .unwrap(); + solver.solve(&mut y, 0.0, 0.4, h_equal, &mut args).unwrap(); // check vec_approx_eq(&y, &[0.4], 1e-15); - array_approx_eq(&out.step_h, &[0.2, 0.2, 0.2], 1e-15); - array_approx_eq(&out.step_x, &[0.0, 0.2, 0.4], 1e-15); - array_approx_eq(&out.step_y.get(&0).unwrap(), &[0.0, 0.2, 0.4], 1e-15); - array_approx_eq(&out.step_global_error, &[0.0, 0.0, 0.0], 1e-15); + array_approx_eq(&solver.out().step_h, &[0.2, 0.2, 0.2], 1e-15); + array_approx_eq(&solver.out().step_x, &[0.0, 0.2, 0.4], 1e-15); + array_approx_eq(&solver.out().step_y.get(&0).unwrap(), &[0.0, 0.2, 0.4], 1e-15); + array_approx_eq(&solver.out().step_global_error, &[0.0, 0.0, 0.0], 1e-15); // check count file let count = OutCount::read_json(&format!("{}_count.json", path_key)).unwrap(); @@ -608,43 +632,49 @@ mod tests { assert_eq!(cb(&solver.stats(), 0.0, 0.0, &y0, &mut args).err(), Some("unreachable")); // run again without step output - out.clear(); - out.set_step_file_writing(false, path_key) + solver + .enable_output() + .clear() + .set_step_file_writing(false, path_key) .set_step_recording(false, &[]) .set_step_callback(false, cb); let mut y = y0.clone(); - solver.solve(&mut y, 0.0, 0.4, None, Some(&mut out), &mut args).unwrap(); + solver.solve(&mut y, 0.0, 0.4, None, &mut args).unwrap(); vec_approx_eq(&y, &[0.4], 1e-15); - assert_eq!(out.step_h.len(), 0); - assert_eq!(out.step_x.len(), 0); - assert_eq!(out.step_y.len(), 0); - assert_eq!(out.step_global_error.len(), 0); + assert_eq!(solver.out().step_h.len(), 0); + assert_eq!(solver.out().step_x.len(), 0); + assert_eq!(solver.out().step_y.len(), 0); + assert_eq!(solver.out().step_global_error.len(), 0); // run again and stop earlier - out.clear(); - out.set_step_callback(true, |stats, _h, _x, _y, _args| { - if stats.n_accepted > 0 { - Ok(true) // stop - } else { - Ok(false) // do not stop - } - }); + solver + .enable_output() + .clear() + .set_step_callback(true, |stats, _h, _x, _y, _args| { + if stats.n_accepted > 0 { + Ok(true) // stop + } else { + Ok(false) // do not stop + } + }); let mut y = y0.clone(); - solver.solve(&mut y, 0.0, 0.4, None, Some(&mut out), &mut args).unwrap(); + solver.solve(&mut y, 0.0, 0.4, None, &mut args).unwrap(); assert!(y[0] > 0.0 && y[0] < 0.4); // run again and stop due to error - out.clear(); - out.set_step_callback(true, |stats, _h, _x, _y, _args| { - if stats.n_accepted > 0 { - Err("stop with error") - } else { - Ok(false) // do not stop - } - }); + solver + .enable_output() + .clear() + .set_step_callback(true, |stats, _h, _x, _y, _args| { + if stats.n_accepted > 0 { + Err("stop with error") + } else { + Ok(false) // do not stop + } + }); let mut y = y0.clone(); assert_eq!( - solver.solve(&mut y, 0.0, 0.4, None, Some(&mut out), &mut args).err(), + solver.solve(&mut y, 0.0, 0.4, None, &mut args).err(), Some("stop with error") ); } @@ -657,36 +687,41 @@ mod tests { let mut solver = OdeSolver::new(params, &system).unwrap(); // output - let mut out = Output::new(); const H_OUT: f64 = 0.1; let path_key = "/tmp/russell_ode/test_solve_dense_output_works"; - out.set_yx_correct(y_fn_x); - out.set_dense_file_writing(true, H_OUT, path_key).unwrap(); - out.set_dense_recording(true, H_OUT, &[0]).unwrap(); - out.set_dense_callback(true, H_OUT, |stats, h, x, y, _args| { - assert_eq!(h, 0.2); - if stats.n_accepted < 2 { - approx_eq(x, (stats.n_accepted as f64) * H_OUT, 1e-15); - approx_eq(y[0], (stats.n_accepted as f64) * H_OUT, 1e-15); - } else { - approx_eq(y[0], x, 1e-15); - } - Ok(false) - }) - .unwrap(); + solver + .enable_output() + .set_yx_correct(y_fn_x) + .set_dense_file_writing(true, H_OUT, path_key) + .unwrap() + .set_dense_recording(true, H_OUT, &[0]) + .unwrap() + .set_dense_callback(true, H_OUT, |stats, h, x, y, _args| { + assert_eq!(h, 0.2); + if stats.n_accepted < 2 { + approx_eq(x, (stats.n_accepted as f64) * H_OUT, 1e-15); + approx_eq(y[0], (stats.n_accepted as f64) * H_OUT, 1e-15); + } else { + approx_eq(y[0], x, 1e-15); + } + Ok(false) + }) + .unwrap(); // solve let h_equal = Some(0.2); let mut y = y0.clone(); - solver - .solve(&mut y, 0.0, 0.4, h_equal, Some(&mut out), &mut args) - .unwrap(); + solver.solve(&mut y, 0.0, 0.4, h_equal, &mut args).unwrap(); // check vec_approx_eq(&y, &[0.4], 1e-15); - assert_eq!(&out.dense_step_index, &[0, 1, 2, 2, 2]); - array_approx_eq(&out.dense_x, &[0.0, 0.1, 0.2, 0.3, 0.4], 1e-15); - array_approx_eq(&out.dense_y.get(&0).unwrap(), &[0.0, 0.1, 0.2, 0.3, 0.4], 1e-15); + assert_eq!(&solver.out().dense_step_index, &[0, 1, 2, 2, 2]); + array_approx_eq(&solver.out().dense_x, &[0.0, 0.1, 0.2, 0.3, 0.4], 1e-15); + array_approx_eq( + &solver.out().dense_y.get(&0).unwrap(), + &[0.0, 0.1, 0.2, 0.3, 0.4], + 1e-15, + ); // check count file let count = OutCount::read_json(&format!("{}_count.json", path_key)).unwrap(); @@ -705,71 +740,82 @@ mod tests { assert_eq!(cb(&solver.stats(), 0.0, 0.0, &y0, &mut args).err(), Some("unreachable")); // run again without dense output - out.clear(); - out.set_dense_file_writing(false, H_OUT, path_key).unwrap(); - out.set_dense_recording(false, H_OUT, &[]).unwrap(); - out.set_dense_callback(false, H_OUT, cb).unwrap(); + solver + .enable_output() + .clear() + .set_dense_file_writing(false, H_OUT, path_key) + .unwrap() + .set_dense_recording(false, H_OUT, &[]) + .unwrap() + .set_dense_callback(false, H_OUT, cb) + .unwrap(); let mut y = y0.clone(); - solver.solve(&mut y, 0.0, 0.4, None, Some(&mut out), &mut args).unwrap(); + solver.solve(&mut y, 0.0, 0.4, None, &mut args).unwrap(); vec_approx_eq(&y, &[0.4], 1e-15); - assert_eq!(out.dense_step_index.len(), 0); - assert_eq!(out.dense_x.len(), 0); - assert_eq!(out.dense_y.len(), 0); + assert_eq!(solver.out().dense_step_index.len(), 0); + assert_eq!(solver.out().dense_x.len(), 0); + assert_eq!(solver.out().dense_y.len(), 0); // run again but stop at the first output - out.clear(); - out.set_dense_callback(true, H_OUT, |_stats, _h, _x, _y, _args| { - Ok(true) // stop - }) - .unwrap(); + solver + .enable_output() + .clear() + .set_dense_callback(true, H_OUT, |_stats, _h, _x, _y, _args| { + Ok(true) // stop + }) + .unwrap(); let mut y = y0.clone(); - solver.solve(&mut y, 0.0, 0.4, None, Some(&mut out), &mut args).unwrap(); + solver.solve(&mut y, 0.0, 0.4, None, &mut args).unwrap(); assert_eq!(solver.work.stats.n_accepted, 0); assert_eq!(y[0], 0.0); // run again and stop earlier - out.clear(); - out.set_dense_callback(true, H_OUT, |stats, _h, _x, _y, _args| { - if stats.n_accepted > 0 { - Ok(true) // stop - } else { - Ok(false) // do not stop - } - }) - .unwrap(); - // ... equal steps - let mut y = y0.clone(); solver - .solve(&mut y, 0.0, 0.4, Some(0.2), Some(&mut out), &mut args) + .enable_output() + .clear() + .set_dense_callback(true, H_OUT, |stats, _h, _x, _y, _args| { + if stats.n_accepted > 0 { + Ok(true) // stop + } else { + Ok(false) // do not stop + } + }) .unwrap(); + // ... equal steps + let mut y = y0.clone(); + solver.solve(&mut y, 0.0, 0.4, Some(0.2), &mut args).unwrap(); assert!(y[0] > 0.0 && y[0] < 0.4); // ... variable steps let mut y = y0.clone(); - solver.solve(&mut y, 0.0, 0.4, None, Some(&mut out), &mut args).unwrap(); + solver.solve(&mut y, 0.0, 0.4, None, &mut args).unwrap(); assert!(y[0] > 0.0 && y[0] < 0.4); // run again and stop due to error - out.clear(); + solver.enable_output().clear(); // ... first step - out.set_dense_callback(true, H_OUT, |_stats, _h, _x, _y, _args| Err("stop with error")) + solver + .enable_output() + .set_dense_callback(true, H_OUT, |_stats, _h, _x, _y, _args| Err("stop with error")) .unwrap(); let mut y = y0.clone(); assert_eq!( - solver.solve(&mut y, 0.0, 0.4, None, Some(&mut out), &mut args).err(), + solver.solve(&mut y, 0.0, 0.4, None, &mut args).err(), Some("stop with error") ); // ... next steps - out.set_dense_callback(true, H_OUT, |stats, _h, _x, _y, _args| { - if stats.n_accepted > 0 { - Err("stop with error") - } else { - Ok(false) // do not stop - } - }) - .unwrap(); + solver + .enable_output() + .set_dense_callback(true, H_OUT, |stats, _h, _x, _y, _args| { + if stats.n_accepted > 0 { + Err("stop with error") + } else { + Ok(false) // do not stop + } + }) + .unwrap(); let mut y = y0.clone(); assert_eq!( - solver.solve(&mut y, 0.0, 0.4, None, Some(&mut out), &mut args).err(), + solver.solve(&mut y, 0.0, 0.4, None, &mut args).err(), Some("stop with error") ); } @@ -819,21 +865,22 @@ mod tests { let mut solver = OdeSolver::new(params, &system).unwrap(); // output - let mut out = Output::new(); - out.set_dense_callback(true, 0.1, |_stats, _h, _x, _y, args: &mut Args| { - if args.out_count == args.out_barrier { - return Err("out: artificial error"); - } - args.out_count += 1; - Ok(false) // do not stop - }) - .unwrap(); + solver + .enable_output() + .set_dense_callback(true, 0.1, |_stats, _h, _x, _y, args: &mut Args| { + if args.out_count == args.out_barrier { + return Err("out: artificial error"); + } + args.out_count += 1; + Ok(false) // do not stop + }) + .unwrap(); // equal steps ----------------------------------------------------------- // first error @ actual.step assert_eq!( - solver.solve(&mut y, x0, x1, Some(0.2), None, &mut args).err(), + solver.solve(&mut y, x0, x1, Some(0.2), &mut args).err(), Some("f: artificial error") ); @@ -844,14 +891,14 @@ mod tests { // 2. in computations related to the stiffness detection args.f_barrier += 12; // nstage = 12 (need to skip the next call to 'step') assert_eq!( - solver.solve(&mut y, x0, x1, Some(0.2), Some(&mut out), &mut args).err(), + solver.solve(&mut y, x0, x1, Some(0.2), &mut args).err(), Some("f: artificial error") ); // third error @ the second output args.f_barrier += 2 * 12; // skip next calls to 'step' assert_eq!( - solver.solve(&mut y, x0, x1, Some(0.2), Some(&mut out), &mut args).err(), + solver.solve(&mut y, x0, x1, Some(0.2), &mut args).err(), Some("out: artificial error") ); @@ -859,7 +906,7 @@ mod tests { args.f_barrier += 2 * 12; // skip next calls to 'step' args.out_barrier += 2; // skip first and second output assert_eq!( - solver.solve(&mut y, x0, x1, Some(0.2), Some(&mut out), &mut args).err(), + solver.solve(&mut y, x0, x1, Some(0.2), &mut args).err(), Some("out: artificial error") ); @@ -871,7 +918,7 @@ mod tests { args.out_count = 0; args.out_barrier = 2; // first and second outputs assert_eq!( - solver.solve(&mut y, x0, x1, None, None, &mut args).err(), + solver.solve(&mut y, x0, x1, None, &mut args).err(), Some("f: artificial error") ); @@ -882,7 +929,7 @@ mod tests { // 2. in computations related to the stiffness detection args.f_barrier += 12; // nstage = 12 (need to skip the next call to 'step') assert_eq!( - solver.solve(&mut y, x0, x1, None, Some(&mut out), &mut args).err(), + solver.solve(&mut y, x0, x1, None, &mut args).err(), Some("f: artificial error") ); @@ -892,7 +939,7 @@ mod tests { args.out_count = 0; args.out_barrier = 2; // first and second outputs assert_eq!( - solver.solve(&mut y, x0, x1, None, Some(&mut out), &mut args).err(), + solver.solve(&mut y, x0, x1, None, &mut args).err(), Some("out: artificial error") ); } diff --git a/russell_ode/src/output.rs b/russell_ode/src/output.rs index a56c914e..02beb898 100644 --- a/russell_ode/src/output.rs +++ b/russell_ode/src/output.rs @@ -54,10 +54,10 @@ pub struct OutCount { /// /// * `A` -- Is auxiliary argument for the `F`, `J`, `YxFunction`, and `OutCallback` functions. /// It may be simply [crate::NoArgs] indicating that no arguments are effectively used. -pub struct Output { +pub struct Output<'a, A> { // --- step -------------------------------------------------------------------------------------------- /// Holds a callback function called on an accepted step - step_callback: Option>, + step_callback: Option Result + 'a>>, /// Save the results to a file (step) step_file_key: Option, @@ -188,7 +188,7 @@ impl OutCount { } } -impl Output { +impl<'a, A> Output<'a, A> { /// Allocates a new instance pub fn new() -> Self { const EMPTY: usize = 0; @@ -234,9 +234,13 @@ impl Output { /// /// * `enable` -- Enable/disable the output /// * `callback` -- Function to be executed on an accepted step - pub fn set_step_callback(&mut self, enable: bool, callback: OutCallback) -> &mut Self { + pub fn set_step_callback( + &mut self, + enable: bool, + callback: impl Fn(&Stats, f64, f64, &Vector, &mut A) -> Result + 'a, + ) -> &mut Self { if enable { - self.step_callback = Some(callback); + self.step_callback = Some(Box::new(callback)); } else { self.step_callback = None; } @@ -388,7 +392,7 @@ impl Output { } /// Clears the results - pub fn clear(&mut self) { + pub fn clear(&mut self) -> &mut Self { // step self.step_h.clear(); self.step_x.clear(); @@ -402,10 +406,11 @@ impl Output { self.stiff_step_index.clear(); self.stiff_x.clear(); self.stiff_h_times_rho.clear(); + self } /// Executes the output at an accepted step - pub(crate) fn execute<'a>( + pub(crate) fn execute( &mut self, work: &Workspace, h: f64, @@ -417,8 +422,8 @@ impl Output { // --- step -------------------------------------------------------------------------------------------- // // step output: callback - if let Some(cb) = self.step_callback { - let stop = cb(&work.stats, h, x, y, args)?; + if let Some(cb) = self.step_callback.as_ref() { + let stop = (cb)(&work.stats, h, x, y, args)?; if stop { return Ok(stop); } @@ -592,6 +597,8 @@ impl Output { #[cfg(test)] mod tests { + use crate::NoArgs; + use super::*; #[test] @@ -657,10 +664,10 @@ mod tests { #[test] fn set_methods_handle_errors() { - struct Args {} - let mut out: Output = Output::new(); + let mut out = Output::new(); assert_eq!( - out.set_dense_callback(true, 0.0, |_, _, _, _, _| Ok(false)).err(), + out.set_dense_callback(true, 0.0, |_, _, _, _, _: &mut NoArgs| Ok(false)) + .err(), Some("h_out must be > EPSILON") ); let path_key = "/tmp/russell_ode/test_output_errors"; diff --git a/russell_ode/tests/test_bweuler.rs b/russell_ode/tests/test_bweuler.rs index f700f198..227d52d0 100644 --- a/russell_ode/tests/test_bweuler.rs +++ b/russell_ode/tests/test_bweuler.rs @@ -16,7 +16,7 @@ fn test_bweuler_hairer_wanner_eq1() { // solve the ODE system let h_equal = Some(1.875 / 50.0); - solver.solve(&mut y0, x0, x1, h_equal, None, &mut args).unwrap(); + solver.solve(&mut y0, x0, x1, h_equal, &mut args).unwrap(); // get statistics let stat = solver.stats(); @@ -59,7 +59,7 @@ fn test_bweuler_hairer_wanner_eq1_num_jac() { // solve the ODE system let mut solver = OdeSolver::new(params, &system).unwrap(); let h_equal = Some(1.875 / 50.0); - solver.solve(&mut y0, x0, x1, h_equal, None, &mut args).unwrap(); + solver.solve(&mut y0, x0, x1, h_equal, &mut args).unwrap(); // get statistics let stat = solver.stats(); @@ -102,7 +102,7 @@ fn test_bweuler_hairer_wanner_eq1_modified_newton() { // solve the ODE system let mut solver = OdeSolver::new(params, &system).unwrap(); let h_equal = Some(1.875 / 50.0); - solver.solve(&mut y0, x0, x1, h_equal, None, &mut args).unwrap(); + solver.solve(&mut y0, x0, x1, h_equal, &mut args).unwrap(); // get statistics let stat = solver.stats(); diff --git a/russell_ode/tests/test_dopri5_arenstorf.rs b/russell_ode/tests/test_dopri5_arenstorf.rs index 2d7ad28c..9599af4c 100644 --- a/russell_ode/tests/test_dopri5_arenstorf.rs +++ b/russell_ode/tests/test_dopri5_arenstorf.rs @@ -1,5 +1,5 @@ use russell_lab::{approx_eq, format_fortran}; -use russell_ode::{Method, OdeSolver, Output, Params, Samples}; +use russell_ode::{Method, OdeSolver, Params, Samples}; #[test] fn test_dopri5_arenstorf() { @@ -11,14 +11,18 @@ fn test_dopri5_arenstorf() { params.step.h_ini = 1e-4; params.set_tolerances(1e-7, 1e-7, None).unwrap(); + // allocate the solver + let mut solver = OdeSolver::new(params, &system).unwrap(); + // enable dense output with 1.0 spacing - let mut out = Output::new(); - out.set_dense_recording(true, 1.0, &[0, 1, 2, 3]).unwrap(); + solver + .enable_output() + .set_dense_recording(true, 1.0, &[0, 1, 2, 3]) + .unwrap(); // solve the ODE system - let mut solver = OdeSolver::new(params, &system).unwrap(); let y = &mut y0; - solver.solve(y, x0, x1, None, Some(&mut out), &mut args).unwrap(); + solver.solve(y, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); @@ -31,16 +35,16 @@ fn test_dopri5_arenstorf() { approx_eq(stat.h_accepted, 5.258587607119909E-04, 1e-10); // print dense output - let n_dense = out.dense_step_index.len(); + let n_dense = solver.out().dense_step_index.len(); for i in 0..n_dense { println!( "step ={:>4}, x ={:6.2}, y ={}{}{}{}", - out.dense_step_index[i], - out.dense_x[i], - format_fortran(out.dense_y.get(&0).unwrap()[i]), - format_fortran(out.dense_y.get(&1).unwrap()[i]), - format_fortran(out.dense_y.get(&2).unwrap()[i]), - format_fortran(out.dense_y.get(&3).unwrap()[i]), + solver.out().dense_step_index[i], + solver.out().dense_x[i], + format_fortran(solver.out().dense_y.get(&0).unwrap()[i]), + format_fortran(solver.out().dense_y.get(&1).unwrap()[i]), + format_fortran(solver.out().dense_y.get(&2).unwrap()[i]), + format_fortran(solver.out().dense_y.get(&3).unwrap()[i]), ); } diff --git a/russell_ode/tests/test_dopri5_arenstorf_debug.rs b/russell_ode/tests/test_dopri5_arenstorf_debug.rs index f903aebc..b2ca3a71 100644 --- a/russell_ode/tests/test_dopri5_arenstorf_debug.rs +++ b/russell_ode/tests/test_dopri5_arenstorf_debug.rs @@ -12,10 +12,12 @@ fn test_dopri5_arenstorf_debug() { params.set_tolerances(1e-7, 1e-7, None).unwrap(); params.debug = true; - // solve the ODE system + // allocate the solver let mut solver = OdeSolver::new(params, &system).unwrap(); + + // solve the ODE system let y = &mut y0; - solver.solve(y, x0, x1, None, None, &mut args).unwrap(); + solver.solve(y, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); diff --git a/russell_ode/tests/test_dopri5_hairer_wanner_eq1.rs b/russell_ode/tests/test_dopri5_hairer_wanner_eq1.rs index 402f858f..75d97cef 100644 --- a/russell_ode/tests/test_dopri5_hairer_wanner_eq1.rs +++ b/russell_ode/tests/test_dopri5_hairer_wanner_eq1.rs @@ -1,5 +1,5 @@ use russell_lab::{approx_eq, format_fortran, Vector}; -use russell_ode::{Method, OdeSolver, Output, Params, Samples}; +use russell_ode::{Method, OdeSolver, Params, Samples}; #[test] fn test_dopri5_hairer_wanner_eq1() { @@ -14,13 +14,14 @@ fn test_dopri5_hairer_wanner_eq1() { let mut params = Params::new(Method::DoPri5); params.step.h_ini = 1e-4; + // allocate the solver + let mut solver = OdeSolver::new(params, &system).unwrap(); + // enable dense output with 0.1 spacing - let mut out = Output::new(); - out.set_dense_recording(true, 0.1, &[0]).unwrap(); + solver.enable_output().set_dense_recording(true, 0.1, &[0]).unwrap(); // solve the ODE system - let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args).unwrap(); + solver.solve(&mut y0, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); @@ -34,13 +35,13 @@ fn test_dopri5_hairer_wanner_eq1() { approx_eq(y0[0], y1_correct[0], 4e-5); // print dense output - let n_dense = out.dense_step_index.len(); + let n_dense = solver.out().dense_step_index.len(); for i in 0..n_dense { println!( "step ={:>4}, x ={:6.2}, y ={}", - out.dense_step_index[i], - out.dense_x[i], - format_fortran(out.dense_y.get(&0).unwrap()[i]), + solver.out().dense_step_index[i], + solver.out().dense_x[i], + format_fortran(solver.out().dense_y.get(&0).unwrap()[i]), ); } diff --git a/russell_ode/tests/test_dopri5_van_der_pol_debug.rs b/russell_ode/tests/test_dopri5_van_der_pol_debug.rs index 5fc2a0a6..4c7e41d4 100644 --- a/russell_ode/tests/test_dopri5_van_der_pol_debug.rs +++ b/russell_ode/tests/test_dopri5_van_der_pol_debug.rs @@ -1,5 +1,5 @@ use russell_lab::{approx_eq, array_approx_eq, format_fortran, Vector}; -use russell_ode::{Method, OdeSolver, Output, Params, Samples}; +use russell_ode::{Method, OdeSolver, Params, Samples}; #[test] fn test_dopri5_van_der_pol_debug() { @@ -17,15 +17,18 @@ fn test_dopri5_van_der_pol_debug() { params.stiffness.stop_with_error = false; params.stiffness.save_results = true; - // output (to save stiff stations) - let mut out = Output::new(); + // allocate the solver + let mut solver = OdeSolver::new(params, &system).unwrap(); + + // enable output (to save stiff stations) + solver.enable_output(); // solve the ODE system let mut y0 = Vector::from(&[2.0, 0.0]); let x0 = 0.0; let x1 = 2.0; - let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args).unwrap(); + + solver.solve(&mut y0, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); @@ -45,9 +48,9 @@ fn test_dopri5_van_der_pol_debug() { assert_eq!(stat.n_rejected, 20); // check stiffness results - assert_eq!(out.stiff_step_index, &[32, 189, 357]); + assert_eq!(solver.out().stiff_step_index, &[32, 189, 357]); array_approx_eq( - &out.stiff_x, + &solver.out().stiff_x, &[1.216973774601867E-02, 8.717646581250652E-01, 1.744401291692531E+00], 1e-12, ); diff --git a/russell_ode/tests/test_dopri8_van_der_pol.rs b/russell_ode/tests/test_dopri8_van_der_pol.rs index 304dfa76..8502eb59 100644 --- a/russell_ode/tests/test_dopri8_van_der_pol.rs +++ b/russell_ode/tests/test_dopri8_van_der_pol.rs @@ -1,5 +1,5 @@ use russell_lab::{approx_eq, format_fortran, Vector}; -use russell_ode::{Method, OdeSolver, Output, Params, Samples}; +use russell_ode::{Method, OdeSolver, Params, Samples}; #[test] fn test_dopri8_van_der_pol() { @@ -12,16 +12,17 @@ fn test_dopri8_van_der_pol() { params.step.h_ini = 1e-6; params.set_tolerances(1e-9, 1e-9, None).unwrap(); + // allocate the solver + let mut solver = OdeSolver::new(params, &system).unwrap(); + // enable dense output with 0.2 spacing - let mut out = Output::new(); - out.set_dense_recording(true, 0.1, &[0, 1]).unwrap(); + solver.enable_output().set_dense_recording(true, 0.1, &[0, 1]).unwrap(); // solve the ODE system let mut y0 = Vector::from(&[2.0, 0.0]); let x0 = 0.0; let x1 = 2.0; - let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args).unwrap(); + solver.solve(&mut y0, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); @@ -32,14 +33,14 @@ fn test_dopri8_van_der_pol() { approx_eq(stat.h_accepted, 8.656983588595286E-04, 1e-5); // print dense output - let n_dense = out.dense_step_index.len(); + let n_dense = solver.out().dense_step_index.len(); for i in 0..n_dense { println!( "step ={:>4}, x ={:5.2}, y ={}{}", - out.dense_step_index[i], - out.dense_x[i], - format_fortran(out.dense_y.get(&0).unwrap()[i]), - format_fortran(out.dense_y.get(&1).unwrap()[i]), + solver.out().dense_step_index[i], + solver.out().dense_x[i], + format_fortran(solver.out().dense_y.get(&0).unwrap()[i]), + format_fortran(solver.out().dense_y.get(&1).unwrap()[i]), ); } diff --git a/russell_ode/tests/test_dopri8_van_der_pol_debug.rs b/russell_ode/tests/test_dopri8_van_der_pol_debug.rs index ab18e8f1..121312ac 100644 --- a/russell_ode/tests/test_dopri8_van_der_pol_debug.rs +++ b/russell_ode/tests/test_dopri8_van_der_pol_debug.rs @@ -1,5 +1,5 @@ use russell_lab::{approx_eq, array_approx_eq, format_fortran, Vector}; -use russell_ode::{Method, OdeSolver, Output, Params, Samples}; +use russell_ode::{Method, OdeSolver, Params, Samples}; #[test] fn test_dopri8_van_der_pol_debug() { @@ -17,15 +17,17 @@ fn test_dopri8_van_der_pol_debug() { params.stiffness.stop_with_error = false; params.stiffness.save_results = true; + // allocate the solver + let mut solver = OdeSolver::new(params, &system).unwrap(); + // output (to save stiff stations) - let mut out = Output::new(); + solver.enable_output(); // solve the ODE system let mut y0 = Vector::from(&[2.0, 0.0]); let x0 = 0.0; let x1 = 2.0; - let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args).unwrap(); + solver.solve(&mut y0, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); @@ -46,9 +48,9 @@ fn test_dopri8_van_der_pol_debug() { assert_eq!(stat.n_rejected, 20); // check stiffness results - assert_eq!(out.stiff_step_index, &[21, 109, 196]); + assert_eq!(solver.out().stiff_step_index, &[21, 109, 196]); array_approx_eq( - &out.stiff_x, + &solver.out().stiff_x, &[1.563905377322407E-02, 8.759592223459979E-01, 1.749270939102191E+00], 1e-7, ); diff --git a/russell_ode/tests/test_fweuler.rs b/russell_ode/tests/test_fweuler.rs index 57ef717c..8383ca8d 100644 --- a/russell_ode/tests/test_fweuler.rs +++ b/russell_ode/tests/test_fweuler.rs @@ -16,7 +16,7 @@ fn test_fweuler_hairer_wanner_eq1() { // solve the ODE system let mut solver = OdeSolver::new(params, &system).unwrap(); let h_equal = Some(1.875 / 50.0); - solver.solve(&mut y0, x0, x1, h_equal, None, &mut args).unwrap(); + solver.solve(&mut y0, x0, x1, h_equal, &mut args).unwrap(); // get statistics let stat = solver.stats(); diff --git a/russell_ode/tests/test_mdeuler.rs b/russell_ode/tests/test_mdeuler.rs index 653064a4..ec31ee61 100644 --- a/russell_ode/tests/test_mdeuler.rs +++ b/russell_ode/tests/test_mdeuler.rs @@ -16,7 +16,7 @@ fn test_mdeuler_hairer_wanner_eq1() { // solve the ODE system let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, x0, x1, None, None, &mut args).unwrap(); + solver.solve(&mut y0, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); diff --git a/russell_ode/tests/test_radau5_amplifier1t.rs b/russell_ode/tests/test_radau5_amplifier1t.rs index 3d94ca73..6d277813 100644 --- a/russell_ode/tests/test_radau5_amplifier1t.rs +++ b/russell_ode/tests/test_radau5_amplifier1t.rs @@ -1,5 +1,5 @@ use russell_lab::{approx_eq, format_fortran, format_scientific}; -use russell_ode::{Method, OdeSolver, Output, Params, Samples}; +use russell_ode::{Method, OdeSolver, Params, Samples}; #[test] fn test_radau5_amplifier1t() { @@ -14,13 +14,17 @@ fn test_radau5_amplifier1t() { params.step.h_ini = 1e-6; params.set_tolerances(1e-4, 1e-4, None).unwrap(); + // allocate the solver + let mut solver = OdeSolver::new(params, &system).unwrap(); + // enable output of accepted steps - let mut out = Output::new(); - out.set_dense_recording(true, 0.001, &[0, 4]).unwrap(); + solver + .enable_output() + .set_dense_recording(true, 0.001, &[0, 4]) + .unwrap(); // solve the ODE system - let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args).unwrap(); + solver.solve(&mut y0, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); @@ -34,17 +38,17 @@ fn test_radau5_amplifier1t() { approx_eq(stat.h_accepted, 7.791381954171996E-04, 1e-6); // compare dense output with Mathematica - let n_dense = out.dense_step_index.len(); + let n_dense = solver.out().dense_step_index.len(); for i in 0..n_dense { - approx_eq(out.dense_x[i], X_MATH[i], 1e-15); - let diff0 = f64::abs(out.dense_y.get(&0).unwrap()[i] - Y0_MATH[i]); - let diff4 = f64::abs(out.dense_y.get(&4).unwrap()[i] - Y4_MATH[i]); + approx_eq(solver.out().dense_x[i], X_MATH[i], 1e-15); + let diff0 = f64::abs(solver.out().dense_y.get(&0).unwrap()[i] - Y0_MATH[i]); + let diff4 = f64::abs(solver.out().dense_y.get(&4).unwrap()[i] - Y4_MATH[i]); println!( "step ={:>4}, x ={:7.4}, y1and5 ={}{}, diff1and5 ={}{}", - out.dense_step_index[i], - out.dense_x[i], - format_fortran(out.dense_y.get(&0).unwrap()[i]), - format_fortran(out.dense_y.get(&4).unwrap()[i]), + solver.out().dense_step_index[i], + solver.out().dense_x[i], + format_fortran(solver.out().dense_y.get(&0).unwrap()[i]), + format_fortran(solver.out().dense_y.get(&4).unwrap()[i]), format_scientific(diff0, 8, 1), format_scientific(diff4, 8, 1) ); diff --git a/russell_ode/tests/test_radau5_brusselator_pde.rs b/russell_ode/tests/test_radau5_brusselator_pde.rs index 4200ecee..af070951 100644 --- a/russell_ode/tests/test_radau5_brusselator_pde.rs +++ b/russell_ode/tests/test_radau5_brusselator_pde.rs @@ -20,10 +20,12 @@ fn test_radau5_brusselator_pde() { let mut params = Params::new(Method::Radau5); params.set_tolerances(1e-3, 1e-3, None).unwrap(); - // solve the ODE system + // allocate the solver let mut solver = OdeSolver::new(params, &system).unwrap(); + + // solve the ODE system let yy = &mut yy0; - solver.solve(yy, t0, t1, None, None, &mut args).unwrap(); + solver.solve(yy, t0, t1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); diff --git a/russell_ode/tests/test_radau5_hairer_wanner_eq1.rs b/russell_ode/tests/test_radau5_hairer_wanner_eq1.rs index deb19f0e..394a8b7e 100644 --- a/russell_ode/tests/test_radau5_hairer_wanner_eq1.rs +++ b/russell_ode/tests/test_radau5_hairer_wanner_eq1.rs @@ -1,5 +1,5 @@ use russell_lab::{approx_eq, format_fortran, Vector}; -use russell_ode::{Method, OdeSolver, Output, Params, Samples}; +use russell_ode::{Method, OdeSolver, Params, Samples}; #[test] fn test_radau5_hairer_wanner_eq1() { @@ -14,13 +14,14 @@ fn test_radau5_hairer_wanner_eq1() { let mut params = Params::new(Method::Radau5); params.step.h_ini = 1e-4; + // allocate the solver + let mut solver = OdeSolver::new(params, &system).unwrap(); + // enable dense output with 0.2 spacing - let mut out = Output::new(); - out.set_dense_recording(true, 0.2, &[0]).unwrap(); + solver.enable_output().set_dense_recording(true, 0.2, &[0]).unwrap(); // solve the ODE system - let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args).unwrap(); + solver.solve(&mut y0, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); @@ -35,13 +36,13 @@ fn test_radau5_hairer_wanner_eq1() { approx_eq(y0[0], y1_correct[0], 3e-5); // print dense output - let n_dense = out.dense_step_index.len(); + let n_dense = solver.out().dense_step_index.len(); for i in 0..n_dense { println!( "step ={:>4}, x ={:5.2}, y ={}", - out.dense_step_index[i], - out.dense_x[i], - format_fortran(out.dense_y.get(&0).unwrap()[i]) + solver.out().dense_step_index[i], + solver.out().dense_x[i], + format_fortran(solver.out().dense_y.get(&0).unwrap()[i]) ); } diff --git a/russell_ode/tests/test_radau5_hairer_wanner_eq1_debug.rs b/russell_ode/tests/test_radau5_hairer_wanner_eq1_debug.rs index 7b5cc44e..027869ac 100644 --- a/russell_ode/tests/test_radau5_hairer_wanner_eq1_debug.rs +++ b/russell_ode/tests/test_radau5_hairer_wanner_eq1_debug.rs @@ -15,9 +15,11 @@ fn test_radau5_hairer_wanner_eq1_debug() { params.step.h_ini = 1e-4; params.debug = true; - // solve the ODE system + // allocate the solver let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, x0, x1, None, None, &mut args).unwrap(); + + // solve the ODE system + solver.solve(&mut y0, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); diff --git a/russell_ode/tests/test_radau5_robertson.rs b/russell_ode/tests/test_radau5_robertson.rs index b12f0e76..4a3523ba 100644 --- a/russell_ode/tests/test_radau5_robertson.rs +++ b/russell_ode/tests/test_radau5_robertson.rs @@ -1,5 +1,5 @@ use russell_lab::{approx_eq, format_fortran}; -use russell_ode::{Method, OdeSolver, Output, Params, Samples}; +use russell_ode::{Method, OdeSolver, Params, Samples}; #[test] fn test_radau5_robertson() { @@ -14,13 +14,14 @@ fn test_radau5_robertson() { params.step.h_ini = 1e-6; params.set_tolerances(1e-8, 1e-2, None).unwrap(); + // allocate the solver + let mut solver = OdeSolver::new(params, &system).unwrap(); + // enable output of accepted steps - let mut out = Output::new(); - out.set_step_recording(true, &[0, 1, 2]); + solver.enable_output().set_step_recording(true, &[0, 1, 2]); // solve the ODE system - let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args).unwrap(); + solver.solve(&mut y0, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); @@ -32,15 +33,15 @@ fn test_radau5_robertson() { approx_eq(stat.h_accepted, 8.160578540333708E-01, 1e-10); // print the results at accepted steps - let n_step = out.step_x.len(); + let n_step = solver.out().step_x.len(); for i in 0..n_step { println!( "step ={:>4}, x ={:5.2}, y ={}{}{}", i, - out.step_x[i], - format_fortran(out.step_y.get(&0).unwrap()[i]), - format_fortran(out.step_y.get(&1).unwrap()[i]), - format_fortran(out.step_y.get(&2).unwrap()[i]), + solver.out().step_x[i], + format_fortran(solver.out().step_y.get(&0).unwrap()[i]), + format_fortran(solver.out().step_y.get(&1).unwrap()[i]), + format_fortran(solver.out().step_y.get(&2).unwrap()[i]), ); } diff --git a/russell_ode/tests/test_radau5_robertson_debug.rs b/russell_ode/tests/test_radau5_robertson_debug.rs index 52c5179e..21ff0fb5 100644 --- a/russell_ode/tests/test_radau5_robertson_debug.rs +++ b/russell_ode/tests/test_radau5_robertson_debug.rs @@ -15,9 +15,11 @@ fn test_radau5_robertson_debug() { params.set_tolerances(1e-8, 1e-2, None).unwrap(); params.debug = true; - // solve the ODE system + // allocate the solver let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, x0, x1, None, None, &mut args).unwrap(); + + // solve the ODE system + solver.solve(&mut y0, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); diff --git a/russell_ode/tests/test_radau5_robertson_small_h.rs b/russell_ode/tests/test_radau5_robertson_small_h.rs index 231ec3fa..58918cf3 100644 --- a/russell_ode/tests/test_radau5_robertson_small_h.rs +++ b/russell_ode/tests/test_radau5_robertson_small_h.rs @@ -14,12 +14,14 @@ fn test_radau5_robertson_small_h() { params.step.h_ini = 1e-6; params.debug = true; + // allocate the solver + let mut solver = OdeSolver::new(params, &system).unwrap(); + // this will cause h to become too small params.set_tolerances(1e-2, 1e-2, None).unwrap(); // solve the ODE system - let mut solver = OdeSolver::new(params, &system).unwrap(); - let res = solver.solve(&mut y0, x0, x1, None, None, &mut args); + let res = solver.solve(&mut y0, x0, x1, None, &mut args); assert_eq!(res.err(), Some("the stepsize becomes too small")); println!("ERROR: THE STEPSIZE BECOMES TOO SMALL"); diff --git a/russell_ode/tests/test_radau5_van_der_pol.rs b/russell_ode/tests/test_radau5_van_der_pol.rs index 635c087d..7cfe1f75 100644 --- a/russell_ode/tests/test_radau5_van_der_pol.rs +++ b/russell_ode/tests/test_radau5_van_der_pol.rs @@ -1,5 +1,5 @@ use russell_lab::{approx_eq, format_fortran}; -use russell_ode::{Method, OdeSolver, Output, Params, Samples}; +use russell_ode::{Method, OdeSolver, Params, Samples}; #[test] fn test_radau5_van_der_pol() { @@ -11,13 +11,14 @@ fn test_radau5_van_der_pol() { let mut params = Params::new(Method::Radau5); params.step.h_ini = 1e-6; + // allocate the solver + let mut solver = OdeSolver::new(params, &system).unwrap(); + // enable dense output with 0.2 spacing - let mut out = Output::new(); - out.set_dense_recording(true, 0.2, &[0, 1]).unwrap(); + solver.enable_output().set_dense_recording(true, 0.2, &[0, 1]).unwrap(); // solve the ODE system - let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, x0, x1, None, Some(&mut out), &mut args).unwrap(); + solver.solve(&mut y0, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats(); @@ -28,14 +29,14 @@ fn test_radau5_van_der_pol() { approx_eq(stat.h_accepted, 1.510987221365367E-01, 1.2e-8); // print dense output - let n_dense = out.dense_step_index.len(); + let n_dense = solver.out().dense_step_index.len(); for i in 0..n_dense { println!( "step ={:>4}, x ={:5.2}, y ={}{}", - out.dense_step_index[i], - out.dense_x[i], - format_fortran(out.dense_y.get(&0).unwrap()[i]), - format_fortran(out.dense_y.get(&1).unwrap()[i]), + solver.out().dense_step_index[i], + solver.out().dense_x[i], + format_fortran(solver.out().dense_y.get(&0).unwrap()[i]), + format_fortran(solver.out().dense_y.get(&1).unwrap()[i]), ); } diff --git a/russell_ode/tests/test_radau5_van_der_pol_debug.rs b/russell_ode/tests/test_radau5_van_der_pol_debug.rs index 140c4778..1e1d02c1 100644 --- a/russell_ode/tests/test_radau5_van_der_pol_debug.rs +++ b/russell_ode/tests/test_radau5_van_der_pol_debug.rs @@ -12,9 +12,11 @@ fn test_radau5_van_der_pol_debug() { params.step.h_ini = 1e-6; params.debug = true; - // solve the ODE system + // allocate the solver let mut solver = OdeSolver::new(params, &system).unwrap(); - solver.solve(&mut y0, x0, x1, None, None, &mut args).unwrap(); + + // solve the ODE system + solver.solve(&mut y0, x0, x1, None, &mut args).unwrap(); // get statistics let stat = solver.stats();