Skip to content

Commit

Permalink
Reduce number of calls to get_pointer_type
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Jul 19, 2024
1 parent 43f4669 commit 2f59edc
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 58 deletions.
13 changes: 7 additions & 6 deletions src/sparse_blas/backends/mkl_common/mkl_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ inline bool is_ptr_accessible_on_host(sycl::queue &queue, const T *host_or_devic
}

/// Throw an exception if the scalar is not accessible in the host
template <typename T>
void check_ptr_is_host_accessible(const std::string &function_name, const std::string &scalar_name,
sycl::queue &queue, const T *host_or_device_ptr) {
if (!is_ptr_accessible_on_host(queue, host_or_device_ptr)) {
inline void check_ptr_is_host_accessible(const std::string &function_name,
const std::string &scalar_name,
bool is_ptr_accessible_on_host) {
if (!is_ptr_accessible_on_host) {
throw mkl::invalid_argument(
"sparse_blas", function_name,
"Scalar " + scalar_name + " must be accessible on the host for buffer functions.");
Expand All @@ -56,8 +56,9 @@ void check_ptr_is_host_accessible(const std::string &function_name, const std::s
/// Return a scalar on the host from a pointer to host or device memory
/// Used for USM functions
template <typename T>
inline T get_scalar_on_host(sycl::queue &queue, const T *host_or_device_ptr) {
if (is_ptr_accessible_on_host(queue, host_or_device_ptr)) {
inline T get_scalar_on_host(sycl::queue &queue, const T *host_or_device_ptr,
bool is_ptr_accessible_on_host) {
if (is_ptr_accessible_on_host) {
return *host_or_device_ptr;
}
T scalar;
Expand Down
60 changes: 36 additions & 24 deletions src/sparse_blas/backends/mkl_common/mkl_spmm.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,23 @@ sycl::event release_spmm_descr(sycl::queue &queue, oneapi::mkl::sparse::spmm_des
return detail::collapse_dependencies(queue, dependencies);
}

void check_valid_spmm(const std::string &function_name, sycl::queue &queue,
oneapi::mkl::transpose opA, oneapi::mkl::sparse::matrix_view A_view,
void check_valid_spmm(const std::string &function_name, oneapi::mkl::transpose opA,
oneapi::mkl::sparse::matrix_view A_view,
oneapi::mkl::sparse::matrix_handle_t A_handle,
oneapi::mkl::sparse::dense_matrix_handle_t B_handle,
oneapi::mkl::sparse::dense_matrix_handle_t C_handle, const void *alpha,
const void *beta) {
oneapi::mkl::sparse::dense_matrix_handle_t C_handle,
bool is_alpha_host_accessible, bool is_beta_host_accessible) {
THROW_IF_NULLPTR(function_name, A_handle);
THROW_IF_NULLPTR(function_name, B_handle);
THROW_IF_NULLPTR(function_name, C_handle);

auto internal_A_handle = detail::get_internal_handle(A_handle);
detail::check_all_containers_compatible(function_name, internal_A_handle, B_handle, C_handle);
if (internal_A_handle->all_use_buffer()) {
detail::check_ptr_is_host_accessible("spmm", "alpha", queue, alpha);
detail::check_ptr_is_host_accessible("spmm", "beta", queue, beta);
detail::check_ptr_is_host_accessible("spmm", "alpha", is_alpha_host_accessible);
detail::check_ptr_is_host_accessible("spmm", "beta", is_beta_host_accessible);
}
if (detail::is_ptr_accessible_on_host(queue, alpha) !=
detail::is_ptr_accessible_on_host(queue, beta)) {
if (is_alpha_host_accessible != is_beta_host_accessible) {
throw mkl::invalid_argument(
"sparse_blas", function_name,
"Alpha and beta must both be placed on host memory or device memory.");
Expand Down Expand Up @@ -91,7 +90,10 @@ void spmm_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA,
oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/,
std::size_t &temp_buffer_size) {
// TODO: Add support for external workspace once the close-source oneMKL backend supports it.
check_valid_spmm(__func__, queue, opA, A_view, A_handle, B_handle, C_handle, alpha, beta);
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible,
is_beta_host_accessible);
temp_buffer_size = 0;
}

Expand All @@ -103,7 +105,10 @@ void spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::
oneapi::mkl::sparse::spmm_alg alg,
oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/,
sycl::buffer<std::uint8_t, 1> /*workspace*/) {
check_valid_spmm(__func__, queue, opA, A_view, A_handle, B_handle, C_handle, alpha, beta);
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible,
is_beta_host_accessible);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (!internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
Expand All @@ -124,7 +129,10 @@ sycl::event spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA,
oneapi::mkl::sparse::spmm_alg alg,
oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, void * /*workspace*/,
const std::vector<sycl::event> &dependencies) {
check_valid_spmm(__func__, queue, opA, A_view, A_handle, B_handle, C_handle, alpha, beta);
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible,
is_beta_host_accessible);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
Expand All @@ -138,17 +146,17 @@ sycl::event spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA,
}

template <typename T>
sycl::event internal_spmm(sycl::queue &queue, oneapi::mkl::transpose opA,
oneapi::mkl::transpose opB, const void *alpha,
oneapi::mkl::sparse::matrix_view /*A_view*/,
oneapi::mkl::sparse::matrix_handle_t A_handle,
oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void *beta,
oneapi::mkl::sparse::dense_matrix_handle_t C_handle,
oneapi::mkl::sparse::spmm_alg /*alg*/,
oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/,
const std::vector<sycl::event> &dependencies) {
T host_alpha = detail::get_scalar_on_host(queue, static_cast<const T *>(alpha));
T host_beta = detail::get_scalar_on_host(queue, static_cast<const T *>(beta));
sycl::event internal_spmm(
sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB, const void *alpha,
oneapi::mkl::sparse::matrix_view /*A_view*/, oneapi::mkl::sparse::matrix_handle_t A_handle,
oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void *beta,
oneapi::mkl::sparse::dense_matrix_handle_t C_handle, oneapi::mkl::sparse::spmm_alg /*alg*/,
oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/, const std::vector<sycl::event> &dependencies,
bool is_alpha_host_accessible, bool is_beta_host_accessible) {
T host_alpha =
detail::get_scalar_on_host(queue, static_cast<const T *>(alpha), is_alpha_host_accessible);
T host_beta =
detail::get_scalar_on_host(queue, static_cast<const T *>(beta), is_beta_host_accessible);
auto internal_A_handle = detail::get_internal_handle(A_handle);
internal_A_handle->can_be_reset = false;
auto layout = B_handle->dense_layout;
Expand Down Expand Up @@ -177,8 +185,12 @@ sycl::event spmm(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::tr
oneapi::mkl::sparse::dense_matrix_handle_t C_handle,
oneapi::mkl::sparse::spmm_alg alg, oneapi::mkl::sparse::spmm_descr_t spmm_descr,
const std::vector<sycl::event> &dependencies) {
check_valid_spmm(__func__, queue, opA, A_view, A_handle, B_handle, C_handle, alpha, beta);
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible,
is_beta_host_accessible);
auto value_type = detail::get_internal_handle(A_handle)->get_value_type();
DISPATCH_MKL_OPERATION("spmm", value_type, internal_spmm, queue, opA, opB, alpha, A_view,
A_handle, B_handle, beta, C_handle, alg, spmm_descr, dependencies);
A_handle, B_handle, beta, C_handle, alg, spmm_descr, dependencies,
is_alpha_host_accessible, is_beta_host_accessible);
}
47 changes: 31 additions & 16 deletions src/sparse_blas/backends/mkl_common/mkl_spmv.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,23 @@ sycl::event release_spmv_descr(sycl::queue &queue, oneapi::mkl::sparse::spmv_des
return detail::collapse_dependencies(queue, dependencies);
}

void check_valid_spmv(const std::string &function_name, sycl::queue &queue,
oneapi::mkl::transpose opA, oneapi::mkl::sparse::matrix_view A_view,
void check_valid_spmv(const std::string &function_name, oneapi::mkl::transpose opA,
oneapi::mkl::sparse::matrix_view A_view,
oneapi::mkl::sparse::matrix_handle_t A_handle,
oneapi::mkl::sparse::dense_vector_handle_t x_handle,
oneapi::mkl::sparse::dense_vector_handle_t y_handle, const void *alpha,
const void *beta) {
oneapi::mkl::sparse::dense_vector_handle_t y_handle,
bool is_alpha_host_accessible, bool is_beta_host_accessible) {
THROW_IF_NULLPTR(function_name, A_handle);
THROW_IF_NULLPTR(function_name, x_handle);
THROW_IF_NULLPTR(function_name, y_handle);

auto internal_A_handle = detail::get_internal_handle(A_handle);
detail::check_all_containers_compatible(function_name, internal_A_handle, x_handle, y_handle);
if (internal_A_handle->all_use_buffer()) {
detail::check_ptr_is_host_accessible("spmv", "alpha", queue, alpha);
detail::check_ptr_is_host_accessible("spmv", "beta", queue, beta);
detail::check_ptr_is_host_accessible("spmv", "alpha", is_alpha_host_accessible);
detail::check_ptr_is_host_accessible("spmv", "beta", is_beta_host_accessible);
}
if (detail::is_ptr_accessible_on_host(queue, alpha) !=
detail::is_ptr_accessible_on_host(queue, beta)) {
if (is_alpha_host_accessible != is_beta_host_accessible) {
throw mkl::invalid_argument(
"sparse_blas", function_name,
"Alpha and beta must both be placed on host memory or device memory.");
Expand Down Expand Up @@ -81,7 +80,10 @@ void spmv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, const void
oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/,
std::size_t &temp_buffer_size) {
// TODO: Add support for external workspace once the close-source oneMKL backend supports it.
check_valid_spmv(__func__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta);
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible,
is_beta_host_accessible);
temp_buffer_size = 0;
}

Expand All @@ -93,7 +95,10 @@ void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a
oneapi::mkl::sparse::spmv_alg alg,
oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/,
sycl::buffer<std::uint8_t, 1> /*workspace*/) {
check_valid_spmv(__func__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta);
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible,
is_beta_host_accessible);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (!internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
Expand Down Expand Up @@ -127,7 +132,10 @@ sycl::event spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const
oneapi::mkl::sparse::spmv_alg alg,
oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, void * /*workspace*/,
const std::vector<sycl::event> &dependencies) {
check_valid_spmv(__func__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta);
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible,
is_beta_host_accessible);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
Expand Down Expand Up @@ -158,9 +166,12 @@ sycl::event internal_spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const
oneapi::mkl::sparse::dense_vector_handle_t y_handle,
oneapi::mkl::sparse::spmv_alg /*alg*/,
oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/,
const std::vector<sycl::event> &dependencies) {
T host_alpha = detail::get_scalar_on_host(queue, static_cast<const T *>(alpha));
T host_beta = detail::get_scalar_on_host(queue, static_cast<const T *>(beta));
const std::vector<sycl::event> &dependencies,
bool is_alpha_host_accessible, bool is_beta_host_accessible) {
T host_alpha =
detail::get_scalar_on_host(queue, static_cast<const T *>(alpha), is_alpha_host_accessible);
T host_beta =
detail::get_scalar_on_host(queue, static_cast<const T *>(beta), is_beta_host_accessible);
auto internal_A_handle = detail::get_internal_handle(A_handle);
internal_A_handle->can_be_reset = false;
auto backend_handle = internal_A_handle->backend_handle;
Expand Down Expand Up @@ -210,8 +221,12 @@ sycl::event spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
oneapi::mkl::sparse::dense_vector_handle_t y_handle,
oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t spmv_descr,
const std::vector<sycl::event> &dependencies) {
check_valid_spmv(__func__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta);
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible,
is_beta_host_accessible);
auto value_type = detail::get_internal_handle(A_handle)->get_value_type();
DISPATCH_MKL_OPERATION("spmv", value_type, internal_spmv, queue, opA, alpha, A_view, A_handle,
x_handle, beta, y_handle, alg, spmv_descr, dependencies);
x_handle, beta, y_handle, alg, spmv_descr, dependencies,
is_alpha_host_accessible, is_beta_host_accessible);
}
Loading

0 comments on commit 2f59edc

Please sign in to comment.