From 7190c6ad8446060643022fe163f57b4c1d324cb3 Mon Sep 17 00:00:00 2001 From: "romain.biessy" Date: Fri, 7 Jun 2024 08:25:26 +0100 Subject: [PATCH] Disallow symmetric/hermitian conjtrans configurations for spmv --- .../backends/mkl_common/mkl_spmv.cxx | 22 ++++++++---- .../include/common_sparse_reference.hpp | 4 +-- .../sparse_blas/include/test_spmv.hpp | 34 ++++++++++--------- 3 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx b/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx index 6950dc700..7f809c75e 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx @@ -30,7 +30,7 @@ sycl::event release_spmv_descr(sycl::queue &queue, oneapi::mkl::sparse::spmv_des } void check_valid_spmv(const std::string function_name, sycl::queue &queue, - oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::transpose opA, oneapi::mkl::sparse::matrix_view A_view, oneapi::mkl::sparse::matrix_handle_t A_handle, oneapi::mkl::sparse::dense_vector_handle_t x_handle, oneapi::mkl::sparse::dense_vector_handle_t y_handle, const void *alpha, @@ -51,14 +51,22 @@ void check_valid_spmv(const std::string function_name, sycl::queue &queue, } if (A_view.type_view != oneapi::mkl::sparse::matrix_descr::triangular && - A_view.diag_view != oneapi::mkl::diag::nonunit) { + A_view.diag_view == oneapi::mkl::diag::unit) { throw mkl::invalid_argument( "sparse_blas", function_name, "`unit` diag_view can only be used with a triangular type_view."); } + + if ((A_view.type_view == oneapi::mkl::sparse::matrix_descr::symmetric || + A_view.type_view == oneapi::mkl::sparse::matrix_descr::hermitian) && + opA == oneapi::mkl::transpose::conjtrans) { + throw mkl::invalid_argument( + "sparse_blas", function_name, + "Symmetric or Hermitian matrix cannot be conjugated with `conjtrans`."); + } } -void spmv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose /*opA*/, const void *alpha, +void spmv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha, oneapi::mkl::sparse::matrix_view A_view, oneapi::mkl::sparse::matrix_handle_t A_handle, oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void *beta, @@ -67,7 +75,7 @@ void spmv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose /*opA*/, const oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, std::size_t &temp_buffer_size) { // TODO: Add support for external workspace once the close-source oneMKL backend supports it. - check_valid_spmv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, beta); + check_valid_spmv(__FUNCTION__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta); temp_buffer_size = 0; } @@ -79,7 +87,7 @@ void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, sycl::buffer /*workspace*/) { - check_valid_spmv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, beta); + check_valid_spmv(__FUNCTION__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta); auto internal_A_handle = detail::get_internal_handle(A_handle); if (!internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__FUNCTION__); @@ -113,7 +121,7 @@ sycl::event spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, void * /*workspace*/, const std::vector &dependencies) { - check_valid_spmv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, beta); + check_valid_spmv(__FUNCTION__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta); auto internal_A_handle = detail::get_internal_handle(A_handle); if (internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__FUNCTION__); @@ -196,7 +204,7 @@ sycl::event spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp oneapi::mkl::sparse::dense_vector_handle_t y_handle, oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t spmv_descr, const std::vector &dependencies) { - check_valid_spmv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, beta); + check_valid_spmv(__FUNCTION__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta); auto value_type = detail::get_internal_handle(A_handle)->get_value_type(); DISPATCH_MKL_OPERATION("spmv", value_type, internal_spmv, queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg, spmv_descr, dependencies); diff --git a/tests/unit_tests/sparse_blas/include/common_sparse_reference.hpp b/tests/unit_tests/sparse_blas/include/common_sparse_reference.hpp index 7949342d3..d8b11e6b7 100644 --- a/tests/unit_tests/sparse_blas/include/common_sparse_reference.hpp +++ b/tests/unit_tests/sparse_blas/include/common_sparse_reference.hpp @@ -161,9 +161,7 @@ std::vector sparse_to_dense(sparse_matrix_format_t format, const intType const bool is_symmetric_or_hermitian_view = type_view == oneapi::mkl::sparse::matrix_descr::symmetric || type_view == oneapi::mkl::sparse::matrix_descr::hermitian; - // Matrices are not conjugated if they are symmetric - const bool apply_conjugate = - !is_symmetric_or_hermitian_view && transpose_val == oneapi::mkl::transpose::conjtrans; + const bool apply_conjugate = transpose_val == oneapi::mkl::transpose::conjtrans; std::vector dense_a(a_nrows * a_ncols, fpType(0)); auto write_to_dense_if_needed = [&](std::size_t a_idx, std::size_t row, std::size_t col) { diff --git a/tests/unit_tests/sparse_blas/include/test_spmv.hpp b/tests/unit_tests/sparse_blas/include/test_spmv.hpp index eee9ec124..70738dd02 100644 --- a/tests/unit_tests/sparse_blas/include/test_spmv.hpp +++ b/tests/unit_tests/sparse_blas/include/test_spmv.hpp @@ -143,22 +143,24 @@ void test_helper_with_format( fp_one, fp_zero, default_alg, triangular_unit_A_view, no_properties, no_reset_data), num_passed, num_skipped); - // Lower symmetric or hermitian - oneapi::mkl::sparse::matrix_view symmetric_view( - complex_info::is_complex ? oneapi::mkl::sparse::matrix_descr::hermitian - : oneapi::mkl::sparse::matrix_descr::symmetric); - EXPECT_TRUE_OR_FUTURE_SKIP( - test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, - fp_one, fp_zero, default_alg, symmetric_view, no_properties, - no_reset_data), - num_passed, num_skipped); - // Upper symmetric or hermitian - symmetric_view.uplo_view = oneapi::mkl::uplo::upper; - EXPECT_TRUE_OR_FUTURE_SKIP( - test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, - fp_one, fp_zero, default_alg, symmetric_view, no_properties, - no_reset_data), - num_passed, num_skipped); + if (transpose_val != oneapi::mkl::transpose::conjtrans) { + // Lower symmetric or hermitian + oneapi::mkl::sparse::matrix_view symmetric_view( + complex_info::is_complex ? oneapi::mkl::sparse::matrix_descr::hermitian + : oneapi::mkl::sparse::matrix_descr::symmetric); + EXPECT_TRUE_OR_FUTURE_SKIP( + test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, + transpose_val, fp_one, fp_zero, default_alg, symmetric_view, + no_properties, no_reset_data), + num_passed, num_skipped); + // Upper symmetric or hermitian + symmetric_view.uplo_view = oneapi::mkl::uplo::upper; + EXPECT_TRUE_OR_FUTURE_SKIP( + test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, + transpose_val, fp_one, fp_zero, default_alg, symmetric_view, + no_properties, no_reset_data), + num_passed, num_skipped); + } // Test other algorithms for (auto alg : non_default_algorithms) { EXPECT_TRUE_OR_FUTURE_SKIP(test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix,