Skip to content

Commit

Permalink
Make linear_fitting function generic
Browse files Browse the repository at this point in the history
  • Loading branch information
cpmech committed Sep 25, 2024
1 parent 03c8ac3 commit 8f73f1a
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 32 deletions.
6 changes: 3 additions & 3 deletions russell_lab/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,12 @@ Fit a line through a set of points. The line has slope `m` and intercepts the y

```rust
use russell_lab::algo::linear_fitting;
use russell_lab::{approx_eq, StrError, Vector};
use russell_lab::{approx_eq, StrError};

fn main() -> Result<(), StrError> {
// model: c is the y value @ x = 0; m is the slope
let x = Vector::from(&[0.0, 1.0, 3.0, 5.0]);
let y = Vector::from(&[1.0, 0.0, 2.0, 4.0]);
let x = [0.0, 1.0, 3.0, 5.0];
let y = [1.0, 0.0, 2.0, 4.0];
let (c, m) = linear_fitting(&x, &y, false)?;
println!("c = {}, m = {}", c, m);
approx_eq(c, 0.1864406779661015, 1e-15);
Expand Down
6 changes: 3 additions & 3 deletions russell_lab/examples/algo_linear_fitting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ const OUT_DIR: &str = "/tmp/russell_lab/";

fn main() -> Result<(), StrError> {
// model: c is the y value @ x = 0; m is the slope
let x = Vector::from(&[0.0, 1.0, 3.0, 5.0]);
let y = Vector::from(&[1.0, 0.0, 2.0, 4.0]);
let x = [0.0, 1.0, 3.0, 5.0];
let y = [1.0, 0.0, 2.0, 4.0];
let (c, m) = linear_fitting(&x, &y, false)?;
println!("c = {}, m = {}", c, m);
approx_eq(c, 0.1864406779661015, 1e-15);
Expand All @@ -28,7 +28,7 @@ fn main() -> Result<(), StrError> {
.set_marker_line_color("red")
.set_marker_color("red");
curve_fit.draw_ray(0.0, c, RayEndpoint::Slope(m));
curve_dat.draw(x.as_data(), y.as_data());
curve_dat.draw(&x, &y);
let mut plot = Plot::new();
let path = format!("{}/algo_linear_fitting_1.svg", OUT_DIR);
plot.add(&curve_dat)
Expand Down
95 changes: 69 additions & 26 deletions russell_lab/src/algo/linear_fitting.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::{StrError, Vector};
use crate::StrError;
use num_traits::{cast, Num, NumCast};
use std::ops::{AddAssign, Mul};

/// Calculates the parameters of a linear model using least squares fitting
///
Expand All @@ -19,42 +21,55 @@ use crate::{StrError, Vector};
/// * If `pass_through_zero == True` and `sum(X) == 0`
/// * If `pass_through_zero == False` and the line is vertical (null denominator)
///
/// # Panics
///
/// This function may panic if the number type cannot be converted to `f64`.
///
/// # Examples
///
/// ![Linear fitting](https://raw.githubusercontent.com/cpmech/russell/main/russell_lab/data/figures/algo_linear_fitting_1.svg)
///
/// ```
/// use russell_lab::{approx_eq, linear_fitting, StrError, Vector};
/// use russell_lab::{approx_eq, linear_fitting, StrError};
///
/// fn main() -> Result<(), StrError> {
/// // model: c is the y value @ x = 0; m is the slope
/// let x = Vector::from(&[0.0, 1.0, 3.0, 5.0]);
/// let y = Vector::from(&[1.0, 0.0, 2.0, 4.0]);
/// let x = [0.0, 1.0, 3.0, 5.0];
/// let y = [1.0, 0.0, 2.0, 4.0];
/// let (c, m) = linear_fitting(&x, &y, false)?;
/// approx_eq(c, 0.1864406779661015, 1e-15);
/// approx_eq(m, 0.6949152542372882, 1e-15);
/// Ok(())
/// }
/// ```
pub fn linear_fitting(x: &Vector, y: &Vector, pass_through_zero: bool) -> Result<(f64, f64), StrError> {
pub fn linear_fitting<T>(x: &[T], y: &[T], pass_through_zero: bool) -> Result<(f64, f64), StrError>
where
T: AddAssign + Copy + Mul + Num + NumCast,
{
// dimension
let nn = x.dim();
if y.dim() != nn {
return Err("vectors must have the same dimension");
let nn = x.len();
if y.len() != nn {
return Err("arrays must have the same lengths");
}

// sums
let mut sum_x = 0.0;
let mut sum_y = 0.0;
let mut sum_xy = 0.0;
let mut sum_xx = 0.0;
let mut t_sum_x = T::zero();
let mut t_sum_y = T::zero();
let mut t_sum_xy = T::zero();
let mut t_sum_xx = T::zero();
for i in 0..nn {
sum_x += x[i];
sum_y += y[i];
sum_xy += x[i] * y[i];
sum_xx += x[i] * x[i];
t_sum_x += x[i];
t_sum_y += y[i];
t_sum_xy += x[i] * y[i];
t_sum_xx += x[i] * x[i];
}

// cast sums to f64
let sum_x: f64 = cast(t_sum_x).unwrap();
let sum_y: f64 = cast(t_sum_y).unwrap();
let sum_xy: f64 = cast(t_sum_xy).unwrap();
let sum_xx: f64 = cast(t_sum_xx).unwrap();

// calculate parameters
let c;
let m;
Expand Down Expand Up @@ -83,22 +98,37 @@ pub fn linear_fitting(x: &Vector, y: &Vector, pass_through_zero: bool) -> Result
#[cfg(test)]
mod tests {
use super::linear_fitting;
use crate::{approx_eq, Vector};
use crate::approx_eq;

#[test]
fn linear_fitting_handles_errors() {
let x = Vector::from(&[1.0, 2.0]);
let y = Vector::from(&[6.0, 5.0, 7.0, 10.0]);
let x = [1.0, 2.0];
let y = [6.0, 5.0, 7.0, 10.0];
assert_eq!(
linear_fitting(&x, &y, false).err(),
Some("vectors must have the same dimension")
Some("arrays must have the same lengths")
);
}

#[test]
fn linear_fitting_works() {
let x = Vector::from(&[1.0, 2.0, 3.0, 4.0]);
let y = Vector::from(&[6.0, 5.0, 7.0, 10.0]);
// f64 (heap)

let x = vec![1.0, 2.0, 3.0, 4.0];
let y = vec![6.0, 5.0, 7.0, 10.0];

let (c, m) = linear_fitting(&x, &y, false).unwrap();
assert_eq!(c, 3.5);
assert_eq!(m, 1.4);

let (c, m) = linear_fitting(&x, &y, true).unwrap();
assert_eq!(c, 0.0);
approx_eq(m, 2.566666666666667, 1e-16);

// usize (stack)

let x = [1, 2, 3, 4_usize];
let y = [6, 5, 7, 10_usize];

let (c, m) = linear_fitting(&x, &y, false).unwrap();
assert_eq!(c, 3.5);
Expand All @@ -107,19 +137,32 @@ mod tests {
let (c, m) = linear_fitting(&x, &y, true).unwrap();
assert_eq!(c, 0.0);
approx_eq(m, 2.566666666666667, 1e-16);

// i32 (slice)

let x = &[1, 2, 3, 4_i32];
let y = &[6, 5, 7, 10_i32];

let (c, m) = linear_fitting(x, y, false).unwrap();
assert_eq!(c, 3.5);
assert_eq!(m, 1.4);

let (c, m) = linear_fitting(x, y, true).unwrap();
assert_eq!(c, 0.0);
approx_eq(m, 2.566666666666667, 1e-16);
}

#[test]
fn linear_fitting_handles_division_by_zero() {
let x = Vector::from(&[1.0, 1.0, 1.0, 1.0]);
let y = Vector::from(&[1.0, 2.0, 3.0, 4.0]);
let x = [1.0, 1.0, 1.0, 1.0];
let y = [1.0, 2.0, 3.0, 4.0];

let (c, m) = linear_fitting(&x, &y, false).unwrap();
assert_eq!(c, 0.0);
assert_eq!(m, f64::INFINITY);

let x = Vector::from(&[0.0, 0.0, 0.0, 0.0]);
let y = Vector::from(&[1.0, 2.0, 3.0, 4.0]);
let x = [0.0, 0.0, 0.0, 0.0];
let y = [1.0, 2.0, 3.0, 4.0];
let (c, m) = linear_fitting(&x, &y, true).unwrap();
assert_eq!(c, 0.0);
assert_eq!(m, f64::INFINITY);
Expand Down

0 comments on commit 8f73f1a

Please sign in to comment.