From fc1f08d8bb1f08c3bf44dce2146ea802e218c99b Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Tue, 28 May 2024 11:24:16 +0100 Subject: [PATCH] Do not allow changing data types of dense handles --- .../backends/mkl_common/mkl_handles.cxx | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/src/sparse_blas/backends/mkl_common/mkl_handles.cxx b/src/sparse_blas/backends/mkl_common/mkl_handles.cxx index 38d102768..f3ff5afa2 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_handles.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_handles.cxx @@ -32,10 +32,23 @@ void init_dense_vector(sycl::queue & /*queue*/, *p_dvhandle = new oneapi::mkl::sparse::dense_vector_handle(val, size); } +template +void check_can_reset_value_handle(const std::string &function_name, + InternalHandleT *internal_handle) { + if (internal_handle->get_value_type() != detail::get_data_type()) { + throw oneapi::mkl::invalid_argument( + "sparse_blas", function_name, + "Incompatible data types expected " + + data_type_to_str(internal_handle->get_value_type()) + " but got " + + data_type_to_str(detail::get_data_type())); + } +} + template void set_dense_vector_data(sycl::queue & /*queue*/, oneapi::mkl::sparse::dense_vector_handle_t dvhandle, std::int64_t size, sycl::buffer val) { + check_can_reset_value_handle(__FUNCTION__, dvhandle); dvhandle->size = size; dvhandle->set_buffer(val); } @@ -44,6 +57,7 @@ template void set_dense_vector_data(sycl::queue & /*queue*/, oneapi::mkl::sparse::dense_vector_handle_t dvhandle, std::int64_t size, fpType *val) { + check_can_reset_value_handle(__FUNCTION__, dvhandle); dvhandle->size = size; dvhandle->set_usm_ptr(val); } @@ -94,6 +108,7 @@ void set_dense_matrix_data(sycl::queue & /*queue*/, oneapi::mkl::sparse::dense_matrix_handle_t dmhandle, std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, oneapi::mkl::layout dense_layout, sycl::buffer val) { + check_can_reset_value_handle(__FUNCTION__, dmhandle); dmhandle->num_rows = num_rows; dmhandle->num_cols = num_cols; dmhandle->ld = ld; @@ -106,6 +121,7 @@ void set_dense_matrix_data(sycl::queue & /*queue*/, oneapi::mkl::sparse::dense_matrix_handle_t dmhandle, std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, oneapi::mkl::layout dense_layout, fpType *val) { + check_can_reset_value_handle(__FUNCTION__, dmhandle); dmhandle->num_rows = num_rows; dmhandle->num_cols = num_cols; dmhandle->ld = ld; @@ -173,15 +189,9 @@ void init_coo_matrix(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p } template -void check_can_reset(const std::string &function_name, - detail::sparse_matrix_handle *internal_smhandle) { - if (internal_smhandle->get_value_type() != detail::get_data_type()) { - throw oneapi::mkl::invalid_argument( - "sparse_blas", function_name, - "Incompatible data types expected " + - data_type_to_str(internal_smhandle->get_value_type()) + " but got " + - data_type_to_str(detail::get_data_type())); - } +void check_can_reset_sparse_handle(const std::string &function_name, + detail::sparse_matrix_handle *internal_smhandle) { + check_can_reset_value_handle(function_name, internal_smhandle); if (internal_smhandle->get_int_type() != detail::get_data_type()) { throw oneapi::mkl::invalid_argument( "sparse_blas", function_name, @@ -202,7 +212,7 @@ void set_coo_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_ oneapi::mkl::index_base index, sycl::buffer row_ind, sycl::buffer col_ind, sycl::buffer val) { auto internal_smhandle = detail::get_internal_handle(smhandle); - check_can_reset(__FUNCTION__, internal_smhandle); + check_can_reset_sparse_handle(__FUNCTION__, internal_smhandle); internal_smhandle->row_container.set_buffer(row_ind); internal_smhandle->col_container.set_buffer(col_ind); internal_smhandle->value_container.set_buffer(val); @@ -221,7 +231,7 @@ void set_coo_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_ oneapi::mkl::index_base index, intType *row_ind, intType *col_ind, fpType *val) { auto internal_smhandle = detail::get_internal_handle(smhandle); - check_can_reset(__FUNCTION__, internal_smhandle); + check_can_reset_sparse_handle(__FUNCTION__, internal_smhandle); internal_smhandle->row_container.set_usm_ptr(row_ind); internal_smhandle->col_container.set_usm_ptr(col_ind); internal_smhandle->value_container.set_usm_ptr(val); @@ -298,7 +308,7 @@ void set_csr_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_ oneapi::mkl::index_base index, sycl::buffer row_ptr, sycl::buffer col_ind, sycl::buffer val) { auto internal_smhandle = detail::get_internal_handle(smhandle); - check_can_reset(__FUNCTION__, internal_smhandle); + check_can_reset_sparse_handle(__FUNCTION__, internal_smhandle); internal_smhandle->row_container.set_buffer(row_ptr); internal_smhandle->col_container.set_buffer(col_ind); internal_smhandle->value_container.set_buffer(val); @@ -318,7 +328,7 @@ void set_csr_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_ oneapi::mkl::index_base index, intType *row_ptr, intType *col_ind, fpType *val) { auto internal_smhandle = detail::get_internal_handle(smhandle); - check_can_reset(__FUNCTION__, internal_smhandle); + check_can_reset_sparse_handle(__FUNCTION__, internal_smhandle); internal_smhandle->row_container.set_usm_ptr(row_ptr); internal_smhandle->col_container.set_usm_ptr(col_ind); internal_smhandle->value_container.set_usm_ptr(val);