diff --git a/src/sparse_blas/backends/mkl_common/mkl_helper.hpp b/src/sparse_blas/backends/mkl_common/mkl_helper.hpp index d1303d949..c76af5cb6 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_helper.hpp +++ b/src/sparse_blas/backends/mkl_common/mkl_helper.hpp @@ -43,10 +43,10 @@ inline bool is_ptr_accessible_on_host(sycl::queue &queue, const T *host_or_devic } /// Throw an exception if the scalar is not accessible in the host -template -void check_ptr_is_host_accessible(const std::string &function_name, const std::string &scalar_name, - sycl::queue &queue, const T *host_or_device_ptr) { - if (!is_ptr_accessible_on_host(queue, host_or_device_ptr)) { +inline void check_ptr_is_host_accessible(const std::string &function_name, + const std::string &scalar_name, + bool is_ptr_accessible_on_host) { + if (!is_ptr_accessible_on_host) { throw mkl::invalid_argument( "sparse_blas", function_name, "Scalar " + scalar_name + " must be accessible on the host for buffer functions."); @@ -56,8 +56,9 @@ void check_ptr_is_host_accessible(const std::string &function_name, const std::s /// Return a scalar on the host from a pointer to host or device memory /// Used for USM functions template -inline T get_scalar_on_host(sycl::queue &queue, const T *host_or_device_ptr) { - if (is_ptr_accessible_on_host(queue, host_or_device_ptr)) { +inline T get_scalar_on_host(sycl::queue &queue, const T *host_or_device_ptr, + bool is_ptr_accessible_on_host) { + if (is_ptr_accessible_on_host) { return *host_or_device_ptr; } T scalar; diff --git a/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx b/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx index 604db11a7..3c2a9f161 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx @@ -29,12 +29,12 @@ sycl::event release_spmm_descr(sycl::queue &queue, oneapi::mkl::sparse::spmm_des return detail::collapse_dependencies(queue, dependencies); } -void check_valid_spmm(const std::string &function_name, sycl::queue &queue, - oneapi::mkl::transpose opA, oneapi::mkl::sparse::matrix_view A_view, +void check_valid_spmm(const std::string &function_name, oneapi::mkl::transpose opA, + oneapi::mkl::sparse::matrix_view A_view, oneapi::mkl::sparse::matrix_handle_t A_handle, oneapi::mkl::sparse::dense_matrix_handle_t B_handle, - oneapi::mkl::sparse::dense_matrix_handle_t C_handle, const void *alpha, - const void *beta) { + oneapi::mkl::sparse::dense_matrix_handle_t C_handle, + bool is_alpha_host_accessible, bool is_beta_host_accessible) { THROW_IF_NULLPTR(function_name, A_handle); THROW_IF_NULLPTR(function_name, B_handle); THROW_IF_NULLPTR(function_name, C_handle); @@ -42,11 +42,10 @@ void check_valid_spmm(const std::string &function_name, sycl::queue &queue, auto internal_A_handle = detail::get_internal_handle(A_handle); detail::check_all_containers_compatible(function_name, internal_A_handle, B_handle, C_handle); if (internal_A_handle->all_use_buffer()) { - detail::check_ptr_is_host_accessible("spmm", "alpha", queue, alpha); - detail::check_ptr_is_host_accessible("spmm", "beta", queue, beta); + detail::check_ptr_is_host_accessible("spmm", "alpha", is_alpha_host_accessible); + detail::check_ptr_is_host_accessible("spmm", "beta", is_beta_host_accessible); } - if (detail::is_ptr_accessible_on_host(queue, alpha) != - detail::is_ptr_accessible_on_host(queue, beta)) { + if (is_alpha_host_accessible != is_beta_host_accessible) { throw mkl::invalid_argument( "sparse_blas", function_name, "Alpha and beta must both be placed on host memory or device memory."); @@ -91,7 +90,10 @@ void spmm_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, std::size_t &temp_buffer_size) { // TODO: Add support for external workspace once the close-source oneMKL backend supports it. - check_valid_spmm(__func__, queue, opA, A_view, A_handle, B_handle, C_handle, alpha, beta); + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible, + is_beta_host_accessible); temp_buffer_size = 0; } @@ -103,7 +105,10 @@ void spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl:: oneapi::mkl::sparse::spmm_alg alg, oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, sycl::buffer /*workspace*/) { - check_valid_spmm(__func__, queue, opA, A_view, A_handle, B_handle, C_handle, alpha, beta); + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible, + is_beta_host_accessible); auto internal_A_handle = detail::get_internal_handle(A_handle); if (!internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); @@ -124,7 +129,10 @@ sycl::event spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::sparse::spmm_alg alg, oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, void * /*workspace*/, const std::vector &dependencies) { - check_valid_spmm(__func__, queue, opA, A_view, A_handle, B_handle, C_handle, alpha, beta); + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible, + is_beta_host_accessible); auto internal_A_handle = detail::get_internal_handle(A_handle); if (internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); @@ -138,17 +146,17 @@ sycl::event spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, } template -sycl::event internal_spmm(sycl::queue &queue, oneapi::mkl::transpose opA, - oneapi::mkl::transpose opB, const void *alpha, - oneapi::mkl::sparse::matrix_view /*A_view*/, - oneapi::mkl::sparse::matrix_handle_t A_handle, - oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void *beta, - oneapi::mkl::sparse::dense_matrix_handle_t C_handle, - oneapi::mkl::sparse::spmm_alg /*alg*/, - oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, - const std::vector &dependencies) { - T host_alpha = detail::get_scalar_on_host(queue, static_cast(alpha)); - T host_beta = detail::get_scalar_on_host(queue, static_cast(beta)); +sycl::event internal_spmm( + sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB, const void *alpha, + oneapi::mkl::sparse::matrix_view /*A_view*/, oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void *beta, + oneapi::mkl::sparse::dense_matrix_handle_t C_handle, oneapi::mkl::sparse::spmm_alg /*alg*/, + oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, const std::vector &dependencies, + bool is_alpha_host_accessible, bool is_beta_host_accessible) { + T host_alpha = + detail::get_scalar_on_host(queue, static_cast(alpha), is_alpha_host_accessible); + T host_beta = + detail::get_scalar_on_host(queue, static_cast(beta), is_beta_host_accessible); auto internal_A_handle = detail::get_internal_handle(A_handle); internal_A_handle->can_be_reset = false; auto layout = B_handle->dense_layout; @@ -177,8 +185,12 @@ sycl::event spmm(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::tr oneapi::mkl::sparse::dense_matrix_handle_t C_handle, oneapi::mkl::sparse::spmm_alg alg, oneapi::mkl::sparse::spmm_descr_t spmm_descr, const std::vector &dependencies) { - check_valid_spmm(__func__, queue, opA, A_view, A_handle, B_handle, C_handle, alpha, beta); + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible, + is_beta_host_accessible); auto value_type = detail::get_internal_handle(A_handle)->get_value_type(); DISPATCH_MKL_OPERATION("spmm", value_type, internal_spmm, queue, opA, opB, alpha, A_view, - A_handle, B_handle, beta, C_handle, alg, spmm_descr, dependencies); + A_handle, B_handle, beta, C_handle, alg, spmm_descr, dependencies, + is_alpha_host_accessible, is_beta_host_accessible); } diff --git a/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx b/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx index b35ad0847..930e1ec87 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx @@ -29,12 +29,12 @@ sycl::event release_spmv_descr(sycl::queue &queue, oneapi::mkl::sparse::spmv_des return detail::collapse_dependencies(queue, dependencies); } -void check_valid_spmv(const std::string &function_name, sycl::queue &queue, - oneapi::mkl::transpose opA, oneapi::mkl::sparse::matrix_view A_view, +void check_valid_spmv(const std::string &function_name, oneapi::mkl::transpose opA, + oneapi::mkl::sparse::matrix_view A_view, oneapi::mkl::sparse::matrix_handle_t A_handle, oneapi::mkl::sparse::dense_vector_handle_t x_handle, - oneapi::mkl::sparse::dense_vector_handle_t y_handle, const void *alpha, - const void *beta) { + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + bool is_alpha_host_accessible, bool is_beta_host_accessible) { THROW_IF_NULLPTR(function_name, A_handle); THROW_IF_NULLPTR(function_name, x_handle); THROW_IF_NULLPTR(function_name, y_handle); @@ -42,11 +42,10 @@ void check_valid_spmv(const std::string &function_name, sycl::queue &queue, auto internal_A_handle = detail::get_internal_handle(A_handle); detail::check_all_containers_compatible(function_name, internal_A_handle, x_handle, y_handle); if (internal_A_handle->all_use_buffer()) { - detail::check_ptr_is_host_accessible("spmv", "alpha", queue, alpha); - detail::check_ptr_is_host_accessible("spmv", "beta", queue, beta); + detail::check_ptr_is_host_accessible("spmv", "alpha", is_alpha_host_accessible); + detail::check_ptr_is_host_accessible("spmv", "beta", is_beta_host_accessible); } - if (detail::is_ptr_accessible_on_host(queue, alpha) != - detail::is_ptr_accessible_on_host(queue, beta)) { + if (is_alpha_host_accessible != is_beta_host_accessible) { throw mkl::invalid_argument( "sparse_blas", function_name, "Alpha and beta must both be placed on host memory or device memory."); @@ -81,7 +80,10 @@ void spmv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, const void oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, std::size_t &temp_buffer_size) { // TODO: Add support for external workspace once the close-source oneMKL backend supports it. - check_valid_spmv(__func__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta); + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + is_beta_host_accessible); temp_buffer_size = 0; } @@ -93,7 +95,10 @@ void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, sycl::buffer /*workspace*/) { - check_valid_spmv(__func__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta); + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + is_beta_host_accessible); auto internal_A_handle = detail::get_internal_handle(A_handle); if (!internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); @@ -127,7 +132,10 @@ sycl::event spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, void * /*workspace*/, const std::vector &dependencies) { - check_valid_spmv(__func__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta); + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + is_beta_host_accessible); auto internal_A_handle = detail::get_internal_handle(A_handle); if (internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); @@ -158,9 +166,12 @@ sycl::event internal_spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const oneapi::mkl::sparse::dense_vector_handle_t y_handle, oneapi::mkl::sparse::spmv_alg /*alg*/, oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, - const std::vector &dependencies) { - T host_alpha = detail::get_scalar_on_host(queue, static_cast(alpha)); - T host_beta = detail::get_scalar_on_host(queue, static_cast(beta)); + const std::vector &dependencies, + bool is_alpha_host_accessible, bool is_beta_host_accessible) { + T host_alpha = + detail::get_scalar_on_host(queue, static_cast(alpha), is_alpha_host_accessible); + T host_beta = + detail::get_scalar_on_host(queue, static_cast(beta), is_beta_host_accessible); auto internal_A_handle = detail::get_internal_handle(A_handle); internal_A_handle->can_be_reset = false; auto backend_handle = internal_A_handle->backend_handle; @@ -210,8 +221,12 @@ sycl::event spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp oneapi::mkl::sparse::dense_vector_handle_t y_handle, oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t spmv_descr, const std::vector &dependencies) { - check_valid_spmv(__func__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta); + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + is_beta_host_accessible); auto value_type = detail::get_internal_handle(A_handle)->get_value_type(); DISPATCH_MKL_OPERATION("spmv", value_type, internal_spmv, queue, opA, alpha, A_view, A_handle, - x_handle, beta, y_handle, alg, spmv_descr, dependencies); + x_handle, beta, y_handle, alg, spmv_descr, dependencies, + is_alpha_host_accessible, is_beta_host_accessible); } diff --git a/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx b/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx index 4ca4ee9d8..849919f12 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx @@ -29,12 +29,12 @@ sycl::event release_spsv_descr(sycl::queue &queue, oneapi::mkl::sparse::spsv_des return detail::collapse_dependencies(queue, dependencies); } -void check_valid_spsv(const std::string &function_name, sycl::queue &queue, - oneapi::mkl::transpose opA, oneapi::mkl::sparse::matrix_view A_view, +void check_valid_spsv(const std::string &function_name, oneapi::mkl::transpose opA, + oneapi::mkl::sparse::matrix_view A_view, oneapi::mkl::sparse::matrix_handle_t A_handle, oneapi::mkl::sparse::dense_vector_handle_t x_handle, - oneapi::mkl::sparse::dense_vector_handle_t y_handle, const void *alpha, - oneapi::mkl::sparse::spsv_alg alg) { + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + bool is_alpha_host_accessible, oneapi::mkl::sparse::spsv_alg alg) { THROW_IF_NULLPTR(function_name, A_handle); THROW_IF_NULLPTR(function_name, x_handle); THROW_IF_NULLPTR(function_name, y_handle); @@ -67,7 +67,7 @@ void check_valid_spsv(const std::string &function_name, sycl::queue &queue, } if (internal_A_handle->all_use_buffer()) { - detail::check_ptr_is_host_accessible("spsv", "alpha", queue, alpha); + detail::check_ptr_is_host_accessible("spsv", "alpha", is_alpha_host_accessible); } } @@ -80,7 +80,9 @@ void spsv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, const void oneapi::mkl::sparse::spsv_descr_t /*spsv_descr*/, std::size_t &temp_buffer_size) { // TODO: Add support for external workspace once the close-source oneMKL backend supports it. - check_valid_spsv(__func__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, alg); + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + check_valid_spsv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + alg); temp_buffer_size = 0; } @@ -92,7 +94,9 @@ void spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a oneapi::mkl::sparse::spsv_alg alg, oneapi::mkl::sparse::spsv_descr_t /*spsv_descr*/, sycl::buffer /*workspace*/) { - check_valid_spsv(__func__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, alg); + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + check_valid_spsv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + alg); auto internal_A_handle = detail::get_internal_handle(A_handle); if (!internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); @@ -115,7 +119,9 @@ sycl::event spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const oneapi::mkl::sparse::spsv_alg alg, oneapi::mkl::sparse::spsv_descr_t /*spsv_descr*/, void * /*workspace*/, const std::vector &dependencies) { - check_valid_spsv(__func__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, alg); + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + check_valid_spsv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + alg); auto internal_A_handle = detail::get_internal_handle(A_handle); if (internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); @@ -136,8 +142,10 @@ sycl::event internal_spsv(sycl::queue &queue, oneapi::mkl::transpose opA, const oneapi::mkl::sparse::dense_vector_handle_t y_handle, oneapi::mkl::sparse::spsv_alg /*alg*/, oneapi::mkl::sparse::spsv_descr_t /*spsv_descr*/, - const std::vector &dependencies) { - T host_alpha = detail::get_scalar_on_host(queue, static_cast(alpha)); + const std::vector &dependencies, + bool is_alpha_host_accessible) { + T host_alpha = + detail::get_scalar_on_host(queue, static_cast(alpha), is_alpha_host_accessible); auto internal_A_handle = detail::get_internal_handle(A_handle); internal_A_handle->can_be_reset = false; if (internal_A_handle->all_use_buffer()) { @@ -162,8 +170,11 @@ sycl::event spsv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp oneapi::mkl::sparse::dense_vector_handle_t y_handle, oneapi::mkl::sparse::spsv_alg alg, oneapi::mkl::sparse::spsv_descr_t spsv_descr, const std::vector &dependencies) { - check_valid_spsv(__func__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, alg); + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + check_valid_spsv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + alg); auto value_type = detail::get_internal_handle(A_handle)->get_value_type(); DISPATCH_MKL_OPERATION("spsv", value_type, internal_spsv, queue, opA, alpha, A_view, A_handle, - x_handle, y_handle, alg, spsv_descr, dependencies); + x_handle, y_handle, alg, spsv_descr, dependencies, + is_alpha_host_accessible); }