From a435f75ce6738116c4e3ab0b4405236b25d40a5f Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Tue, 19 Nov 2024 16:19:29 +0000 Subject: [PATCH 1/2] Extend sparse linalg interface with multiply-add. --- .../linalg/sparse/SparseMatrixMultiply.h | 43 +++++- .../linalg/sparse/SparseMatrixMultiply.tcc | 83 +++++++++++- .../SparseMatrixMultiply_EckitLinalg.cc | 51 ++++++- .../sparse/SparseMatrixMultiply_EckitLinalg.h | 12 +- .../sparse/SparseMatrixMultiply_OpenMP.cc | 125 ++++++++++++++---- .../sparse/SparseMatrixMultiply_OpenMP.h | 24 +++- src/tests/linalg/test_linalg_sparse.cc | 96 ++++++++++++++ 7 files changed, 384 insertions(+), 50 deletions(-) diff --git a/src/atlas/linalg/sparse/SparseMatrixMultiply.h b/src/atlas/linalg/sparse/SparseMatrixMultiply.h index d1776975b..9e337fe08 100644 --- a/src/atlas/linalg/sparse/SparseMatrixMultiply.h +++ b/src/atlas/linalg/sparse/SparseMatrixMultiply.h @@ -38,6 +38,19 @@ template void sparse_matrix_multiply(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing, const Configuration& config); +template +void sparse_matrix_multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt); + +template +void sparse_matrix_multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt, const Configuration& config); + +template +void sparse_matrix_multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing); + +template +void sparse_matrix_multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing, + const Configuration& config); + class SparseMatrixMultiply { public: SparseMatrixMultiply() = default; @@ -46,14 +59,34 @@ class SparseMatrixMultiply { template void operator()(const Matrix& matrix, const SourceView& src, TargetView& tgt) const { - sparse_matrix_multiply(matrix, src, tgt, backend()); + multiply(matrix, src, tgt); } template void operator()(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing) const { + multiply(matrix, src, tgt, indexing); + } + + template + void multiply(const Matrix& matrix, const SourceView& src, TargetView& tgt) const { + sparse_matrix_multiply(matrix, src, tgt, backend()); + } + + template + void multiply(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing) const { sparse_matrix_multiply(matrix, src, tgt, indexing, backend()); } + template + void multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt) const { + sparse_matrix_multiply_add(matrix, src, tgt, backend()); + } + + template + void multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing) const { + sparse_matrix_multiply_add(matrix, src, tgt, indexing, backend()); + } + const sparse::Backend& backend() const { return backend_; } private: @@ -65,8 +98,12 @@ namespace sparse { // Template class which needs (full or partial) specialization for concrete template parameters template struct SparseMatrixMultiply { - static void apply(const SparseMatrix&, const View&, View&, - const Configuration&) { + static void multiply(const SparseMatrix&, const View&, View&, + const Configuration&) { + throw_NotImplemented("SparseMatrixMultiply needs a template specialization with the implementation", Here()); + } + static void multiply_add(const SparseMatrix&, const View&, View&, + const Configuration&) { throw_NotImplemented("SparseMatrixMultiply needs a template specialization with the implementation", Here()); } }; diff --git a/src/atlas/linalg/sparse/SparseMatrixMultiply.tcc b/src/atlas/linalg/sparse/SparseMatrixMultiply.tcc index b4e1ed9cd..4b44b984c 100644 --- a/src/atlas/linalg/sparse/SparseMatrixMultiply.tcc +++ b/src/atlas/linalg/sparse/SparseMatrixMultiply.tcc @@ -33,14 +33,25 @@ namespace { template struct SparseMatrixMultiplyHelper { template - static void apply( const SparseMatrix& W, const SourceView& src, TargetView& tgt, + static void multiply( const SparseMatrix& W, const SourceView& src, TargetView& tgt, + const eckit::Configuration& config ) { + using SourceValue = const typename std::remove_const::type; + using TargetValue = typename std::remove_const::type; + constexpr int src_rank = introspection::rank(); + constexpr int tgt_rank = introspection::rank(); + static_assert( src_rank == tgt_rank, "src and tgt need same rank" ); + SparseMatrixMultiply::multiply( W, src, tgt, config ); + } + + template + static void multiply_add( const SparseMatrix& W, const SourceView& src, TargetView& tgt, const eckit::Configuration& config ) { using SourceValue = const typename std::remove_const::type; using TargetValue = typename std::remove_const::type; constexpr int src_rank = introspection::rank(); constexpr int tgt_rank = introspection::rank(); static_assert( src_rank == tgt_rank, "src and tgt need same rank" ); - SparseMatrixMultiply::apply( W, src, tgt, config ); + SparseMatrixMultiply::multiply_add( W, src, tgt, config ); } }; @@ -53,14 +64,38 @@ void dispatch_sparse_matrix_multiply( const Matrix& matrix, const SourceView& sr if ( introspection::layout_right( src ) || introspection::layout_right( tgt ) ) { ATLAS_ASSERT( introspection::layout_right( src ) && introspection::layout_right( tgt ) ); // Override layout with known layout given by introspection - SparseMatrixMultiplyHelper::apply( matrix, src_v, tgt_v, config ); + SparseMatrixMultiplyHelper::multiply( matrix, src_v, tgt_v, config ); + } + else { + if( indexing == Indexing::layout_left ) { + SparseMatrixMultiplyHelper::multiply( matrix, src_v, tgt_v, config ); + } + else if( indexing == Indexing::layout_right ) { + SparseMatrixMultiplyHelper::multiply( matrix, src_v, tgt_v, config ); + } + else { + throw_NotImplemented( "indexing not implemented", Here() ); + } + } +} + +template +void dispatch_sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing, + const eckit::Configuration& config ) { + auto src_v = make_view( src ); + auto tgt_v = make_view( tgt ); + + if ( introspection::layout_right( src ) || introspection::layout_right( tgt ) ) { + ATLAS_ASSERT( introspection::layout_right( src ) && introspection::layout_right( tgt ) ); + // Override layout with known layout given by introspection + SparseMatrixMultiplyHelper::multiply_add( matrix, src_v, tgt_v, config ); } else { if( indexing == Indexing::layout_left ) { - SparseMatrixMultiplyHelper::apply( matrix, src_v, tgt_v, config ); + SparseMatrixMultiplyHelper::multiply_add( matrix, src_v, tgt_v, config ); } else if( indexing == Indexing::layout_right ) { - SparseMatrixMultiplyHelper::apply( matrix, src_v, tgt_v, config ); + SparseMatrixMultiplyHelper::multiply_add( matrix, src_v, tgt_v, config ); } else { throw_NotImplemented( "indexing not implemented", Here() ); @@ -108,6 +143,44 @@ void sparse_matrix_multiply( const Matrix& matrix, const SourceView& src, Target sparse_matrix_multiply( matrix, src, tgt, Indexing::layout_left ); } +template +void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing, + const eckit::Configuration& config ) { + std::string type = config.getString( "type", sparse::current_backend() ); + if ( type == sparse::backend::openmp::type() ) { + sparse::dispatch_sparse_matrix_multiply_add( matrix, src, tgt, indexing, config ); + } + else if ( type == sparse::backend::eckit_linalg::type() ) { + sparse::dispatch_sparse_matrix_multiply_add( matrix, src, tgt, indexing, config ); + } +#if ATLAS_ECKIT_HAVE_ECKIT_585 + else if( eckit::linalg::LinearAlgebraSparse::hasBackend(type) ) { +#else + else if( eckit::linalg::LinearAlgebra::hasBackend(type) ) { +#endif + sparse::dispatch_sparse_matrix_multiply_add( matrix, src, tgt, indexing, util::Config("backend",type) ); + } + else { + throw_NotImplemented( "sparse_matrix_multiply_add cannot be performed with unsupported backend [" + type + "]", + Here() ); + } +} + +template +void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt, const eckit::Configuration& config ) { + sparse_matrix_multiply_add( matrix, src, tgt, Indexing::layout_left, config ); +} + +template +void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing ) { + sparse_matrix_multiply_add( matrix, src, tgt, indexing, sparse::Backend() ); +} + +template +void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt ) { + sparse_matrix_multiply_add( matrix, src, tgt, Indexing::layout_left ); +} + } // namespace linalg } // namespace atlas diff --git a/src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.cc b/src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.cc index 72d55c1a9..4c4ec1c58 100644 --- a/src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.cc +++ b/src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.cc @@ -20,6 +20,7 @@ #include "SparseMatrixMultiply_EckitLinalg.h" +#include "atlas/array.h" #include "atlas/library/config.h" #if ATLAS_ECKIT_HAVE_ECKIT_585 @@ -62,9 +63,15 @@ const eckit::linalg::LinearAlgebra& eckit_linalg_backend(const Configuration& co } #endif +template +auto linalg_make_view(atlas::array::ArrayT& array) { + auto v_array = array::make_view(array); + return atlas::linalg::make_view(v_array); +} + } // namespace -void SparseMatrixMultiply::apply( +void SparseMatrixMultiply::multiply( const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { ATLAS_ASSERT(src.contiguous()); ATLAS_ASSERT(tgt.contiguous()); @@ -73,7 +80,7 @@ void SparseMatrixMultiply::apply( +void SparseMatrixMultiply::multiply( const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { ATLAS_ASSERT(src.contiguous()); ATLAS_ASSERT(tgt.contiguous()); @@ -84,9 +91,45 @@ void SparseMatrixMultiply::apply( +void SparseMatrixMultiply::multiply( + const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { + SparseMatrixMultiply::multiply(W, src, tgt, + config); +} + +void SparseMatrixMultiply::multiply_add( + const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { + + array::ArrayT tmp(src.shape(0)); + auto v_tmp = linalg_make_view(tmp); + v_tmp.assign(0.); + + SparseMatrixMultiply::multiply(W, src, v_tmp, config); + + for (idx_t t = 0; t < tmp.shape(0); ++t) { + tgt(t) += v_tmp(t); + } +} + +void SparseMatrixMultiply::multiply_add( + const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { + + array::ArrayT tmp(src.shape(0), src.shape(1)); + auto v_tmp = linalg_make_view(tmp); + v_tmp.assign(0.); + + SparseMatrixMultiply::multiply(W, src, v_tmp, config); + + for (idx_t t = 0; t < tmp.shape(0); ++t) { + for (idx_t k = 0; k < tmp.shape(1); ++k) { + tgt(t, k) += v_tmp(t, k); + } + } +} + +void SparseMatrixMultiply::multiply_add( const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { - SparseMatrixMultiply::apply(W, src, tgt, + SparseMatrixMultiply::multiply_add(W, src, tgt, config); } diff --git a/src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.h b/src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.h index af99d57df..6915a10f3 100644 --- a/src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.h +++ b/src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.h @@ -28,20 +28,26 @@ namespace sparse { template <> struct SparseMatrixMultiply { - static void apply(const SparseMatrix&, const View& src, View& tgt, + static void multiply(const SparseMatrix&, const View& src, View& tgt, + const Configuration&); + static void multiply_add(const SparseMatrix&, const View& src, View& tgt, const Configuration&); }; template <> struct SparseMatrixMultiply { - static void apply(const SparseMatrix&, const View& src, View& tgt, + static void multiply(const SparseMatrix&, const View& src, View& tgt, + const Configuration&); + static void multiply_add(const SparseMatrix&, const View& src, View& tgt, const Configuration&); }; template <> struct SparseMatrixMultiply { - static void apply(const SparseMatrix&, const View& src, View& tgt, + static void multiply(const SparseMatrix&, const View& src, View& tgt, + const Configuration&); + static void multiply_add(const SparseMatrix&, const View& src, View& tgt, const Configuration&); }; diff --git a/src/atlas/linalg/sparse/SparseMatrixMultiply_OpenMP.cc b/src/atlas/linalg/sparse/SparseMatrixMultiply_OpenMP.cc index 7a610b892..41d01737f 100644 --- a/src/atlas/linalg/sparse/SparseMatrixMultiply_OpenMP.cc +++ b/src/atlas/linalg/sparse/SparseMatrixMultiply_OpenMP.cc @@ -17,9 +17,8 @@ namespace atlas { namespace linalg { namespace sparse { -template -void SparseMatrixMultiply::apply( - const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { +template +void spmv_layout_left(const SparseMatrix& W, const View& src, View& tgt) { using Value = TargetValue; const auto outer = W.outer(); const auto index = W.inner(); @@ -30,7 +29,9 @@ void SparseMatrixMultiply= W.rows()); atlas_omp_parallel_for(idx_t r = 0; r < rows; ++r) { - tgt[r] = 0.; + if constexpr (SetZero) { + tgt[r] = 0.; + } for (idx_t c = outer[r]; c < outer[r + 1]; ++c) { idx_t n = index[c]; Value w = static_cast(weight[c]); @@ -39,10 +40,20 @@ void SparseMatrixMultiply +void SparseMatrixMultiply::multiply( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + spmv_layout_left(W, src, tgt); +} template -void SparseMatrixMultiply::apply( - const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { +void SparseMatrixMultiply::multiply_add( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + spmv_layout_left(W, src, tgt); +} + +template +void spmm_layout_left(const SparseMatrix& W, const View& src, View& tgt) { using Value = TargetValue; const auto outer = W.outer(); const auto index = W.inner(); @@ -54,8 +65,10 @@ void SparseMatrixMultiply= W.rows()); atlas_omp_parallel_for(idx_t r = 0; r < rows; ++r) { - for (idx_t k = 0; k < Nk; ++k) { - tgt(r, k) = 0.; + if constexpr (SetZero) { + for (idx_t k = 0; k < Nk; ++k) { + tgt(r, k) = 0.; + } } for (idx_t c = outer[r]; c < outer[r + 1]; ++c) { idx_t n = index[c]; @@ -68,14 +81,24 @@ void SparseMatrixMultiply -void SparseMatrixMultiply::apply( - const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { +void SparseMatrixMultiply::multiply( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + spmm_layout_left(W, src, tgt); +} + +template +void SparseMatrixMultiply::multiply_add( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + spmm_layout_left(W, src, tgt); +} + +template +void spmt_layout_left(const SparseMatrix& W, const View& src, View& tgt) { if (src.contiguous() && tgt.contiguous()) { // We can take a more optimized route by reducing rank auto src_v = View(src.data(), array::make_shape(src.shape(0), src.stride(0))); auto tgt_v = View(tgt.data(), array::make_shape(tgt.shape(0), tgt.stride(0))); - SparseMatrixMultiply::apply(W, src_v, - tgt_v, config); + spmm_layout_left(W, src_v, tgt_v); return; } using Value = TargetValue; @@ -87,9 +110,11 @@ void SparseMatrixMultiply -void SparseMatrixMultiply::apply( +void SparseMatrixMultiply::multiply( + const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { + spmt_layout_left(W, src, tgt); +} + +template +void SparseMatrixMultiply::multiply_add( + const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { + spmt_layout_left(W, src, tgt); +} + +template +void SparseMatrixMultiply::multiply( const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { - return SparseMatrixMultiply::apply(W, src, tgt, - config); + return SparseMatrixMultiply::multiply(W, src, tgt, config); } template -void SparseMatrixMultiply::apply( - const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { +void SparseMatrixMultiply::multiply_add( + const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { + return SparseMatrixMultiply::multiply_add(W, src, tgt, config); +} + +template +void spmm_layout_right(const SparseMatrix& W, const View& src, View& tgt) { using Value = TargetValue; const auto outer = W.outer(); const auto index = W.inner(); @@ -125,8 +166,10 @@ void SparseMatrixMultiply= W.rows()); atlas_omp_parallel_for(idx_t r = 0; r < rows; ++r) { - for (idx_t k = 0; k < Nk; ++k) { - tgt(k, r) = 0.; + if constexpr (SetZero) { + for (idx_t k = 0; k < Nk; ++k) { + tgt(k, r) = 0.; + } } for (idx_t c = outer[r]; c < outer[r + 1]; ++c) { idx_t n = index[c]; @@ -139,14 +182,24 @@ void SparseMatrixMultiply -void SparseMatrixMultiply::apply( - const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { +void SparseMatrixMultiply::multiply( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + spmm_layout_right(W, src, tgt); +} + +template +void SparseMatrixMultiply::multiply_add( + const SparseMatrix& W, const View& src, View& tgt, const Configuration&) { + spmm_layout_right(W, src, tgt); +} + +template +void spmt_layout_right(const SparseMatrix& W, const View& src, View& tgt) { if (src.contiguous() && tgt.contiguous()) { // We can take a more optimized route by reducing rank auto src_v = View(src.data(), array::make_shape(src.shape(0), src.stride(0))); auto tgt_v = View(tgt.data(), array::make_shape(tgt.shape(0), tgt.stride(0))); - SparseMatrixMultiply::apply( - W, src_v, tgt_v, config); + spmm_layout_right(W, src_v, tgt_v); return; } using Value = TargetValue; @@ -158,9 +211,11 @@ void SparseMatrixMultiply +void SparseMatrixMultiply::multiply( + const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { + spmt_layout_right(W, src, tgt); +} + +template +void SparseMatrixMultiply::multiply_add( + const SparseMatrix& W, const View& src, View& tgt, const Configuration& config) { + spmt_layout_right(W, src, tgt); +} + #define EXPLICIT_TEMPLATE_INSTANTIATION(TYPE) \ template struct SparseMatrixMultiply; \ template struct SparseMatrixMultiply; \ diff --git a/src/atlas/linalg/sparse/SparseMatrixMultiply_OpenMP.h b/src/atlas/linalg/sparse/SparseMatrixMultiply_OpenMP.h index f6cd0a17a..a3bdcea44 100644 --- a/src/atlas/linalg/sparse/SparseMatrixMultiply_OpenMP.h +++ b/src/atlas/linalg/sparse/SparseMatrixMultiply_OpenMP.h @@ -19,37 +19,49 @@ namespace sparse { template struct SparseMatrixMultiply { - static void apply(const SparseMatrix& W, const View& src, View& tgt, + static void multiply(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); + static void multiply_add(const SparseMatrix& W, const View& src, View& tgt, const Configuration&); }; template struct SparseMatrixMultiply { - static void apply(const SparseMatrix& W, const View& src, View& tgt, + static void multiply(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); + static void multiply_add(const SparseMatrix& W, const View& src, View& tgt, const Configuration&); }; template struct SparseMatrixMultiply { - static void apply(const SparseMatrix& W, const View& src, View& tgt, + static void multiply(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); + static void multiply_add(const SparseMatrix& W, const View& src, View& tgt, const Configuration&); }; template struct SparseMatrixMultiply { - static void apply(const SparseMatrix& W, const View& src, View& tgt, + static void multiply(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); + static void multiply_add(const SparseMatrix& W, const View& src, View& tgt, const Configuration&); }; template struct SparseMatrixMultiply { - static void apply(const SparseMatrix& W, const View& src, View& tgt, + static void multiply(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); + static void multiply_add(const SparseMatrix& W, const View& src, View& tgt, const Configuration&); }; template struct SparseMatrixMultiply { - static void apply(const SparseMatrix& W, const View& src, View& tgt, + static void multiply(const SparseMatrix& W, const View& src, View& tgt, + const Configuration&); + static void multiply_add(const SparseMatrix& W, const View& src, View& tgt, const Configuration&); }; diff --git a/src/tests/linalg/test_linalg_sparse.cc b/src/tests/linalg/test_linalg_sparse.cc index ee1b10302..0e6f8c81f 100644 --- a/src/tests/linalg/test_linalg_sparse.cc +++ b/src/tests/linalg/test_linalg_sparse.cc @@ -258,6 +258,56 @@ CASE("sparse_matrix vector multiply (spmv)") { EXPECT_THROWS_AS(sparse_matrix_multiply(A, x2.view(), y.view()), eckit::AssertionFailed); } } + + SECTION("View of atlas::Array [backend=" + backend + "]") { + ArrayVector x(Vector{1., 2., 3.}); + ArrayVector y(3); + auto spmm = SparseMatrixMultiply{backend}; + spmm(A, x.view(), y.view()); + expect_equal(y.view(), Vector{-7., 4., 6.}); + } + + SECTION("View of atlas::Array [backend=" + backend + "]") { + ArrayVector x(Vector{1., 2., 3.}); + ArrayVector y(3); + auto spmm = SparseMatrixMultiply{backend}; + spmm.multiply(A, x.view(), y.view()); + expect_equal(y.view(), Vector{-7., 4., 6.}); + } + } +} + +CASE("sparse_matrix vector multiply-add (spmv)") { + // "square" matrix + // A = 2 . -3 + // . 2 . + // . . 2 + // x = 1 2 3 + // y = 1 2 3 + SparseMatrix A{3, 3, {{0, 0, 2.}, {0, 2, -3.}, {1, 1, 2.}, {2, 2, 2.}}}; + + for (std::string backend : {openmp, eckit_linalg}) { + sparse::current_backend(backend); + + SECTION("View of atlas::Array [backend=" + backend + "]") { + ArrayVector x(Vector{1., 2., 3.}); + ArrayVector y(Vector{4., 5., 6.}); + sparse_matrix_multiply_add(A, x.view(), y.view()); + expect_equal(y.view(), Vector{-3., 9., 12.}); + // sparse_matrix_multiply_add of sparse matrix and vector of non-matching sizes should fail + { + ArrayVector x2(2); + EXPECT_THROWS_AS(sparse_matrix_multiply_add(A, x2.view(), y.view()), eckit::AssertionFailed); + } + } + + SECTION("sparse_matrix_multiply_add [backend=" + backend + "]") { + ArrayVector x(Vector{1., 2., 3.}); + ArrayVector y(Vector{1., 2., 3.}); + auto spmm = SparseMatrixMultiply{sparse::backend::openmp()}; + spmm.multiply_add(A, x.view(), y.view()); + expect_equal(y.view(), Vector{-6., 6., 9.}); + } } } @@ -326,8 +376,54 @@ CASE("sparse_matrix matrix multiply (spmm)") { spmm(A, ma.view(), c.view()); expect_equal(c.view(), ArrayMatrix(c_exp).view()); } + + SECTION("SparseMatrixMultiply::multiply [backend=openmp]") { + sparse::current_backend(eckit_linalg); // expected to be ignored + auto spmm = SparseMatrixMultiply{openmp}; + ArrayMatrix ma(m); + ArrayMatrix c(3, 2); + spmm.multiply(A, ma.view(), c.view()); + expect_equal(c.view(), ArrayMatrix(c_exp).view()); + } } +CASE("sparse_matrix matrix multiply-add (spmm)") { + // "square" + // A = 2 . -3 + // . 2 . + // . . 2 + SparseMatrix A{3, 3, {{0, 0, 2.}, {0, 2, -3.}, {1, 1, 2.}, {2, 2, 2.}}}; + Matrix m{{1., 2.}, {3., 4.}, {5., 6.}}; + Matrix y_exp{{-12., -12.}, {9., 12.}, {15., 18.}}; + + for (std::string backend : {openmp, eckit_linalg}) { + sparse::current_backend(backend); + + SECTION("View of atlas::Array PointsRight [backend=" + sparse::current_backend().type() + "]") { + ArrayMatrix x(m); + ArrayMatrix y(m); + sparse_matrix_multiply_add(A, x.view(), y.view(), Indexing::layout_right); + expect_equal(y.view(), y_exp); + } + } + + SECTION("sparse_matrix_multiply_add [backend=openmp]") { + ArrayMatrix x(m); + ArrayMatrix y(m); + sparse_matrix_multiply_add(A, x.view(), y.view(), sparse::backend::openmp()); + expect_equal(y.view(), ArrayMatrix(y_exp).view()); + } + + SECTION("SparseMatrixMultiply::multiply_add [backend=openmp]") { + auto spmm = SparseMatrixMultiply{sparse::backend::openmp()}; + ArrayMatrix x(m); + ArrayMatrix y(m); + spmm.multiply_add(A, x.view(), y.view()); + expect_equal(y.view(), ArrayMatrix(y_exp).view()); + } +} + + //---------------------------------------------------------------------------------------------------------------------- } // namespace test From e639a064e3e6fdc2e19cd596630c950e86e3441c Mon Sep 17 00:00:00 2001 From: Liam Adams Date: Tue, 19 Nov 2024 16:20:38 +0000 Subject: [PATCH 2/2] Update Method::adjoint_interpolate_* function to use multiply-add. --- src/atlas/interpolation/method/Method.cc | 43 +++--------------------- 1 file changed, 4 insertions(+), 39 deletions(-) diff --git a/src/atlas/interpolation/method/Method.cc b/src/atlas/interpolation/method/Method.cc index 3a2349a1d..1195a7d53 100644 --- a/src/atlas/interpolation/method/Method.cc +++ b/src/atlas/interpolation/method/Method.cc @@ -168,65 +168,30 @@ void Method::interpolate_field_rank3(const Field& src, Field& tgt, const Matrix& template void Method::adjoint_interpolate_field_rank1(Field& src, const Field& tgt, const Matrix& W) const { - array::ArrayT tmp(src.shape()); + auto backend = std::is_same::value ? sparse::backend::openmp() : sparse::Backend{linalg_backend_}; - auto tmp_v = array::make_view(tmp); auto src_v = array::make_view(src); auto tgt_v = array::make_view(tgt); - tmp_v.assign(0.); - - if (std::is_same::value) { - sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp()); - } - else { - sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::Backend{linalg_backend_}); - } - - - for (idx_t t = 0; t < tmp.shape(0); ++t) { - src_v(t) += tmp_v(t); - } + sparse_matrix_multiply_add(W, tgt_v, src_v, backend); } template void Method::adjoint_interpolate_field_rank2(Field& src, const Field& tgt, const Matrix& W) const { - array::ArrayT tmp(src.shape()); - auto tmp_v = array::make_view(tmp); auto src_v = array::make_view(src); auto tgt_v = array::make_view(tgt); - tmp_v.assign(0.); - - sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp()); - - for (idx_t t = 0; t < tmp.shape(0); ++t) { - for (idx_t k = 0; k < tmp.shape(1); ++k) { - src_v(t, k) += tmp_v(t, k); - } - } + sparse_matrix_multiply_add(W, tgt_v, src_v, sparse::backend::openmp()); } template void Method::adjoint_interpolate_field_rank3(Field& src, const Field& tgt, const Matrix& W) const { - array::ArrayT tmp(src.shape()); - auto tmp_v = array::make_view(tmp); auto src_v = array::make_view(src); auto tgt_v = array::make_view(tgt); - tmp_v.assign(0.); - - sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp()); - - for (idx_t t = 0; t < tmp.shape(0); ++t) { - for (idx_t j = 0; j < tmp.shape(1); ++j) { - for (idx_t k = 0; k < tmp.shape(2); ++k) { - src_v(t, j, k) += tmp_v(t, j, k); - } - } - } + sparse_matrix_multiply_add(W, tgt_v, src_v, sparse::backend::openmp()); } void Method::check_compatibility(const Field& src, const Field& tgt, const Matrix& W) const {