Skip to content

Commit

Permalink
Do not allow changing data types of dense handles
Browse files Browse the repository at this point in the history
  • Loading branch information
GitHub Actions committed May 28, 2024
1 parent 3442318 commit fc1f08d
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions src/sparse_blas/backends/mkl_common/mkl_handles.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,23 @@ void init_dense_vector(sycl::queue & /*queue*/,
*p_dvhandle = new oneapi::mkl::sparse::dense_vector_handle(val, size);
}

template <typename fpType, typename InternalHandleT>
void check_can_reset_value_handle(const std::string &function_name,
InternalHandleT *internal_handle) {
if (internal_handle->get_value_type() != detail::get_data_type<fpType>()) {
throw oneapi::mkl::invalid_argument(
"sparse_blas", function_name,
"Incompatible data types expected " +
data_type_to_str(internal_handle->get_value_type()) + " but got " +
data_type_to_str(detail::get_data_type<fpType>()));
}
}

template <typename fpType>
void set_dense_vector_data(sycl::queue & /*queue*/,
oneapi::mkl::sparse::dense_vector_handle_t dvhandle, std::int64_t size,
sycl::buffer<fpType, 1> val) {
check_can_reset_value_handle<fpType>(__FUNCTION__, dvhandle);
dvhandle->size = size;
dvhandle->set_buffer(val);
}
Expand All @@ -44,6 +57,7 @@ template <typename fpType>
void set_dense_vector_data(sycl::queue & /*queue*/,
oneapi::mkl::sparse::dense_vector_handle_t dvhandle, std::int64_t size,
fpType *val) {
check_can_reset_value_handle<fpType>(__FUNCTION__, dvhandle);
dvhandle->size = size;
dvhandle->set_usm_ptr(val);
}
Expand Down Expand Up @@ -94,6 +108,7 @@ void set_dense_matrix_data(sycl::queue & /*queue*/,
oneapi::mkl::sparse::dense_matrix_handle_t dmhandle,
std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld,
oneapi::mkl::layout dense_layout, sycl::buffer<fpType, 1> val) {
check_can_reset_value_handle<fpType>(__FUNCTION__, dmhandle);
dmhandle->num_rows = num_rows;
dmhandle->num_cols = num_cols;
dmhandle->ld = ld;
Expand All @@ -106,6 +121,7 @@ void set_dense_matrix_data(sycl::queue & /*queue*/,
oneapi::mkl::sparse::dense_matrix_handle_t dmhandle,
std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld,
oneapi::mkl::layout dense_layout, fpType *val) {
check_can_reset_value_handle<fpType>(__FUNCTION__, dmhandle);
dmhandle->num_rows = num_rows;
dmhandle->num_cols = num_cols;
dmhandle->ld = ld;
Expand Down Expand Up @@ -173,15 +189,9 @@ void init_coo_matrix(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p
}

template <typename fpType, typename intType>
void check_can_reset(const std::string &function_name,
detail::sparse_matrix_handle *internal_smhandle) {
if (internal_smhandle->get_value_type() != detail::get_data_type<fpType>()) {
throw oneapi::mkl::invalid_argument(
"sparse_blas", function_name,
"Incompatible data types expected " +
data_type_to_str(internal_smhandle->get_value_type()) + " but got " +
data_type_to_str(detail::get_data_type<fpType>()));
}
void check_can_reset_sparse_handle(const std::string &function_name,
detail::sparse_matrix_handle *internal_smhandle) {
check_can_reset_value_handle<fpType>(function_name, internal_smhandle);
if (internal_smhandle->get_int_type() != detail::get_data_type<intType>()) {
throw oneapi::mkl::invalid_argument(
"sparse_blas", function_name,
Expand All @@ -202,7 +212,7 @@ void set_coo_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_
oneapi::mkl::index_base index, sycl::buffer<intType, 1> row_ind,
sycl::buffer<intType, 1> col_ind, sycl::buffer<fpType, 1> val) {
auto internal_smhandle = detail::get_internal_handle(smhandle);
check_can_reset<fpType, intType>(__FUNCTION__, internal_smhandle);
check_can_reset_sparse_handle<fpType, intType>(__FUNCTION__, internal_smhandle);
internal_smhandle->row_container.set_buffer(row_ind);
internal_smhandle->col_container.set_buffer(col_ind);
internal_smhandle->value_container.set_buffer(val);
Expand All @@ -221,7 +231,7 @@ void set_coo_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_
oneapi::mkl::index_base index, intType *row_ind, intType *col_ind,
fpType *val) {
auto internal_smhandle = detail::get_internal_handle(smhandle);
check_can_reset<fpType, intType>(__FUNCTION__, internal_smhandle);
check_can_reset_sparse_handle<fpType, intType>(__FUNCTION__, internal_smhandle);
internal_smhandle->row_container.set_usm_ptr(row_ind);
internal_smhandle->col_container.set_usm_ptr(col_ind);
internal_smhandle->value_container.set_usm_ptr(val);
Expand Down Expand Up @@ -298,7 +308,7 @@ void set_csr_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_
oneapi::mkl::index_base index, sycl::buffer<intType, 1> row_ptr,
sycl::buffer<intType, 1> col_ind, sycl::buffer<fpType, 1> val) {
auto internal_smhandle = detail::get_internal_handle(smhandle);
check_can_reset<fpType, intType>(__FUNCTION__, internal_smhandle);
check_can_reset_sparse_handle<fpType, intType>(__FUNCTION__, internal_smhandle);
internal_smhandle->row_container.set_buffer(row_ptr);
internal_smhandle->col_container.set_buffer(col_ind);
internal_smhandle->value_container.set_buffer(val);
Expand All @@ -318,7 +328,7 @@ void set_csr_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_
oneapi::mkl::index_base index, intType *row_ptr, intType *col_ind,
fpType *val) {
auto internal_smhandle = detail::get_internal_handle(smhandle);
check_can_reset<fpType, intType>(__FUNCTION__, internal_smhandle);
check_can_reset_sparse_handle<fpType, intType>(__FUNCTION__, internal_smhandle);
internal_smhandle->row_container.set_usm_ptr(row_ptr);
internal_smhandle->col_container.set_usm_ptr(col_ind);
internal_smhandle->value_container.set_usm_ptr(val);
Expand Down

0 comments on commit fc1f08d

Please sign in to comment.