From 711f8772500033553063e09673fe87e2558513ef Mon Sep 17 00:00:00 2001 From: Dorival Pedroso Date: Wed, 20 Sep 2023 17:30:20 +1000 Subject: [PATCH] Simplify the COO to dense conversion --- russell_sparse/src/coo_matrix.rs | 71 +++++++++----------------------- russell_sparse/src/csr_matrix.rs | 1 - 2 files changed, 19 insertions(+), 53 deletions(-) diff --git a/russell_sparse/src/coo_matrix.rs b/russell_sparse/src/coo_matrix.rs index c91941b3..66fe32a3 100644 --- a/russell_sparse/src/coo_matrix.rs +++ b/russell_sparse/src/coo_matrix.rs @@ -182,12 +182,11 @@ impl CooMatrix { self.pos = 0; } - /// Converts the CooMatrix to a dense matrix and returns the matrix - /// - /// Note: this function calls [CooMatrix::to_matrix]. + /// Converts this CooMatrix to a dense matrix /// /// ``` - /// use russell_sparse::{CooMatrix, Layout, StrError}; + /// use russell_sparse::prelude::*; + /// use russell_sparse::StrError; /// /// fn main() -> Result<(), StrError> { /// // define (4 x 4) sparse matrix with 6+1 non-zero values @@ -202,7 +201,7 @@ impl CooMatrix { /// coo.put(2, 2, 5.0)?; /// coo.put(3, 3, 6.0)?; /// - /// // convert to matrix + /// // convert to dense /// let a = coo.as_matrix(); /// let correct = "┌ ┐\n\ /// │ 1 2 0 0 │\n\ @@ -220,20 +219,18 @@ impl CooMatrix { a } - /// Converts the CooMatrix to a dense matrix, up to a limit - /// - /// Note: see the function [CooMatrix::as_matrix] that returns the Matrix. + /// Converts this CooMatrix to a dense matrix /// /// # Input /// - /// `a` -- (nrow_max, ncol_max) matrix to hold the triplet data. - /// The output matrix may have fewer rows or fewer columns than the triplet data. + /// * `a` -- where to store the dense matrix; must be (nrow, ncol) /// - /// # Example + /// # Examples /// /// ``` - /// use russell_lab::{Matrix}; - /// use russell_sparse::{CooMatrix, Layout, StrError}; + /// use russell_lab::Matrix; + /// use russell_sparse::prelude::*; + /// use russell_sparse::StrError; /// /// fn main() -> Result<(), StrError> { /// // define (4 x 4) sparse matrix with 6+1 non-zero values @@ -248,47 +245,31 @@ impl CooMatrix { /// coo.put(2, 2, 5.0)?; /// coo.put(3, 3, 6.0)?; /// - /// // convert the first (3 x 3) values - /// let mut a = Matrix::new(3, 3); + /// // convert to dense + /// let mut a = Matrix::new(4, 4); /// coo.to_matrix(&mut a)?; - /// let correct = "┌ ┐\n\ - /// │ 1 2 0 │\n\ - /// │ 3 4 0 │\n\ - /// │ 0 0 5 │\n\ - /// └ ┘"; - /// assert_eq!(format!("{}", a), correct); - /// - /// // convert the first (4 x 4) values - /// let mut b = Matrix::new(4, 4); - /// coo.to_matrix(&mut b)?; /// let correct = "┌ ┐\n\ /// │ 1 2 0 0 │\n\ /// │ 3 4 0 0 │\n\ /// │ 0 0 5 0 │\n\ /// │ 0 0 0 6 │\n\ /// └ ┘"; - /// assert_eq!(format!("{}", b), correct); + /// assert_eq!(format!("{}", a), correct); /// Ok(()) /// } /// ``` pub fn to_matrix(&self, a: &mut Matrix) -> Result<(), StrError> { let (m, n) = a.dims(); - if m > self.nrow || n > self.ncol { + if m != self.nrow || n != self.ncol { return Err("wrong matrix dimensions"); } - if self.layout != Layout::Full && m != n { - return Err("the resulting matrix must be square when the layout is either lower of upper triangular"); - } - let m_i32 = to_i32(m); - let n_i32 = to_i32(n); a.fill(0.0); for p in 0..self.pos { - if self.indices_i[p] < m_i32 && self.indices_j[p] < n_i32 { - let (i, j) = (self.indices_i[p] as usize, self.indices_j[p] as usize); - a.add(i, j, self.values_aij[p]); - if self.layout != Layout::Full && i != j { - a.add(j, i, self.values_aij[p]); - } + let i = self.indices_i[p] as usize; + let j = self.indices_j[p] as usize; + a.add(i, j, self.values_aij[p]); + if self.layout != Layout::Full && i != j { + a.add(j, i, self.values_aij[p]); } } Ok(()) @@ -481,10 +462,6 @@ mod tests { assert_eq!(a.get(1, 0), 3.0); assert_eq!(a.get(1, 1), 4.0); assert_eq!(a.get(2, 2), 5.0); - let mut b = Matrix::new(2, 1); - coo.to_matrix(&mut b).unwrap(); - assert_eq!(b.get(0, 0), 1.0); - assert_eq!(b.get(1, 0), 3.0); // using as_matrix let bb = coo.as_matrix(); assert_eq!(bb.get(0, 0), 1.0); @@ -538,11 +515,6 @@ mod tests { │ 0 4 0 │\n\ └ ┘"; assert_eq!(format!("{}", a), correct); - let mut b = Matrix::new(2, 1); - assert_eq!( - coo.to_matrix(&mut b).err(), - Some("the resulting matrix must be square when the layout is either lower of upper triangular") - ); } #[test] @@ -560,11 +532,6 @@ mod tests { │ 0 4 0 │\n\ └ ┘"; assert_eq!(format!("{}", a), correct); - let mut b = Matrix::new(2, 1); - assert_eq!( - coo.to_matrix(&mut b).err(), - Some("the resulting matrix must be square when the layout is either lower of upper triangular") - ); } #[test] diff --git a/russell_sparse/src/csr_matrix.rs b/russell_sparse/src/csr_matrix.rs index 790ea58f..84f4c6b2 100644 --- a/russell_sparse/src/csr_matrix.rs +++ b/russell_sparse/src/csr_matrix.rs @@ -170,7 +170,6 @@ impl CsrMatrix { /// # Examples /// /// ``` - /// use russell_lab::Matrix; /// use russell_sparse::prelude::*; /// use russell_sparse::StrError; ///