Skip to content

Commit

Permalink
Add checks that buffer_size and optimize functions are called before …
Browse files Browse the repository at this point in the history
…when possible
  • Loading branch information
Rbiessy committed Aug 26, 2024
1 parent c7a4420 commit 5f8e183
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 60 deletions.
9 changes: 9 additions & 0 deletions src/sparse_blas/backends/mkl_common/mkl_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ inline sycl::event collapse_dependencies(sycl::queue &queue,
"Internal error: unsupported type " + data_type_to_str(value_type)); \
}

#define CHECK_DESCR_MATCH(descr, argument, optimize_func_name) \
do { \
if (descr->last_optimized_##argument != argument) { \
throw mkl::invalid_argument( \
"sparse_blas", __func__, \
#argument " argument must match with the previous call to " #optimize_func_name); \
} \
} while (0)

} // namespace oneapi::mkl::sparse::detail

#endif // _ONEMKL_SRC_SPARSE_BLAS_BACKENDS_MKL_COMMON_MKL_HELPER_HPP_
91 changes: 71 additions & 20 deletions src/sparse_blas/backends/mkl_common/mkl_spmm.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,31 @@
*
**************************************************************************/

// The operation descriptor is not needed as long as the backend does not have an equivalent type and does not support external workspace.
using spmm_descr = void *;
namespace oneapi::mkl::sparse {

struct spmm_descr {
bool buffer_size_called = false;
bool optimized_called = false;
oneapi::mkl::transpose last_optimized_opA;
oneapi::mkl::transpose last_optimized_opB;
oneapi::mkl::sparse::matrix_view last_optimized_A_view;
oneapi::mkl::sparse::matrix_handle_t last_optimized_A_handle;
oneapi::mkl::sparse::dense_matrix_handle_t last_optimized_B_handle;
oneapi::mkl::sparse::dense_matrix_handle_t last_optimized_C_handle;
oneapi::mkl::sparse::spmm_alg last_optimized_alg;
};

} // namespace oneapi::mkl::sparse

namespace oneapi::mkl::sparse::BACKEND {

void init_spmm_descr(sycl::queue & /*queue*/, oneapi::mkl::sparse::spmm_descr_t *p_spmm_descr) {
*p_spmm_descr = nullptr;
*p_spmm_descr = new spmm_descr();
}

sycl::event release_spmm_descr(sycl::queue &queue, oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/,
sycl::event release_spmm_descr(sycl::queue &queue, oneapi::mkl::sparse::spmm_descr_t spmm_descr,
const std::vector<sycl::event> &dependencies) {
return detail::collapse_dependencies(queue, dependencies);
return detail::submit_release(queue, spmm_descr, dependencies);
}

void check_valid_spmm(const std::string &function_name, oneapi::mkl::transpose opA,
Expand Down Expand Up @@ -87,28 +102,50 @@ void spmm_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA,
oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void *beta,
oneapi::mkl::sparse::dense_matrix_handle_t C_handle,
oneapi::mkl::sparse::spmm_alg /*alg*/,
oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/,
std::size_t &temp_buffer_size) {
oneapi::mkl::sparse::spmm_descr_t spmm_descr, std::size_t &temp_buffer_size) {
// TODO: Add support for external workspace once the close-source oneMKL backend supports it.
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible,
is_beta_host_accessible);
temp_buffer_size = 0;
spmm_descr->buffer_size_called = true;
}

inline void common_spmm_optimize(
sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB, const void *alpha,
oneapi::mkl::sparse::matrix_view A_view, oneapi::mkl::sparse::matrix_handle_t A_handle,
oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void *beta,
oneapi::mkl::sparse::dense_matrix_handle_t C_handle, oneapi::mkl::sparse::spmm_alg alg,
oneapi::mkl::sparse::spmm_descr_t spmm_descr) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible,
is_beta_host_accessible);
if (!spmm_descr->buffer_size_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spmm_buffer_size must be called with the same arguments before spmm_optimize.");
}
spmm_descr->optimized_called = true;
spmm_descr->last_optimized_opA = opA;
spmm_descr->last_optimized_opB = opB;
spmm_descr->last_optimized_A_view = A_view;
spmm_descr->last_optimized_A_handle = A_handle;
spmm_descr->last_optimized_B_handle = B_handle;
spmm_descr->last_optimized_C_handle = C_handle;
spmm_descr->last_optimized_alg = alg;
}

void spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose /*opB*/,
void spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB,
const void *alpha, oneapi::mkl::sparse::matrix_view A_view,
oneapi::mkl::sparse::matrix_handle_t A_handle,
oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void *beta,
oneapi::mkl::sparse::dense_matrix_handle_t C_handle,
oneapi::mkl::sparse::spmm_alg alg,
oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/,
oneapi::mkl::sparse::spmm_alg alg, oneapi::mkl::sparse::spmm_descr_t spmm_descr,
sycl::buffer<std::uint8_t, 1> /*workspace*/) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible,
is_beta_host_accessible);
common_spmm_optimize(queue, opA, opB, alpha, A_view, A_handle, B_handle, beta, C_handle, alg,
spmm_descr);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (!internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
Expand All @@ -121,18 +158,16 @@ void spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::
}

sycl::event spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA,
oneapi::mkl::transpose /*opB*/, const void *alpha,
oneapi::mkl::transpose opB, const void *alpha,
oneapi::mkl::sparse::matrix_view A_view,
oneapi::mkl::sparse::matrix_handle_t A_handle,
oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void *beta,
oneapi::mkl::sparse::dense_matrix_handle_t C_handle,
oneapi::mkl::sparse::spmm_alg alg,
oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, void * /*workspace*/,
oneapi::mkl::sparse::spmm_descr_t spmm_descr, void * /*workspace*/,
const std::vector<sycl::event> &dependencies) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible,
is_beta_host_accessible);
common_spmm_optimize(queue, opA, opB, alpha, A_view, A_handle, B_handle, beta, C_handle, alg,
spmm_descr);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
Expand Down Expand Up @@ -189,8 +224,24 @@ sycl::event spmm(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::tr
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible,
is_beta_host_accessible);

if (!spmm_descr->optimized_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spmm_optimize must be called with the same arguments before spmm.");
}
CHECK_DESCR_MATCH(spmm_descr, opA, "spmm_optimize");
CHECK_DESCR_MATCH(spmm_descr, opB, "spmm_optimize");
CHECK_DESCR_MATCH(spmm_descr, A_view, "spmm_optimize");
CHECK_DESCR_MATCH(spmm_descr, A_handle, "spmm_optimize");
CHECK_DESCR_MATCH(spmm_descr, B_handle, "spmm_optimize");
CHECK_DESCR_MATCH(spmm_descr, C_handle, "spmm_optimize");
CHECK_DESCR_MATCH(spmm_descr, alg, "spmm_optimize");

auto value_type = detail::get_internal_handle(A_handle)->get_value_type();
DISPATCH_MKL_OPERATION("spmm", value_type, internal_spmm, queue, opA, opB, alpha, A_view,
A_handle, B_handle, beta, C_handle, alg, spmm_descr, dependencies,
is_alpha_host_accessible, is_beta_host_accessible);
}

} // namespace oneapi::mkl::sparse::BACKEND
86 changes: 68 additions & 18 deletions src/sparse_blas/backends/mkl_common/mkl_spmv.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,30 @@
*
**************************************************************************/

// The operation descriptor is not needed as long as the backend does not have an equivalent type and does not support external workspace.
using spmv_descr = void *;
namespace oneapi::mkl::sparse {

struct spmv_descr {
bool buffer_size_called = false;
bool optimized_called = false;
oneapi::mkl::transpose last_optimized_opA;
oneapi::mkl::sparse::matrix_view last_optimized_A_view;
oneapi::mkl::sparse::matrix_handle_t last_optimized_A_handle;
oneapi::mkl::sparse::dense_vector_handle_t last_optimized_x_handle;
oneapi::mkl::sparse::dense_vector_handle_t last_optimized_y_handle;
oneapi::mkl::sparse::spmv_alg last_optimized_alg;
};

} // namespace oneapi::mkl::sparse

