From 5d6371bd8c2744ad496e0ca09b5bb71b7f62b81e Mon Sep 17 00:00:00 2001 From: Dorival Pedroso Date: Wed, 26 Jun 2024 17:40:15 +1000 Subject: [PATCH] Simplify MultiRootSolverChevy --- .../examples/algo_multi_root_solver_cheby.rs | 6 +- .../src/algo/multi_root_solver_cheby.rs | 290 +++++++----------- 2 files changed, 119 insertions(+), 177 deletions(-) diff --git a/russell_lab/examples/algo_multi_root_solver_cheby.rs b/russell_lab/examples/algo_multi_root_solver_cheby.rs index 4ef1fa74..fc4e1ef0 100644 --- a/russell_lab/examples/algo_multi_root_solver_cheby.rs +++ b/russell_lab/examples/algo_multi_root_solver_cheby.rs @@ -18,15 +18,15 @@ fn main() -> Result<(), StrError> { println!("N = {}", nn); // find all roots in the interval - let mut solver = MultiRootSolverCheby::new(nn)?; + let solver = MultiRootSolverCheby::new(); let roots = Vector::from(&solver.find(&interp)?); let f_at_roots = roots.get_mapped(|x| f(x, args).unwrap()); println!("roots =\n{}", roots); println!("f @ roots =\n{}", print_vec_exp(&f_at_roots)); // polish the roots - let mut roots_polished = Vector::new(roots.dim()); - solver.polish_roots_newton(roots_polished.as_mut_data(), roots.as_data(), xa, xb, args, f)?; + let mut roots_polished = roots.clone(); + solver.polish_roots_newton(roots_polished.as_mut_data(), xa, xb, args, f)?; let f_at_roots_polished = roots_polished.get_mapped(|x| f(x, args).unwrap()); println!("polished roots =\n{}", roots_polished); println!("f @ polished roots =\n{}", print_vec_exp(&f_at_roots_polished)); diff --git a/russell_lab/src/algo/multi_root_solver_cheby.rs b/russell_lab/src/algo/multi_root_solver_cheby.rs index f499d9de..3656194e 100644 --- a/russell_lab/src/algo/multi_root_solver_cheby.rs +++ b/russell_lab/src/algo/multi_root_solver_cheby.rs @@ -58,21 +58,6 @@ pub struct MultiRootSolverCheby { /// Default = 15 pub newton_max_iterations: usize, - /// Holds the polynomial degree N - nn: usize, - - /// Holds the companion matrix A - aa: Matrix, - - /// Holds the real part of the eigenvalues - l_real: Vector, - - /// Holds the imaginary part of the eigenvalues - l_imag: Vector, - - /// Holds all possible roots (dim == N) - roots: Vector, - /// Stepsize for one-sided differences h_osd: f64, @@ -82,42 +67,17 @@ pub struct MultiRootSolverCheby { impl MultiRootSolverCheby { /// Allocates a new instance - /// - /// # Input - /// - /// * `nn` -- polynomial degree N (must be ≥ 1) - pub fn new(nn: usize) -> Result { - // check - if nn < 1 { - return Err("the degree N must be ≥ 1"); - } - - // companion matrix (except last row) - let mut aa = Matrix::new(nn, nn); - if nn > 1 { - aa.set(0, 1, 1.0); - for r in 1..(nn - 1) { - aa.set(r, r + 1, 0.5); // upper diagonal - aa.set(r, r - 1, 0.5); // lower diagonal - } - } - - // done - Ok(MultiRootSolverCheby { + pub fn new() -> Self { + MultiRootSolverCheby { tol_zero_an: 1e-13, tol_rel_imag: 1.0e-8, tol_abs_boundary: TOL_RANGE / 10.0, newton_tol_zero_dx: 1e-13, newton_tol_zero_fx: 1e-13, newton_max_iterations: 15, - nn, - aa, - l_real: Vector::new(nn), - l_imag: Vector::new(nn), - roots: Vector::new(nn), h_osd: f64::powf(f64::EPSILON, 1.0 / 2.0), h_cen: f64::powf(f64::EPSILON, 1.0 / 3.0), - }) + } } /// Find all roots in the interval @@ -151,71 +111,76 @@ impl MultiRootSolverCheby { /// interp.set_function(nn, args, f)?; /// /// // find all roots in the interval - /// let mut solver = MultiRootSolverCheby::new(nn)?; - /// let roots = Vector::from(&solver.find(&interp)?); - /// vec_approx_eq(&roots, &[-1.0, 1.0], 1e-15); + /// let mut solver = MultiRootSolverCheby::new(); + /// let roots = solver.find(&interp)?; + /// array_approx_eq(&roots, &[-1.0, 1.0], 1e-15); /// Ok(()) /// } /// ``` - pub fn find(&mut self, interp: &InterpChebyshev) -> Result<&[f64], StrError> { + pub fn find(&self, interp: &InterpChebyshev) -> Result, StrError> { // check - let nn = interp.get_degree(); - if nn != self.nn { - return Err("the interpolant must have the same degree N as the solver"); - } if !interp.is_ready() { - return Err("the interpolant must have the U vector already computed"); + return Err("the interpolant must initialized first"); + } + + // handle constant function + let nn = interp.get_degree(); + if nn == 0 { + return Ok(Vec::new()); } - // last expansion coefficient + // expansion coefficients let a = interp.get_coefficients(); let an = a[nn]; if f64::abs(an) < self.tol_zero_an { return Err("the trailing Chebyshev coefficient vanishes; try a smaller degree N"); } - // linear function + // handle linear function let (xa, xb, dx) = interp.get_range(); if nn == 1 { let z = -a[0] / a[1]; - let nr = if f64::abs(z) <= 1.0 + self.tol_abs_boundary { - self.roots[0] = (xb + xa + dx * z) / 2.0; - 1 + if f64::abs(z) <= 1.0 + self.tol_abs_boundary { + let root = (xb + xa + dx * z) / 2.0; + return Ok(vec![root]); } else { - 0 - }; - return Ok(&self.roots.as_data()[..nr]); + return Ok(Vec::new()); + } } - // last row of the companion matrix + // companion matrix + let mut aa = Matrix::new(nn, nn); + aa.set(0, 1, 1.0); + for r in 1..(nn - 1) { + aa.set(r, r + 1, 0.5); // upper diagonal + aa.set(r, r - 1, 0.5); // lower diagonal + } for k in 0..nn { - self.aa.set(nn - 1, k, -0.5 * a[k] / an); + aa.set(nn - 1, k, -0.5 * a[k] / an); } - self.aa.add(nn - 1, nn - 2, 0.5); + aa.add(nn - 1, nn - 2, 0.5); // eigenvalues - mat_eigenvalues(&mut self.l_real, &mut self.l_imag, &mut self.aa).unwrap(); + let mut l_real = Vector::new(nn); + let mut l_imag = Vector::new(nn); + mat_eigenvalues(&mut l_real, &mut l_imag, &mut aa).unwrap(); // roots = real eigenvalues within the interval - let mut nroot = 0; + let mut roots = Vec::new(); for i in 0..nn { - if f64::abs(self.l_imag[i]) < self.tol_rel_imag * f64::abs(self.l_real[i]) { - if f64::abs(self.l_real[i]) <= 1.0 + self.tol_abs_boundary { - let x = (xb + xa + dx * self.l_real[i]) / 2.0; - self.roots[nroot] = f64::max(xa, f64::min(xb, x)); - nroot += 1; + if f64::abs(l_imag[i]) < self.tol_rel_imag * f64::abs(l_real[i]) { + if f64::abs(l_real[i]) <= 1.0 + self.tol_abs_boundary { + let x = (xb + xa + dx * l_real[i]) / 2.0; + roots.push(f64::max(xa, f64::min(xb, x))); } } } // sort roots - for i in nroot..nn { - self.roots[i] = f64::MAX; + if roots.len() > 0 { + roots.sort_by(|a, b| a.partial_cmp(b).unwrap()); } - self.roots.as_mut_data().sort_by(|a, b| a.partial_cmp(b).unwrap()); - - // results - Ok(&self.roots.as_data()[..nroot]) + Ok(roots) } /// Polishes the roots using Newton's method @@ -237,21 +202,19 @@ impl MultiRootSolverCheby { /// interp.set_function(nn, args, f)?; /// /// // find all roots in the interval - /// let mut solver = MultiRootSolverCheby::new(nn)?; - /// let roots = Vector::from(&solver.find(&interp)?); - /// vec_approx_eq(&roots, &[-0.5, 0.5], 1e-15); // inaccurate + /// let mut solver = MultiRootSolverCheby::new(); + /// let mut roots = solver.find(&interp)?; + /// array_approx_eq(&roots, &[-0.5, 0.5], 1e-15); // inaccurate /// /// // polish the roots - /// let mut roots_polished = Vector::new(roots.dim()); - /// solver.polish_roots_newton(roots_polished.as_mut_data(), roots.as_data(), xa, xb, args, f)?; - /// vec_approx_eq(&roots_polished, &[-1.0, 1.0], 1e-15); // accurate + /// solver.polish_roots_newton(&mut roots, xa, xb, args, f)?; + /// array_approx_eq(&roots, &[-1.0, 1.0], 1e-15); // accurate /// Ok(()) /// } - ///``` + /// ``` pub fn polish_roots_newton( &self, - roots_out: &mut [f64], - roots_in: &[f64], + roots: &mut [f64], xa: f64, xb: f64, args: &mut A, @@ -261,18 +224,15 @@ impl MultiRootSolverCheby { F: FnMut(f64, &mut A) -> Result, { // check - let nr = roots_in.len(); + let nr = roots.len(); if nr < 1 { return Err("at least one root is required"); } - if roots_out.len() != roots_in.len() { - return Err("root_in and root_out must have the same lengths"); - } // Newton's method with approximate Jacobian let h_cen_2 = self.h_cen * 2.0; for r in 0..nr { - let mut x = roots_in[r]; + let mut x = roots[r]; let mut converged = false; for _ in 0..self.newton_max_iterations { // check convergence on f(x) @@ -311,7 +271,7 @@ impl MultiRootSolverCheby { if !converged { return Err("Newton's method did not converge"); } - roots_out[r] = x; + roots[r] = x; } Ok(()) } @@ -325,7 +285,6 @@ mod tests { use crate::algo::NoArgs; use crate::InterpChebyshev; use crate::{approx_eq, array_approx_eq, get_test_functions}; - use crate::{mat_approx_eq, Matrix}; #[allow(unused)] use crate::{StrError, Vector}; @@ -397,30 +356,15 @@ mod tests { } */ - #[test] - fn new_captures_errors() { - let nn = 0; - assert_eq!(MultiRootSolverCheby::new(nn).err(), Some("the degree N must be ≥ 1")); - } - - #[test] - fn new_works() { - let nn = 2; - let solver = MultiRootSolverCheby::new(nn).unwrap(); - let aa_correct = Matrix::from(&[[0.0, 1.0000], [0.0, 0.0]]); - mat_approx_eq(&solver.aa, &aa_correct, 1e-15); - } - #[test] fn find_captures_errors() { let (xa, xb) = (-4.0, 4.0); let nn = 2; let interp = InterpChebyshev::new(nn, xa, xb).unwrap(); - let nn_wrong = 3; - let mut solver = MultiRootSolverCheby::new(nn_wrong).unwrap(); + let solver = MultiRootSolverCheby::new(); assert_eq!( solver.find(&interp).err(), - Some("the interpolant must have the same degree N as the solver") + Some("the interpolant must initialized first") ); } @@ -432,7 +376,7 @@ mod tests { let args = &mut 0; let mut interp = InterpChebyshev::new(nn, xa, xb).unwrap(); interp.set_function(nn, args, f).unwrap(); - let mut solver = MultiRootSolverCheby::new(nn).unwrap(); + let solver = MultiRootSolverCheby::new(); assert_eq!( solver.find(&interp).err(), Some("the trailing Chebyshev coefficient vanishes; try a smaller degree N") @@ -452,11 +396,11 @@ mod tests { interp.set_function(nn, args, f).unwrap(); // find roots - let mut solver = MultiRootSolverCheby::new(nn).unwrap(); - let roots_unpolished = Vec::from(solver.find(&interp).unwrap()); - let mut roots_polished = vec![0.0; roots_unpolished.len()]; + let solver = MultiRootSolverCheby::new(); + let roots_unpolished = solver.find(&interp).unwrap(); + let mut roots_polished = roots_unpolished.clone(); solver - .polish_roots_newton(&mut roots_polished, &roots_unpolished, xa, xb, args, f) + .polish_roots_newton(&mut roots_polished, xa, xb, args, f) .unwrap(); println!("n_roots = {}", roots_polished.len()); println!("roots_unpolished = {:?}", roots_unpolished); @@ -480,6 +424,26 @@ mod tests { array_approx_eq(&roots_polished, &[-1.0, 1.0], 1e-14); } + #[test] + fn polish_roots_newton_captures_errors() { + let f = |_, _: &mut NoArgs| Ok(0.0); + let args = &mut 0; + let _ = f(0.0, args); + let (xa, xb) = (-1.0, 1.0); + let mut solver = MultiRootSolverCheby::new(); + let mut roots = Vec::new(); + assert_eq!( + solver.polish_roots_newton(&mut roots, xa, xb, args, f).err(), + Some("at least one root is required") + ); + let mut roots = [0.0]; + solver.newton_max_iterations = 0; + assert_eq!( + solver.polish_roots_newton(&mut roots, xa, xb, args, f).err(), + Some("Newton's method did not converge") + ); + } + #[test] fn polish_roots_newton_works() { // function @@ -493,11 +457,11 @@ mod tests { interp.set_function(nn, args, f).unwrap(); // find roots - let mut solver = MultiRootSolverCheby::new(nn).unwrap(); - let roots_unpolished = Vec::from(solver.find(&interp).unwrap()); - let mut roots_polished = vec![0.0; roots_unpolished.len()]; + let solver = MultiRootSolverCheby::new(); + let roots_unpolished = solver.find(&interp).unwrap(); + let mut roots_polished = roots_unpolished.clone(); solver - .polish_roots_newton(&mut roots_polished, &roots_unpolished, xa, xb, args, f) + .polish_roots_newton(&mut roots_polished, xa, xb, args, f) .unwrap(); println!("n_roots = {}", roots_polished.len()); println!("roots_unpolished = {:?}", roots_unpolished); @@ -534,12 +498,11 @@ mod tests { let (xa, xb) = test.range; let mut interp = InterpChebyshev::new(nn_max, xa, xb).unwrap(); interp.adapt_function(tol, args, test.f).unwrap(); - let nn = interp.get_degree(); - let mut solver = MultiRootSolverCheby::new(nn).unwrap(); - let roots_unpolished = Vec::from(solver.find(&interp).unwrap()); - let mut roots_polished = vec![0.0; roots_unpolished.len()]; + let solver = MultiRootSolverCheby::new(); + let roots_unpolished = solver.find(&interp).unwrap(); + let mut roots_polished = roots_unpolished.clone(); solver - .polish_roots_newton(&mut roots_polished, &roots_unpolished, xa, xb, args, test.f) + .polish_roots_newton(&mut roots_polished, xa, xb, args, test.f) .unwrap(); for xr in &roots_polished { let fx = (test.f)(*xr, args).unwrap(); @@ -564,8 +527,8 @@ mod tests { if *id == 9 { assert_eq!(roots_unpolished.len(), 93); } - /* // figure + /* let (nstation, fig_width) = if *id == 9 { (1001, 2048.0) } else { (101, 600.0) }; graph( &format!("test_multi_root_solver_cheby_{:0>3}", id), @@ -582,54 +545,39 @@ mod tests { } #[test] - fn polish_roots_newton_captures_errors() { - let f = |_, _: &mut NoArgs| Ok(0.0); - let args = &mut 0; - let _ = f(0.0, args); - let (xa, xb) = (-1.0, 1.0); - let mut solver = MultiRootSolverCheby::new(2).unwrap(); - let roots_in = Vec::new(); - let mut roots_out = [0.0]; - assert_eq!( - solver - .polish_roots_newton(&mut roots_out, &roots_in, xa, xb, args, f) - .err(), - Some("at least one root is required") - ); - let roots_in = [0.0, 1.0]; - assert_eq!( - solver - .polish_roots_newton(&mut roots_out, &roots_in, xa, xb, args, f) - .err(), - Some("root_in and root_out must have the same lengths") - ); - let roots_in = [0.0]; - solver.newton_max_iterations = 0; - assert_eq!( - solver - .polish_roots_newton(&mut roots_out, &roots_in, xa, xb, args, f) - .err(), - Some("Newton's method did not converge") - ); + fn constant_function_works() { + // data + let (xa, xb) = (0.0, 1.0); + let uu = &[0.5]; + + // interpolant + let nn_max = 10; + let mut interp = InterpChebyshev::new(nn_max, xa, xb).unwrap(); + interp.set_data(uu).unwrap(); + + // find all roots in the interval + let solver = MultiRootSolverCheby::new(); + let roots = &solver.find(&interp).unwrap(); + let nroot = roots.len(); + assert_eq!(nroot, 0) } #[test] fn linear_function_no_roots_works() { // data let (xa, xb) = (0.0, 1.0); - let uu = Vector::from(&[0.5, 3.0]); + let uu = &[0.5, 3.0]; // interpolant - let nn_max = 100; + let nn_max = 10; let tol = 1e-8; let mut interp = InterpChebyshev::new(nn_max, xa, xb).unwrap(); - interp.adapt_data(tol, uu.as_data()).unwrap(); - let nn = interp.get_degree(); + interp.adapt_data(tol, uu).unwrap(); // find all roots in the interval - let mut solver = MultiRootSolverCheby::new(nn).unwrap(); - let roots = Vector::from(&solver.find(&interp).unwrap()); - let nroot = roots.dim(); + let solver = MultiRootSolverCheby::new(); + let roots = &solver.find(&interp).unwrap(); + let nroot = roots.len(); assert_eq!(nroot, 0) } @@ -638,8 +586,8 @@ mod tests { // data let (xa, xb) = (0.0, 1.0); let dx = xb - xa; - let uu = Vector::from(&[-7.0, -4.5, 0.5, 3.0]); - let np = uu.dim(); // number of points + let uu = &[-7.0, -4.5, 0.5, 3.0]; + let np = uu.len(); // number of points let nn = np - 1; // degree let mut xx_dat = Vector::new(np); let zz = InterpChebyshev::points(nn); @@ -651,21 +599,15 @@ mod tests { let nn_max = 100; let tol = 1e-8; let mut interp = InterpChebyshev::new(nn_max, xa, xb).unwrap(); - interp.adapt_data(tol, uu.as_data()).unwrap(); - let nn = interp.get_degree(); + interp.adapt_data(tol, uu).unwrap(); // find all roots in the interval - let mut solver = MultiRootSolverCheby::new(nn).unwrap(); - let roots = Vector::from(&solver.find(&interp).unwrap()); - let f_at_roots = roots.get_mapped(|x| interp.eval(x).unwrap()); - let nroot = roots.dim(); - println!("roots =\n{}", roots); - for i in 0..nroot { - println!("xr = {}, f(xr) = {:.2e}", roots[i], f_at_roots[i]); - } + let solver = MultiRootSolverCheby::new(); + let roots = solver.find(&interp).unwrap(); + let nroot = roots.len(); assert_eq!(nroot, 1); approx_eq(roots[0], 0.7, 1e-15); - approx_eq(f_at_roots[0], 0.0, 1e-15); + approx_eq(interp.eval(roots[0]).unwrap(), 0.0, 1e-15); // plot /* @@ -685,7 +627,7 @@ mod tests { .set_marker_void(true); curve_dat.draw(xx_dat.as_data(), uu.as_data()); curve_int.draw(xx.as_data(), yy_int.as_data()); - curve_xr.draw(roots.as_data(), f_at_roots.as_data()); + curve_xr.draw(&roots, &vec![0.0]); let mut plot = Plot::new(); let mut legend = Legend::new(); legend.set_num_col(4);