From 8f73f1a8f73a1fb64d052d72a3d3d2cb39b09df2 Mon Sep 17 00:00:00 2001 From: Dorival Pedroso Date: Wed, 25 Sep 2024 17:42:31 +1000 Subject: [PATCH] Make linear_fitting function generic --- russell_lab/README.md | 6 +- russell_lab/examples/algo_linear_fitting.rs | 6 +- russell_lab/src/algo/linear_fitting.rs | 95 +++++++++++++++------ 3 files changed, 75 insertions(+), 32 deletions(-) diff --git a/russell_lab/README.md b/russell_lab/README.md index af88e9d5..8749c0f4 100644 --- a/russell_lab/README.md +++ b/russell_lab/README.md @@ -308,12 +308,12 @@ Fit a line through a set of points. The line has slope `m` and intercepts the y ```rust use russell_lab::algo::linear_fitting; -use russell_lab::{approx_eq, StrError, Vector}; +use russell_lab::{approx_eq, StrError}; fn main() -> Result<(), StrError> { // model: c is the y value @ x = 0; m is the slope - let x = Vector::from(&[0.0, 1.0, 3.0, 5.0]); - let y = Vector::from(&[1.0, 0.0, 2.0, 4.0]); + let x = [0.0, 1.0, 3.0, 5.0]; + let y = [1.0, 0.0, 2.0, 4.0]; let (c, m) = linear_fitting(&x, &y, false)?; println!("c = {}, m = {}", c, m); approx_eq(c, 0.1864406779661015, 1e-15); diff --git a/russell_lab/examples/algo_linear_fitting.rs b/russell_lab/examples/algo_linear_fitting.rs index 1cbc9eb8..fb6cc130 100644 --- a/russell_lab/examples/algo_linear_fitting.rs +++ b/russell_lab/examples/algo_linear_fitting.rs @@ -11,8 +11,8 @@ const OUT_DIR: &str = "/tmp/russell_lab/"; fn main() -> Result<(), StrError> { // model: c is the y value @ x = 0; m is the slope - let x = Vector::from(&[0.0, 1.0, 3.0, 5.0]); - let y = Vector::from(&[1.0, 0.0, 2.0, 4.0]); + let x = [0.0, 1.0, 3.0, 5.0]; + let y = [1.0, 0.0, 2.0, 4.0]; let (c, m) = linear_fitting(&x, &y, false)?; println!("c = {}, m = {}", c, m); approx_eq(c, 0.1864406779661015, 1e-15); @@ -28,7 +28,7 @@ fn main() -> Result<(), StrError> { .set_marker_line_color("red") .set_marker_color("red"); curve_fit.draw_ray(0.0, c, RayEndpoint::Slope(m)); - curve_dat.draw(x.as_data(), y.as_data()); + curve_dat.draw(&x, &y); let mut plot = Plot::new(); let path = format!("{}/algo_linear_fitting_1.svg", OUT_DIR); plot.add(&curve_dat) diff --git a/russell_lab/src/algo/linear_fitting.rs b/russell_lab/src/algo/linear_fitting.rs index 964576ab..7854c9c3 100644 --- a/russell_lab/src/algo/linear_fitting.rs +++ b/russell_lab/src/algo/linear_fitting.rs @@ -1,4 +1,6 @@ -use crate::{StrError, Vector}; +use crate::StrError; +use num_traits::{cast, Num, NumCast}; +use std::ops::{AddAssign, Mul}; /// Calculates the parameters of a linear model using least squares fitting /// @@ -19,42 +21,55 @@ use crate::{StrError, Vector}; /// * If `pass_through_zero == True` and `sum(X) == 0` /// * If `pass_through_zero == False` and the line is vertical (null denominator) /// +/// # Panics +/// +/// This function may panic if the number type cannot be converted to `f64`. +/// /// # Examples /// /// ![Linear fitting](https://raw.githubusercontent.com/cpmech/russell/main/russell_lab/data/figures/algo_linear_fitting_1.svg) /// /// ``` -/// use russell_lab::{approx_eq, linear_fitting, StrError, Vector}; +/// use russell_lab::{approx_eq, linear_fitting, StrError}; /// /// fn main() -> Result<(), StrError> { /// // model: c is the y value @ x = 0; m is the slope -/// let x = Vector::from(&[0.0, 1.0, 3.0, 5.0]); -/// let y = Vector::from(&[1.0, 0.0, 2.0, 4.0]); +/// let x = [0.0, 1.0, 3.0, 5.0]; +/// let y = [1.0, 0.0, 2.0, 4.0]; /// let (c, m) = linear_fitting(&x, &y, false)?; /// approx_eq(c, 0.1864406779661015, 1e-15); /// approx_eq(m, 0.6949152542372882, 1e-15); /// Ok(()) /// } /// ``` -pub fn linear_fitting(x: &Vector, y: &Vector, pass_through_zero: bool) -> Result<(f64, f64), StrError> { +pub fn linear_fitting(x: &[T], y: &[T], pass_through_zero: bool) -> Result<(f64, f64), StrError> +where + T: AddAssign + Copy + Mul + Num + NumCast, +{ // dimension - let nn = x.dim(); - if y.dim() != nn { - return Err("vectors must have the same dimension"); + let nn = x.len(); + if y.len() != nn { + return Err("arrays must have the same lengths"); } // sums - let mut sum_x = 0.0; - let mut sum_y = 0.0; - let mut sum_xy = 0.0; - let mut sum_xx = 0.0; + let mut t_sum_x = T::zero(); + let mut t_sum_y = T::zero(); + let mut t_sum_xy = T::zero(); + let mut t_sum_xx = T::zero(); for i in 0..nn { - sum_x += x[i]; - sum_y += y[i]; - sum_xy += x[i] * y[i]; - sum_xx += x[i] * x[i]; + t_sum_x += x[i]; + t_sum_y += y[i]; + t_sum_xy += x[i] * y[i]; + t_sum_xx += x[i] * x[i]; } + // cast sums to f64 + let sum_x: f64 = cast(t_sum_x).unwrap(); + let sum_y: f64 = cast(t_sum_y).unwrap(); + let sum_xy: f64 = cast(t_sum_xy).unwrap(); + let sum_xx: f64 = cast(t_sum_xx).unwrap(); + // calculate parameters let c; let m; @@ -83,22 +98,37 @@ pub fn linear_fitting(x: &Vector, y: &Vector, pass_through_zero: bool) -> Result #[cfg(test)] mod tests { use super::linear_fitting; - use crate::{approx_eq, Vector}; + use crate::approx_eq; #[test] fn linear_fitting_handles_errors() { - let x = Vector::from(&[1.0, 2.0]); - let y = Vector::from(&[6.0, 5.0, 7.0, 10.0]); + let x = [1.0, 2.0]; + let y = [6.0, 5.0, 7.0, 10.0]; assert_eq!( linear_fitting(&x, &y, false).err(), - Some("vectors must have the same dimension") + Some("arrays must have the same lengths") ); } #[test] fn linear_fitting_works() { - let x = Vector::from(&[1.0, 2.0, 3.0, 4.0]); - let y = Vector::from(&[6.0, 5.0, 7.0, 10.0]); + // f64 (heap) + + let x = vec![1.0, 2.0, 3.0, 4.0]; + let y = vec![6.0, 5.0, 7.0, 10.0]; + + let (c, m) = linear_fitting(&x, &y, false).unwrap(); + assert_eq!(c, 3.5); + assert_eq!(m, 1.4); + + let (c, m) = linear_fitting(&x, &y, true).unwrap(); + assert_eq!(c, 0.0); + approx_eq(m, 2.566666666666667, 1e-16); + + // usize (stack) + + let x = [1, 2, 3, 4_usize]; + let y = [6, 5, 7, 10_usize]; let (c, m) = linear_fitting(&x, &y, false).unwrap(); assert_eq!(c, 3.5); @@ -107,19 +137,32 @@ mod tests { let (c, m) = linear_fitting(&x, &y, true).unwrap(); assert_eq!(c, 0.0); approx_eq(m, 2.566666666666667, 1e-16); + + // i32 (slice) + + let x = &[1, 2, 3, 4_i32]; + let y = &[6, 5, 7, 10_i32]; + + let (c, m) = linear_fitting(x, y, false).unwrap(); + assert_eq!(c, 3.5); + assert_eq!(m, 1.4); + + let (c, m) = linear_fitting(x, y, true).unwrap(); + assert_eq!(c, 0.0); + approx_eq(m, 2.566666666666667, 1e-16); } #[test] fn linear_fitting_handles_division_by_zero() { - let x = Vector::from(&[1.0, 1.0, 1.0, 1.0]); - let y = Vector::from(&[1.0, 2.0, 3.0, 4.0]); + let x = [1.0, 1.0, 1.0, 1.0]; + let y = [1.0, 2.0, 3.0, 4.0]; let (c, m) = linear_fitting(&x, &y, false).unwrap(); assert_eq!(c, 0.0); assert_eq!(m, f64::INFINITY); - let x = Vector::from(&[0.0, 0.0, 0.0, 0.0]); - let y = Vector::from(&[1.0, 2.0, 3.0, 4.0]); + let x = [0.0, 0.0, 0.0, 0.0]; + let y = [1.0, 2.0, 3.0, 4.0]; let (c, m) = linear_fitting(&x, &y, true).unwrap(); assert_eq!(c, 0.0); assert_eq!(m, f64::INFINITY);