namespace oneapi::mkl::sparse::BACKEND {

void init_spmv_descr(sycl::queue & /*queue*/, oneapi::mkl::sparse::spmv_descr_t *p_spmv_descr) {
*p_spmv_descr = nullptr;
*p_spmv_descr = new spmv_descr();
}

sycl::event release_spmv_descr(sycl::queue &queue, oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/,
sycl::event release_spmv_descr(sycl::queue &queue, oneapi::mkl::sparse::spmv_descr_t spmv_descr,
const std::vector<sycl::event> &dependencies) {
return detail::collapse_dependencies(queue, dependencies);
return detail::submit_release(queue, spmv_descr, dependencies);
}

void check_valid_spmv(const std::string &function_name, oneapi::mkl::transpose opA,
Expand Down Expand Up @@ -77,28 +91,51 @@ void spmv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, const void
oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void *beta,
oneapi::mkl::sparse::dense_vector_handle_t y_handle,
oneapi::mkl::sparse::spmv_alg /*alg*/,
oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/,
std::size_t &temp_buffer_size) {
oneapi::mkl::sparse::spmv_descr_t spmv_descr, std::size_t &temp_buffer_size) {
// TODO: Add support for external workspace once the close-source oneMKL backend supports it.
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible,
is_beta_host_accessible);
temp_buffer_size = 0;
spmv_descr->buffer_size_called = true;
}

inline void common_spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha,
oneapi::mkl::sparse::matrix_view A_view,
oneapi::mkl::sparse::matrix_handle_t A_handle,
oneapi::mkl::sparse::dense_vector_handle_t x_handle,
const void *beta,
oneapi::mkl::sparse::dense_vector_handle_t y_handle,
oneapi::mkl::sparse::spmv_alg alg,
oneapi::mkl::sparse::spmv_descr_t spmv_descr) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible,
is_beta_host_accessible);
if (!spmv_descr->buffer_size_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spmv_buffer_size must be called with the same arguments before spmv_optimize.");
}
spmv_descr->optimized_called = true;
spmv_descr->last_optimized_opA = opA;
spmv_descr->last_optimized_A_view = A_view;
spmv_descr->last_optimized_A_handle = A_handle;
spmv_descr->last_optimized_x_handle = x_handle;
spmv_descr->last_optimized_y_handle = y_handle;
spmv_descr->last_optimized_alg = alg;
}

void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha,
oneapi::mkl::sparse::matrix_view A_view,
oneapi::mkl::sparse::matrix_handle_t A_handle,
oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void *beta,
oneapi::mkl::sparse::dense_vector_handle_t y_handle,
oneapi::mkl::sparse::spmv_alg alg,
oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/,
oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t spmv_descr,
sycl::buffer<std::uint8_t, 1> /*workspace*/) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible,
is_beta_host_accessible);
common_spmv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg,
spmv_descr);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (!internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
Expand Down Expand Up @@ -127,12 +164,10 @@ sycl::event spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const
oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void *beta,
oneapi::mkl::sparse::dense_vector_handle_t y_handle,
oneapi::mkl::sparse::spmv_alg alg,
oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, void * /*workspace*/,
oneapi::mkl::sparse::spmv_descr_t spmv_descr, void * /*workspace*/,
const std::vector<sycl::event> &dependencies) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible,
is_beta_host_accessible);
common_spmv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg,
spmv_descr);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
Expand Down Expand Up @@ -222,8 +257,23 @@ sycl::event spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible,
is_beta_host_accessible);

if (!spmv_descr->optimized_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spmv_optimize must be called with the same arguments before spmv.");
}
CHECK_DESCR_MATCH(spmv_descr, opA, "spmv_optimize");
CHECK_DESCR_MATCH(spmv_descr, A_view, "spmv_optimize");
CHECK_DESCR_MATCH(spmv_descr, A_handle, "spmv_optimize");
CHECK_DESCR_MATCH(spmv_descr, x_handle, "spmv_optimize");
CHECK_DESCR_MATCH(spmv_descr, y_handle, "spmv_optimize");
CHECK_DESCR_MATCH(spmv_descr, alg, "spmv_optimize");

auto value_type = detail::get_internal_handle(A_handle)->get_value_type();
DISPATCH_MKL_OPERATION("spmv", value_type, internal_spmv, queue, opA, alpha, A_view, A_handle,
x_handle, beta, y_handle, alg, spmv_descr, dependencies,
is_alpha_host_accessible, is_beta_host_accessible);
}

} // namespace oneapi::mkl::sparse::BACKEND
Loading

0 comments on commit 5f8e183

Please sign in to comment.