Skip to content

Commit

Permalink
Merge pull request #120 from cpmech/tensor-improve-for-pmsim
Browse files Browse the repository at this point in the history
Tensor improve for pmsim
  • Loading branch information
cpmech authored Jun 3, 2024
2 parents 11dafc3 + 0154f3d commit 840c1a8
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 36 deletions.
214 changes: 188 additions & 26 deletions russell_lab/src/vector/num_vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::cmp;
use std::fmt::{self, Write};
use std::ops::{Index, IndexMut};
use std::ops::{Index, IndexMut, MulAssign};

/// Implements a vector with numeric components for linear algebra
///
Expand Down Expand Up @@ -90,15 +90,15 @@ use std::ops::{Index, IndexMut};
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
#[serde(bound(deserialize = "Vec<T>: Deserialize<'de>"))]
data: Vec<T>,
}

impl<T> NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
/// Creates a new (zeroed) vector
///
Expand Down Expand Up @@ -387,17 +387,17 @@ where

/// Returns the i-th component
///
/// # Panics
///
/// This function may panic if the index is out-of-bounds.
///
/// # Examples
///
/// ```
/// # use russell_lab::NumVector;
/// let u = NumVector::<f64>::from(&[1.0, 2.0]);
/// assert_eq!(u.get(1), 2.0);
/// ```
///
/// # Panics
///
/// This function may panic if the index is out-of-bounds.
#[inline]
pub fn get(&self, i: usize) -> T {
assert!(i < self.data.len());
Expand All @@ -406,6 +406,10 @@ where

/// Change the i-th component
///
/// # Panics
///
/// This function may panic if the index is out-of-bounds.
///
/// # Examples
///
/// ```
Expand All @@ -418,16 +422,116 @@ where
/// └ ┘";
/// assert_eq!(format!("{}", u), correct);
/// ```
///
/// # Panics
///
/// This function may panic if the index is out-of-bounds.
#[inline]
pub fn set(&mut self, i: usize, value: T) {
assert!(i < self.data.len());
self.data[i] = value;
}

/// Copy another vector into this one
///
/// # Panics
///
/// This function may panic if the other vector has a different length than this one
///
/// # Examples
///
/// ```
/// # use russell_lab::NumVector;
/// let mut u = NumVector::<f64>::from(&[1.0, 2.0]);
/// u.set_vector(&[-3.0, -4.0]);
/// let correct = "┌ ┐\n\
/// │ -3 │\n\
/// │ -4 │\n\
/// └ ┘";
/// assert_eq!(format!("{}", u), correct);
/// ```
pub fn set_vector(&mut self, other: &[T]) {
assert_eq!(other.len(), self.data.len());
self.data.copy_from_slice(other);
}

/// Splits this vector into another two vectors
///
/// **Requirements:** `u.len() + v.len() == self.len()`
///
/// This function is the opposite of [NumVector::join2()]
///
/// # Panics
///
/// This function may panic if the sum of the lengths of u and v are different that this vector's length
///
/// # Examples
///
/// ```
/// # use russell_lab::NumVector;
/// let w = NumVector::<f64>::from(&[1.0, 2.0, 3.0]);
/// let mut u = NumVector::<f64>::new(2);
/// let mut v = NumVector::<f64>::new(1);
///
/// w.split2(u.as_mut_data(), v.as_mut_data());
///
/// assert_eq!(u.as_data(), &[1.0, 2.0]);
/// assert_eq!(v.as_data(), &[3.0]);
/// ```
pub fn split2(&self, u: &mut [T], v: &mut [T]) {
assert_eq!(u.len() + v.len(), self.data.len());
u.copy_from_slice(&self.data[..u.len()]);
v.copy_from_slice(&self.data[u.len()..]);
}

/// Joins two vectors into this one
///
/// **Requirements:** `u.len() + v.len() == self.len()`
///
/// This function is the opposite of [NumVector::split2()]
///
/// # Panics
///
/// This function may panic if the sum of the lengths of u and v are different that this vector's length
///
/// # Examples
///
/// ```
/// # use russell_lab::NumVector;
/// let mut w = NumVector::<f64>::new(3);
/// let u = NumVector::<f64>::from(&[1.0, 2.0]);
/// let v = NumVector::<f64>::from(&[3.0]);
///
/// w.join2(u.as_data(), v.as_data());
///
/// assert_eq!(w.as_data(), &[1.0, 2.0, 3.0]);
/// ```
pub fn join2(&mut self, u: &[T], v: &[T]) {
assert_eq!(u.len() + v.len(), self.data.len());
(&mut self.data[..u.len()]).copy_from_slice(u);
(&mut self.data[u.len()..]).copy_from_slice(v);
}

/// Scales this vector
///
/// ```text
/// u := alpha * u
/// ```
///
/// # Examples
///
/// ```
/// # use russell_lab::NumVector;
/// let mut u = NumVector::<f64>::from(&[1.0, 2.0]);
/// u.scale(2.0);
/// let correct = "┌ ┐\n\
/// │ 2 │\n\
/// │ 4 │\n\
/// └ ┘";
/// assert_eq!(format!("{}", u), correct);
/// ```
pub fn scale(&mut self, alpha: T) {
for i in 0..self.data.len() {
self.data[i] *= alpha;
}
}

/// Applies a function over all components of this vector
///
/// ```text
Expand Down Expand Up @@ -524,7 +628,7 @@ where

impl<T> fmt::Display for NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize + fmt::Display,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize + fmt::Display,
{
/// Generates a string representation of the NumVector
///
Expand Down Expand Up @@ -584,6 +688,10 @@ where

/// Allows to access NumVector components using indices
///
/// # Panics
///
/// The index function may panic if the index is out-of-bounds.
///
/// # Examples
///
/// ```
Expand All @@ -593,13 +701,9 @@ where
/// assert_eq!(u[1], 1.2);
/// assert_eq!(u[2], 2.0);
/// ```
///
/// # Panics
///
/// The index function may panic if the index is out-of-bounds.
impl<T> Index<usize> for NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
type Output = T;
#[inline]
Expand All @@ -610,6 +714,10 @@ where

/// Allows to change NumVector components using indices
///
/// # Panics
///
/// The index function may panic if the index is out-of-bounds.
///
/// # Examples
///
/// ```
Expand All @@ -622,13 +730,9 @@ where
/// assert_eq!(u[1], 11.2);
/// assert_eq!(u[2], 22.0);
/// ```
///
/// # Panics
///
/// The index function may panic if the index is out-of-bounds.
impl<T> IndexMut<usize> for NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
#[inline]
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
Expand All @@ -649,7 +753,7 @@ where
/// ```
impl<T> IntoIterator for NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
type Item = T;
type IntoIter = std::vec::IntoIter<Self::Item>;
Expand All @@ -673,7 +777,7 @@ where
/// ```
impl<'a, T> IntoIterator for &'a NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
type Item = &'a T;
type IntoIter = std::slice::Iter<'a, T>;
Expand All @@ -698,7 +802,7 @@ where
/// ```
impl<'a, T> IntoIterator for &'a mut NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
type Item = &'a mut T;
type IntoIter = std::slice::IterMut<'a, T>;
Expand All @@ -710,7 +814,7 @@ where
/// Allows accessing NumVector as an Array1D
impl<'a, T: 'a> AsArray1D<'a, T> for NumVector<T>
where
T: Num + NumCast + Copy + DeserializeOwned + Serialize,
T: MulAssign + Num + NumCast + Copy + DeserializeOwned + Serialize,
{
#[inline]
fn size(&self) -> usize {
Expand Down Expand Up @@ -887,6 +991,64 @@ mod tests {
assert_eq!(u.data, &[-1.0, -2.0]);
}

#[test]
#[should_panic]
fn set_vector_panics_on_wrong_len() {
let mut u = NumVector::<f64>::from(&[1.0, 2.0]);
u.set_vector(&[8.0, 9.0, 10.0]);
}

#[test]
fn set_vector_works() {
let mut u = NumVector::<f64>::from(&[1.0, 2.0]);
u.set_vector(&[8.0, 9.0]);
assert_eq!(u.data, &[8.0, 9.0]);
}

#[test]
#[should_panic]
fn split2_panics_on_wrong_lengths() {
let w = NumVector::<f64>::from(&[1.0, 2.0, 3.0]);
let mut u = NumVector::<f64>::new(2);
let mut v = NumVector::<f64>::new(2); // WRONG length
w.split2(u.as_mut_data(), v.as_mut_data());
}

#[test]
fn split2_works() {
let w = NumVector::<f64>::from(&[4.0, 5.0, -6.0]);
let mut u = NumVector::<f64>::new(2);
let mut v = NumVector::<f64>::new(1);
w.split2(u.as_mut_data(), v.as_mut_data());
assert_eq!(u.as_data(), &[4.0, 5.0]);
assert_eq!(v.as_data(), &[-6.0]);
}

#[test]
#[should_panic]
fn join2_panics_on_wrong_lengths() {
let mut w = NumVector::<f64>::new(2); // WRONG length
let u = NumVector::<f64>::from(&[1.0, 2.0]);
let v = NumVector::<f64>::from(&[3.0]);
w.join2(u.as_data(), v.as_data());
}

#[test]
fn join2_works() {
let mut w = NumVector::<f64>::new(4);
let u = NumVector::<f64>::from(&[9.0, -1.0, 7.0]);
let v = NumVector::<f64>::from(&[8.0]);
w.join2(u.as_data(), v.as_data());
assert_eq!(w.as_data(), &[9.0, -1.0, 7.0, 8.0]);
}

#[test]
fn scale_works() {
let mut u = NumVector::<f64>::from(&[2.0, 4.0]);
u.scale(0.5);
assert_eq!(u.data, &[1.0, 2.0]);
}

#[test]
fn map_works() {
let mut u = NumVector::<f64>::from(&[-1.0, -2.0, -3.0]);
Expand Down
2 changes: 1 addition & 1 deletion russell_ode/src/ode_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl<'a, A> OdeSolver<'a, A> {
/// # Generics
///
/// * `A` -- generic argument to assist in the f(x,y) and Jacobian functions.
/// It may be simply [NoArgs] indicating that no arguments are needed.
/// It may be simply [crate::NoArgs] indicating that no arguments are needed.
pub fn new(params: Params, system: System<'a, A>) -> Result<Self, StrError>
where
A: 'a,
Expand Down
10 changes: 9 additions & 1 deletion russell_ode/src/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ impl<'a, A> Output<'a, A> {
if self.with_dense_output() {
if let Some(h_out) = self.dense_h_out {
// uniform spacing
let n = ((x1 - x0) / h_out) as usize + 1;
let n = usize::max(2, ((x1 - x0) / h_out) as usize + 1); // at least 2 (first and last) are required
if self.dense_x.len() != n {
self.dense_x.resize(n, 0.0);
}
Expand Down Expand Up @@ -808,6 +808,14 @@ mod tests {
assert_eq!(y0_out.len(), 4);
}

#[test]
fn initialize_with_dense_output_works_at_least_two_stations() {
let mut out = Output::<'_, NoArgs>::new();
out.set_dense_h_out(0.5).unwrap().set_dense_recording(&[0]);
out.initialize(0.99, 1.0, false).unwrap();
assert_eq!(out.dense_x.len(), 2);
}

#[test]
fn initialize_with_step_output_works() {
let mut out = Output::<'_, NoArgs>::new();
Expand Down
Loading

0 comments on commit 840c1a8

Please sign in to comment.