Skip to content

Commit

Permalink
sparse by dense parallel products sparsemat#298
Browse files Browse the repository at this point in the history
  • Loading branch information
aujxn committed Apr 26, 2022
1 parent 25e35d2 commit 31cea4e
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ exclude = [

[features]
default = ["alga", "multi_thread"]
multi_thread = ["rayon", "num_cpus"]
multi_thread = ["rayon", "num_cpus", "ndarray/rayon"]

[dependencies]
num-traits = "0.2.0"
Expand Down
16 changes: 11 additions & 5 deletions src/sparse/csmat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1901,7 +1901,7 @@ where
impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Mul<&'b ArrayBase<DS2, Ix2>>
for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
where
N: 'a + crate::MulAcc + num_traits::Zero + Clone,
N: 'a + crate::MulAcc + num_traits::Zero + Clone + Send + Sync,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpS: 'a + Deref<Target = [Iptr]>,
Expand Down Expand Up @@ -1962,7 +1962,13 @@ where
impl<'a, 'b, N, I, IpS, IS, DS, DS2> Dot<CsMatBase<N, I, IpS, IS, DS>>
for ArrayBase<DS2, Ix2>
where
N: 'a + Clone + crate::MulAcc + num_traits::Zero + std::fmt::Debug,
N: 'a
+ Clone
+ crate::MulAcc
+ num_traits::Zero
+ std::fmt::Debug
+ Send
+ Sync,
I: 'a + SpIndex,
IpS: 'a + Deref<Target = [I]>,
IS: 'a + Deref<Target = [I]>,
Expand Down Expand Up @@ -2013,7 +2019,7 @@ where
impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Dot<ArrayBase<DS2, Ix2>>
for CsMatBase<N, I, IpS, IS, DS, Iptr>
where
N: 'a + Clone + crate::MulAcc + num_traits::Zero,
N: 'a + Clone + crate::MulAcc + num_traits::Zero + Send + Sync,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpS: 'a + Deref<Target = [Iptr]>,
Expand All @@ -2031,7 +2037,7 @@ where
impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Mul<&'b ArrayBase<DS2, Ix1>>
for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
where
N: 'a + Clone + crate::MulAcc + num_traits::Zero,
N: 'a + Clone + crate::MulAcc + num_traits::Zero + Send + Sync,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpS: 'a + Deref<Target = [Iptr]>,
Expand Down Expand Up @@ -2072,7 +2078,7 @@ where
impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Dot<ArrayBase<DS2, Ix1>>
for CsMatBase<N, I, IpS, IS, DS, Iptr>
where
N: 'a + Clone + crate::MulAcc + num_traits::Zero,
N: 'a + Clone + crate::MulAcc + num_traits::Zero + Send + Sync,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpS: 'a + Deref<Target = [Iptr]>,
Expand Down
75 changes: 71 additions & 4 deletions src/sparse/prod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ pub fn csr_mulacc_dense_rowmaj<'a, N, A, B, I, Iptr>(
rhs: ArrayView<B, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
) where
N: 'a + crate::MulAcc<A, B>,
A: Send + Sync,
B: Send + Sync,
N: 'a + crate::MulAcc<A, B> + Send + Sync,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
{
Expand All @@ -200,6 +202,21 @@ pub fn csr_mulacc_dense_rowmaj<'a, N, A, B, I, Iptr>(
assert!(lhs.is_csr(), "Storage mismatch");

let axis0 = Axis(0);
#[cfg(feature = "multi_thread")]
for (line, mut oline) in lhs.outer_iterator().zip(out.axis_iter_mut(axis0))
{
for (col_ind, lval) in line.iter() {
let rline = rhs.row(col_ind);
// TODO: call an axpy primitive to benefit from vectorisation?
ndarray::Zip::from(&mut oline).and(rline).par_for_each(
|oval, rval| {
oval.mul_acc(lval, rval);
},
);
}
}

#[cfg(not(feature = "multi_thread"))]
for (line, mut oline) in lhs.outer_iterator().zip(out.axis_iter_mut(axis0))
{
for (col_ind, lval) in line.iter() {
Expand All @@ -220,7 +237,9 @@ pub fn csc_mulacc_dense_rowmaj<'a, N, A, B, I, Iptr>(
rhs: ArrayView<B, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
) where
N: 'a + crate::MulAcc<A, B>,
A: Send + Sync,
B: Send + Sync,
N: 'a + crate::MulAcc<A, B> + Send + Sync,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
{
Expand All @@ -229,6 +248,19 @@ pub fn csc_mulacc_dense_rowmaj<'a, N, A, B, I, Iptr>(
assert_eq!(rhs.shape()[1], out.shape()[1], "Dimension mismatch");
assert!(lhs.is_csc(), "Storage mismatch");

#[cfg(feature = "multi_thread")]
for (lcol, rline) in lhs.outer_iterator().zip(rhs.outer_iter()) {
for (orow, lval) in lcol.iter() {
let oline = out.row_mut(orow);
ndarray::Zip::from(oline)
.and(rline)
.par_for_each(|oval, rval| {
oval.mul_acc(lval, rval);
});
}
}

#[cfg(not(feature = "multi_thread"))]
for (lcol, rline) in lhs.outer_iterator().zip(rhs.outer_iter()) {
for (orow, lval) in lcol.iter() {
let mut oline = out.row_mut(orow);
Expand All @@ -247,7 +279,9 @@ pub fn csc_mulacc_dense_colmaj<'a, N, A, B, I, Iptr>(
rhs: ArrayView<B, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
) where
N: 'a + crate::MulAcc<A, B>,
A: Send + Sync,
B: Send + Sync,
N: 'a + crate::MulAcc<A, B> + Send + Sync,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
{
Expand All @@ -257,6 +291,20 @@ pub fn csc_mulacc_dense_colmaj<'a, N, A, B, I, Iptr>(
assert!(lhs.is_csc(), "Storage mismatch");

let axis1 = Axis(1);
// NOTE: See csr_mulacc_dense_colmaj, same issue
#[cfg(feature = "multi_thread")]
ndarray::Zip::from(out.axis_iter_mut(axis1))
.and(rhs.axis_iter(axis1))
.par_for_each(|mut ocol, rcol| {
for (rrow, lcol) in lhs.outer_iterator().enumerate() {
let rval = &rcol[[rrow]];
for (orow, lval) in lcol.iter() {
ocol[[orow]].mul_acc(lval, rval);
}
}
});

#[cfg(not(feature = "multi_thread"))]
for (mut ocol, rcol) in out.axis_iter_mut(axis1).zip(rhs.axis_iter(axis1)) {
for (rrow, lcol) in lhs.outer_iterator().enumerate() {
let rval = &rcol[[rrow]];
Expand All @@ -275,7 +323,9 @@ pub fn csr_mulacc_dense_colmaj<'a, N, A, B, I, Iptr>(
rhs: ArrayView<B, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
) where
N: 'a + crate::MulAcc<A, B>,
A: Send + Sync,
B: Send + Sync,
N: 'a + crate::MulAcc<A, B> + Send + Sync,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
{
Expand All @@ -285,6 +335,23 @@ pub fn csr_mulacc_dense_colmaj<'a, N, A, B, I, Iptr>(
assert!(lhs.is_csr(), "Storage mismatch");

let axis1 = Axis(1);
// NOTE: This is parallel over the columns of the output and rhs
// which isn't ideal. This is still sequential for dense vector product.
// Ideally CsMat.outer_iterator() should get a par_iter rayon impl
#[cfg(feature = "multi_thread")]
ndarray::Zip::from(out.axis_iter_mut(axis1))
.and(rhs.axis_iter(axis1))
.par_for_each(|mut ocol, rcol| {
for (orow, lrow) in lhs.outer_iterator().enumerate() {
let oval = &mut ocol[[orow]];
for (rrow, lval) in lrow.iter() {
let rval = &rcol[[rrow]];
oval.mul_acc(lval, rval);
}
}
});

#[cfg(not(feature = "multi_thread"))]
for (mut ocol, rcol) in out.axis_iter_mut(axis1).zip(rhs.axis_iter(axis1)) {
for (orow, lrow) in lhs.outer_iterator().enumerate() {
let oval = &mut ocol[[orow]];
Expand Down

0 comments on commit 31cea4e

Please sign in to comment.