diff --git a/src/sparse_blas/backends/mkl_common/mkl_helper.hpp b/src/sparse_blas/backends/mkl_common/mkl_helper.hpp index c76af5cb6..ca15c5b4f 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_helper.hpp +++ b/src/sparse_blas/backends/mkl_common/mkl_helper.hpp @@ -97,6 +97,15 @@ inline sycl::event collapse_dependencies(sycl::queue &queue, "Internal error: unsupported type " + data_type_to_str(value_type)); \ } +#define CHECK_DESCR_MATCH(descr, argument, optimize_func_name) \ + do { \ + if (descr->last_optimized_##argument != argument) { \ + throw mkl::invalid_argument( \ + "sparse_blas", __func__, \ + #argument " argument must match with the previous call to " #optimize_func_name); \ + } \ + } while (0) + } // namespace oneapi::mkl::sparse::detail #endif // _ONEMKL_SRC_SPARSE_BLAS_BACKENDS_MKL_COMMON_MKL_HELPER_HPP_ diff --git a/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx b/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx index 3c2a9f161..eb1b45ebf 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx @@ -17,16 +17,31 @@ * **************************************************************************/ -// The operation descriptor is not needed as long as the backend does not have an equivalent type and does not support external workspace. -using spmm_descr = void *; +namespace oneapi::mkl::sparse { + +struct spmm_descr { + bool buffer_size_called = false; + bool optimized_called = false; + oneapi::mkl::transpose last_optimized_opA; + oneapi::mkl::transpose last_optimized_opB; + oneapi::mkl::sparse::matrix_view last_optimized_A_view; + oneapi::mkl::sparse::matrix_handle_t last_optimized_A_handle; + oneapi::mkl::sparse::dense_matrix_handle_t last_optimized_B_handle; + oneapi::mkl::sparse::dense_matrix_handle_t last_optimized_C_handle; + oneapi::mkl::sparse::spmm_alg last_optimized_alg; +}; + +} // namespace oneapi::mkl::sparse + +namespace oneapi::mkl::sparse::BACKEND { void init_spmm_descr(sycl::queue & /*queue*/, oneapi::mkl::sparse::spmm_descr_t *p_spmm_descr) { - *p_spmm_descr = nullptr; + *p_spmm_descr = new spmm_descr(); } -sycl::event release_spmm_descr(sycl::queue &queue, oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, +sycl::event release_spmm_descr(sycl::queue &queue, oneapi::mkl::sparse::spmm_descr_t spmm_descr, const std::vector &dependencies) { - return detail::collapse_dependencies(queue, dependencies); + return detail::submit_release(queue, spmm_descr, dependencies); } void check_valid_spmm(const std::string &function_name, oneapi::mkl::transpose opA, @@ -87,28 +102,50 @@ void spmm_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void *beta, oneapi::mkl::sparse::dense_matrix_handle_t C_handle, oneapi::mkl::sparse::spmm_alg /*alg*/, - oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, - std::size_t &temp_buffer_size) { + oneapi::mkl::sparse::spmm_descr_t spmm_descr, std::size_t &temp_buffer_size) { // TODO: Add support for external workspace once the close-source oneMKL backend supports it. bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible, is_beta_host_accessible); temp_buffer_size = 0; + spmm_descr->buffer_size_called = true; +} + +inline void common_spmm_optimize( + sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB, const void *alpha, + oneapi::mkl::sparse::matrix_view A_view, oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void *beta, + oneapi::mkl::sparse::dense_matrix_handle_t C_handle, oneapi::mkl::sparse::spmm_alg alg, + oneapi::mkl::sparse::spmm_descr_t spmm_descr) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible, + is_beta_host_accessible); + if (!spmm_descr->buffer_size_called) { + throw mkl::uninitialized( + "sparse_blas", __func__, + "spmm_buffer_size must be called with the same arguments before spmm_optimize."); + } + spmm_descr->optimized_called = true; + spmm_descr->last_optimized_opA = opA; + spmm_descr->last_optimized_opB = opB; + spmm_descr->last_optimized_A_view = A_view; + spmm_descr->last_optimized_A_handle = A_handle; + spmm_descr->last_optimized_B_handle = B_handle; + spmm_descr->last_optimized_C_handle = C_handle; + spmm_descr->last_optimized_alg = alg; } -void spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose /*opB*/, +void spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB, const void *alpha, oneapi::mkl::sparse::matrix_view A_view, oneapi::mkl::sparse::matrix_handle_t A_handle, oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void *beta, oneapi::mkl::sparse::dense_matrix_handle_t C_handle, - oneapi::mkl::sparse::spmm_alg alg, - oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, + oneapi::mkl::sparse::spmm_alg alg, oneapi::mkl::sparse::spmm_descr_t spmm_descr, sycl::buffer /*workspace*/) { - bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); - bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); - check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible, - is_beta_host_accessible); + common_spmm_optimize(queue, opA, opB, alpha, A_view, A_handle, B_handle, beta, C_handle, alg, + spmm_descr); auto internal_A_handle = detail::get_internal_handle(A_handle); if (!internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); @@ -121,18 +158,16 @@ void spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl:: } sycl::event spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, - oneapi::mkl::transpose /*opB*/, const void *alpha, + oneapi::mkl::transpose opB, const void *alpha, oneapi::mkl::sparse::matrix_view A_view, oneapi::mkl::sparse::matrix_handle_t A_handle, oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void *beta, oneapi::mkl::sparse::dense_matrix_handle_t C_handle, oneapi::mkl::sparse::spmm_alg alg, - oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, void * /*workspace*/, + oneapi::mkl::sparse::spmm_descr_t spmm_descr, void * /*workspace*/, const std::vector &dependencies) { - bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); - bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); - check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible, - is_beta_host_accessible); + common_spmm_optimize(queue, opA, opB, alpha, A_view, A_handle, B_handle, beta, C_handle, alg, + spmm_descr); auto internal_A_handle = detail::get_internal_handle(A_handle); if (internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); @@ -189,8 +224,24 @@ sycl::event spmm(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::tr bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible, is_beta_host_accessible); + + if (!spmm_descr->optimized_called) { + throw mkl::uninitialized( + "sparse_blas", __func__, + "spmm_optimize must be called with the same arguments before spmm."); + } + CHECK_DESCR_MATCH(spmm_descr, opA, "spmm_optimize"); + CHECK_DESCR_MATCH(spmm_descr, opB, "spmm_optimize"); + CHECK_DESCR_MATCH(spmm_descr, A_view, "spmm_optimize"); + CHECK_DESCR_MATCH(spmm_descr, A_handle, "spmm_optimize"); + CHECK_DESCR_MATCH(spmm_descr, B_handle, "spmm_optimize"); + CHECK_DESCR_MATCH(spmm_descr, C_handle, "spmm_optimize"); + CHECK_DESCR_MATCH(spmm_descr, alg, "spmm_optimize"); + auto value_type = detail::get_internal_handle(A_handle)->get_value_type(); DISPATCH_MKL_OPERATION("spmm", value_type, internal_spmm, queue, opA, opB, alpha, A_view, A_handle, B_handle, beta, C_handle, alg, spmm_descr, dependencies, is_alpha_host_accessible, is_beta_host_accessible); } + +} // namespace oneapi::mkl::sparse::BACKEND diff --git a/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx b/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx index d5da0c433..4e5aeffdb 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx @@ -17,16 +17,30 @@ * **************************************************************************/ -// The operation descriptor is not needed as long as the backend does not have an equivalent type and does not support external workspace. -using spmv_descr = void *; +namespace oneapi::mkl::sparse { + +struct spmv_descr { + bool buffer_size_called = false; + bool optimized_called = false; + oneapi::mkl::transpose last_optimized_opA; + oneapi::mkl::sparse::matrix_view last_optimized_A_view; + oneapi::mkl::sparse::matrix_handle_t last_optimized_A_handle; + oneapi::mkl::sparse::dense_vector_handle_t last_optimized_x_handle; + oneapi::mkl::sparse::dense_vector_handle_t last_optimized_y_handle; + oneapi::mkl::sparse::spmv_alg last_optimized_alg; +}; + +} // namespace oneapi::mkl::sparse + +namespace oneapi::mkl::sparse::BACKEND { void init_spmv_descr(sycl::queue & /*queue*/, oneapi::mkl::sparse::spmv_descr_t *p_spmv_descr) { - *p_spmv_descr = nullptr; + *p_spmv_descr = new spmv_descr(); } -sycl::event release_spmv_descr(sycl::queue &queue, oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, +sycl::event release_spmv_descr(sycl::queue &queue, oneapi::mkl::sparse::spmv_descr_t spmv_descr, const std::vector &dependencies) { - return detail::collapse_dependencies(queue, dependencies); + return detail::submit_release(queue, spmv_descr, dependencies); } void check_valid_spmv(const std::string &function_name, oneapi::mkl::transpose opA, @@ -77,14 +91,40 @@ void spmv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, const void oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void *beta, oneapi::mkl::sparse::dense_vector_handle_t y_handle, oneapi::mkl::sparse::spmv_alg /*alg*/, - oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, - std::size_t &temp_buffer_size) { + oneapi::mkl::sparse::spmv_descr_t spmv_descr, std::size_t &temp_buffer_size) { // TODO: Add support for external workspace once the close-source oneMKL backend supports it. bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, is_beta_host_accessible); temp_buffer_size = 0; + spmv_descr->buffer_size_called = true; +} + +inline void common_spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha, + oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x_handle, + const void *beta, + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + oneapi::mkl::sparse::spmv_alg alg, + oneapi::mkl::sparse::spmv_descr_t spmv_descr) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + is_beta_host_accessible); + if (!spmv_descr->buffer_size_called) { + throw mkl::uninitialized( + "sparse_blas", __func__, + "spmv_buffer_size must be called with the same arguments before spmv_optimize."); + } + spmv_descr->optimized_called = true; + spmv_descr->last_optimized_opA = opA; + spmv_descr->last_optimized_A_view = A_view; + spmv_descr->last_optimized_A_handle = A_handle; + spmv_descr->last_optimized_x_handle = x_handle; + spmv_descr->last_optimized_y_handle = y_handle; + spmv_descr->last_optimized_alg = alg; } void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha, @@ -92,13 +132,10 @@ void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a oneapi::mkl::sparse::matrix_handle_t A_handle, oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void *beta, oneapi::mkl::sparse::dense_vector_handle_t y_handle, - oneapi::mkl::sparse::spmv_alg alg, - oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, + oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t spmv_descr, sycl::buffer /*workspace*/) { - bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); - bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); - check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, - is_beta_host_accessible); + common_spmv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg, + spmv_descr); auto internal_A_handle = detail::get_internal_handle(A_handle); if (!internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); @@ -127,12 +164,10 @@ sycl::event spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void *beta, oneapi::mkl::sparse::dense_vector_handle_t y_handle, oneapi::mkl::sparse::spmv_alg alg, - oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, void * /*workspace*/, + oneapi::mkl::sparse::spmv_descr_t spmv_descr, void * /*workspace*/, const std::vector &dependencies) { - bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); - bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); - check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, - is_beta_host_accessible); + common_spmv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg, + spmv_descr); auto internal_A_handle = detail::get_internal_handle(A_handle); if (internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); @@ -222,8 +257,23 @@ sycl::event spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, is_beta_host_accessible); + + if (!spmv_descr->optimized_called) { + throw mkl::uninitialized( + "sparse_blas", __func__, + "spmv_optimize must be called with the same arguments before spmv."); + } + CHECK_DESCR_MATCH(spmv_descr, opA, "spmv_optimize"); + CHECK_DESCR_MATCH(spmv_descr, A_view, "spmv_optimize"); + CHECK_DESCR_MATCH(spmv_descr, A_handle, "spmv_optimize"); + CHECK_DESCR_MATCH(spmv_descr, x_handle, "spmv_optimize"); + CHECK_DESCR_MATCH(spmv_descr, y_handle, "spmv_optimize"); + CHECK_DESCR_MATCH(spmv_descr, alg, "spmv_optimize"); + auto value_type = detail::get_internal_handle(A_handle)->get_value_type(); DISPATCH_MKL_OPERATION("spmv", value_type, internal_spmv, queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg, spmv_descr, dependencies, is_alpha_host_accessible, is_beta_host_accessible); } + +} // namespace oneapi::mkl::sparse::BACKEND diff --git a/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx b/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx index 718ab0f19..371fac38b 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx @@ -17,16 +17,30 @@ * **************************************************************************/ -// The operation descriptor is not needed as long as the backend does not have an equivalent type and does not support external workspace. -using spsv_descr = void *; +namespace oneapi::mkl::sparse { + +struct spsv_descr { + bool buffer_size_called = false; + bool optimized_called = false; + oneapi::mkl::transpose last_optimized_opA; + oneapi::mkl::sparse::matrix_view last_optimized_A_view; + oneapi::mkl::sparse::matrix_handle_t last_optimized_A_handle; + oneapi::mkl::sparse::dense_vector_handle_t last_optimized_x_handle; + oneapi::mkl::sparse::dense_vector_handle_t last_optimized_y_handle; + oneapi::mkl::sparse::spsv_alg last_optimized_alg; +}; + +} // namespace oneapi::mkl::sparse + +namespace oneapi::mkl::sparse::BACKEND { void init_spsv_descr(sycl::queue & /*queue*/, oneapi::mkl::sparse::spsv_descr_t *p_spsv_descr) { - *p_spsv_descr = nullptr; + *p_spsv_descr = new spsv_descr(); } -sycl::event release_spsv_descr(sycl::queue &queue, oneapi::mkl::sparse::spsv_descr_t /*spsv_descr*/, +sycl::event release_spsv_descr(sycl::queue &queue, oneapi::mkl::sparse::spsv_descr_t spsv_descr, const std::vector &dependencies) { - return detail::collapse_dependencies(queue, dependencies); + return detail::submit_release(queue, spsv_descr, dependencies); } void check_valid_spsv(const std::string &function_name, oneapi::mkl::transpose opA, @@ -76,13 +90,37 @@ void spsv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, const void oneapi::mkl::sparse::dense_vector_handle_t x_handle, oneapi::mkl::sparse::dense_vector_handle_t y_handle, oneapi::mkl::sparse::spsv_alg alg, - oneapi::mkl::sparse::spsv_descr_t /*spsv_descr*/, - std::size_t &temp_buffer_size) { + oneapi::mkl::sparse::spsv_descr_t spsv_descr, std::size_t &temp_buffer_size) { // TODO: Add support for external workspace once the close-source oneMKL backend supports it. bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); check_valid_spsv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, alg); temp_buffer_size = 0; + spsv_descr->buffer_size_called = true; +} + +inline void common_spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha, + oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x_handle, + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + oneapi::mkl::sparse::spsv_alg alg, + oneapi::mkl::sparse::spsv_descr_t spsv_descr) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + check_valid_spsv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + alg); + if (!spsv_descr->buffer_size_called) { + throw mkl::uninitialized( + "sparse_blas", __func__, + "spsv_buffer_size must be called with the same arguments before spsv_optimize."); + } + spsv_descr->optimized_called = true; + spsv_descr->last_optimized_opA = opA; + spsv_descr->last_optimized_A_view = A_view; + spsv_descr->last_optimized_A_handle = A_handle; + spsv_descr->last_optimized_x_handle = x_handle; + spsv_descr->last_optimized_y_handle = y_handle; + spsv_descr->last_optimized_alg = alg; } void spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha, @@ -90,12 +128,9 @@ void spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a oneapi::mkl::sparse::matrix_handle_t A_handle, oneapi::mkl::sparse::dense_vector_handle_t x_handle, oneapi::mkl::sparse::dense_vector_handle_t y_handle, - oneapi::mkl::sparse::spsv_alg alg, - oneapi::mkl::sparse::spsv_descr_t /*spsv_descr*/, + oneapi::mkl::sparse::spsv_alg alg, oneapi::mkl::sparse::spsv_descr_t spsv_descr, sycl::buffer /*workspace*/) { - bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); - check_valid_spsv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, - alg); + common_spsv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, spsv_descr); auto internal_A_handle = detail::get_internal_handle(A_handle); if (!internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); @@ -114,11 +149,9 @@ sycl::event spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const oneapi::mkl::sparse::dense_vector_handle_t x_handle, oneapi::mkl::sparse::dense_vector_handle_t y_handle, oneapi::mkl::sparse::spsv_alg alg, - oneapi::mkl::sparse::spsv_descr_t /*spsv_descr*/, void * /*workspace*/, + oneapi::mkl::sparse::spsv_descr_t spsv_descr, void * /*workspace*/, const std::vector &dependencies) { - bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); - check_valid_spsv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, - alg); + common_spsv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, spsv_descr); auto internal_A_handle = detail::get_internal_handle(A_handle); if (internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); @@ -170,8 +203,23 @@ sycl::event spsv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); check_valid_spsv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, alg); + + if (!spsv_descr->optimized_called) { + throw mkl::uninitialized( + "sparse_blas", __func__, + "spsv_optimize must be called with the same arguments before spsv."); + } + CHECK_DESCR_MATCH(spsv_descr, opA, "spsv_optimize"); + CHECK_DESCR_MATCH(spsv_descr, A_view, "spsv_optimize"); + CHECK_DESCR_MATCH(spsv_descr, A_handle, "spsv_optimize"); + CHECK_DESCR_MATCH(spsv_descr, x_handle, "spsv_optimize"); + CHECK_DESCR_MATCH(spsv_descr, y_handle, "spsv_optimize"); + CHECK_DESCR_MATCH(spsv_descr, alg, "spsv_optimize"); + auto value_type = detail::get_internal_handle(A_handle)->get_value_type(); DISPATCH_MKL_OPERATION("spsv", value_type, internal_spsv, queue, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, spsv_descr, dependencies, is_alpha_host_accessible); } + +} // namespace oneapi::mkl::sparse::BACKEND diff --git a/src/sparse_blas/backends/mklcpu/mklcpu_operations.cpp b/src/sparse_blas/backends/mklcpu/mklcpu_operations.cpp index 4e0242c2d..0929a7ef4 100644 --- a/src/sparse_blas/backends/mklcpu/mklcpu_operations.cpp +++ b/src/sparse_blas/backends/mklcpu/mklcpu_operations.cpp @@ -17,16 +17,17 @@ * **************************************************************************/ +#include "sparse_blas/backends/mkl_common/mkl_handles.hpp" #include "sparse_blas/backends/mkl_common/mkl_helper.hpp" #include "sparse_blas/macros.hpp" -#include "sparse_blas/backends/mkl_common/mkl_handles.hpp" +#include "sparse_blas/matrix_view_comparison.hpp" #include "oneapi/mkl/sparse_blas/detail/mklcpu/onemkl_sparse_blas_mklcpu.hpp" -namespace oneapi::mkl::sparse::mklcpu { +#define BACKEND mklcpu #include "sparse_blas/backends/mkl_common/mkl_spmm.cxx" #include "sparse_blas/backends/mkl_common/mkl_spmv.cxx" #include "sparse_blas/backends/mkl_common/mkl_spsv.cxx" -} // namespace oneapi::mkl::sparse::mklcpu +#undef BACKEND diff --git a/src/sparse_blas/backends/mklgpu/mklgpu_operations.cpp b/src/sparse_blas/backends/mklgpu/mklgpu_operations.cpp index 0c5a73fb0..be5e0c0aa 100644 --- a/src/sparse_blas/backends/mklgpu/mklgpu_operations.cpp +++ b/src/sparse_blas/backends/mklgpu/mklgpu_operations.cpp @@ -17,16 +17,17 @@ * **************************************************************************/ +#include "sparse_blas/backends/mkl_common/mkl_handles.hpp" #include "sparse_blas/backends/mkl_common/mkl_helper.hpp" #include "sparse_blas/macros.hpp" -#include "sparse_blas/backends/mkl_common/mkl_handles.hpp" +#include "sparse_blas/matrix_view_comparison.hpp" #include "oneapi/mkl/sparse_blas/detail/mklgpu/onemkl_sparse_blas_mklgpu.hpp" -namespace oneapi::mkl::sparse::mklgpu { +#define BACKEND mklgpu #include "sparse_blas/backends/mkl_common/mkl_spmm.cxx" #include "sparse_blas/backends/mkl_common/mkl_spmv.cxx" #include "sparse_blas/backends/mkl_common/mkl_spsv.cxx" -} // namespace oneapi::mkl::sparse::mklgpu +#undef BACKEND diff --git a/src/sparse_blas/matrix_view_comparison.hpp b/src/sparse_blas/matrix_view_comparison.hpp new file mode 100644 index 000000000..e01be7311 --- /dev/null +++ b/src/sparse_blas/matrix_view_comparison.hpp @@ -0,0 +1,36 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SRC_SPARSE_BLAS_MATRIX_VIEW_COMPARISON_HPP_ +#define _ONEMKL_SRC_SPARSE_BLAS_MATRIX_VIEW_COMPARISON_HPP_ + +#include "oneapi/mkl/sparse_blas/matrix_view.hpp" + +inline bool operator==(const oneapi::mkl::sparse::matrix_view& lhs, + const oneapi::mkl::sparse::matrix_view& rhs) { + return lhs.type_view == rhs.type_view && lhs.uplo_view == rhs.uplo_view && + lhs.diag_view == rhs.diag_view; +} + +inline bool operator!=(const oneapi::mkl::sparse::matrix_view& lhs, + const oneapi::mkl::sparse::matrix_view& rhs) { + return !(lhs == rhs); +} + +#endif // _ONEMKL_SRC_SPARSE_BLAS_MATRIX_VIEW_COMPARISON_HPP_ \ No newline at end of file