Skip to content

Commit

Permalink
Check workspace container is compatible with the handles
Browse files Browse the repository at this point in the history
  • Loading branch information
GitHub Actions committed Jun 4, 2024
1 parent 657880c commit 28ec600
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 9 deletions.
10 changes: 8 additions & 2 deletions src/sparse_blas/backends/mkl_common/mkl_spmm.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,13 @@ void spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::
oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/,
sycl::buffer<std::uint8_t, 1> /*workspace*/) {
check_valid_spmm(__FUNCTION__, queue, opA, A_view, A_handle, B_handle, C_handle, alpha, beta);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (!internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__FUNCTION__);
}
if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) {
return;
}
auto internal_A_handle = detail::get_internal_handle(A_handle);
internal_A_handle->can_be_reset = false;
// TODO: Add support for spmm_optimize once the close-source oneMKL backend supports it.
}
Expand All @@ -112,10 +115,13 @@ sycl::event spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA,
oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, void * /*workspace*/,
const std::vector<sycl::event> &dependencies) {
check_valid_spmm(__FUNCTION__, queue, opA, A_view, A_handle, B_handle, C_handle, alpha, beta);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__FUNCTION__);
}
if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) {
return detail::collapse_dependencies(queue, dependencies);
}
auto internal_A_handle = detail::get_internal_handle(A_handle);
internal_A_handle->can_be_reset = false;
// TODO: Add support for spmm_optimize once the close-source oneMKL backend supports it.
return detail::collapse_dependencies(queue, dependencies);
Expand Down
10 changes: 8 additions & 2 deletions src/sparse_blas/backends/mkl_common/mkl_spmv.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,14 @@ void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a
oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/,
sycl::buffer<std::uint8_t, 1> /*workspace*/) {
check_valid_spmv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, beta);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (!internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__FUNCTION__);
}
if (alg == oneapi::mkl::sparse::spmv_alg::no_optimize_alg) {
return;
}
sycl::event event;
auto internal_A_handle = detail::get_internal_handle(A_handle);
internal_A_handle->can_be_reset = false;
if (A_view.type_view == matrix_descr::triangular) {
event = oneapi::mkl::sparse::optimize_trmv(queue, A_view.uplo_view, opA, A_view.diag_view,
Expand All @@ -111,10 +114,13 @@ sycl::event spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const
oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, void * /*workspace*/,
const std::vector<sycl::event> &dependencies) {
check_valid_spmv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, beta);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__FUNCTION__);
}
if (alg == oneapi::mkl::sparse::spmv_alg::no_optimize_alg) {
return detail::collapse_dependencies(queue, dependencies);
}
auto internal_A_handle = detail::get_internal_handle(A_handle);
internal_A_handle->can_be_reset = false;
if (A_view.type_view == matrix_descr::triangular) {
return oneapi::mkl::sparse::optimize_trmv(queue, A_view.uplo_view, opA, A_view.diag_view,
Expand Down
10 changes: 8 additions & 2 deletions src/sparse_blas/backends/mkl_common/mkl_spsv.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,13 @@ void spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a
oneapi::mkl::sparse::spsv_descr_t /*spsv_descr*/,
sycl::buffer<std::uint8_t, 1> /*workspace*/) {
check_valid_spsv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, alg);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (!internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__FUNCTION__);
}
if (alg == oneapi::mkl::sparse::spsv_alg::no_optimize_alg) {
return;
}
auto internal_A_handle = detail::get_internal_handle(A_handle);
internal_A_handle->can_be_reset = false;
auto event = oneapi::mkl::sparse::optimize_trsv(queue, A_view.uplo_view, opA, A_view.diag_view,
internal_A_handle->backend_handle);
Expand All @@ -100,10 +103,13 @@ sycl::event spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const
oneapi::mkl::sparse::spsv_descr_t /*spsv_descr*/, void * /*workspace*/,
const std::vector<sycl::event> &dependencies) {
check_valid_spsv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, alg);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__FUNCTION__);
}
if (alg == oneapi::mkl::sparse::spsv_alg::no_optimize_alg) {
return detail::collapse_dependencies(queue, dependencies);
}
auto internal_A_handle = detail::get_internal_handle(A_handle);
internal_A_handle->can_be_reset = false;
return oneapi::mkl::sparse::optimize_trsv(queue, A_view.uplo_view, opA, A_view.diag_view,
internal_A_handle->backend_handle, dependencies);
Expand Down
10 changes: 7 additions & 3 deletions src/sparse_blas/generic_container.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,12 @@ struct generic_sparse_handle {
}
};

inline void throw_incompatible_container(const std::string& function_name) {
throw oneapi::mkl::invalid_argument(
"sparse_blas", function_name,
"Incompatible container types. All inputs and outputs must use the same container: buffer or USM");
}

/**
* Check that all internal containers use the same container.
*/
Expand All @@ -279,9 +285,7 @@ void check_all_containers_use_buffers(const std::string& function_name,
bool first_use_buffer = first_internal_container->all_use_buffer();
for (const auto internal_container : { internal_containers... }) {
if (internal_container->all_use_buffer() != first_use_buffer) {
throw oneapi::mkl::invalid_argument(
"sparse_blas", function_name,
"Incompatible container types. All inputs and outputs must use the same container: buffer or USM");
throw_incompatible_container(function_name);
}
}
}
Expand Down

0 comments on commit 28ec600

Please sign in to comment.