Skip to content

Commit

Permalink
Test scalars on device memory
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Jul 2, 2024
1 parent a79c5af commit d452846
Show file tree
Hide file tree
Showing 13 changed files with 209 additions and 109 deletions.
2 changes: 1 addition & 1 deletion src/sparse_blas/backends/mkl_common/mkl_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void check_ptr_is_host_accessible(const std::string &function_name, const std::s
/// Return a scalar on the host from a pointer to host or device memory
/// Used for USM functions
template <typename T>
inline T get_scalar(sycl::queue &queue, const T *host_or_device_ptr) {
inline T get_scalar_on_host(sycl::queue &queue, const T *host_or_device_ptr) {
if (is_ptr_accessible_on_host(queue, host_or_device_ptr)) {
return *host_or_device_ptr;
}
Expand Down
18 changes: 12 additions & 6 deletions src/sparse_blas/backends/mkl_common/mkl_spmm.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ void check_valid_spmm(const std::string function_name, sycl::queue &queue,
detail::check_ptr_is_host_accessible("spmm", "alpha", queue, alpha);
detail::check_ptr_is_host_accessible("spmm", "beta", queue, beta);
}
if (detail::is_ptr_accessible_on_host(queue, alpha) !=
detail::is_ptr_accessible_on_host(queue, beta)) {
throw mkl::invalid_argument(
"sparse_blas", function_name,
"Alpha and beta must both be placed on host memory or device memory.");
}
if (B_handle->dense_layout != C_handle->dense_layout) {
throw mkl::invalid_argument("sparse_blas", function_name,
"B and C matrices must used the same layout.");
Expand Down Expand Up @@ -138,25 +144,25 @@ sycl::event internal_spmm(sycl::queue &queue, oneapi::mkl::transpose opA,
oneapi::mkl::sparse::spmm_alg /*alg*/,
oneapi::mkl::sparse::spmm_descr_t /*spmm_descr*/,
const std::vector<sycl::event> &dependencies) {
T cast_alpha = *static_cast<const T *>(alpha);
T cast_beta = *static_cast<const T *>(beta);
T host_alpha = detail::get_scalar_on_host(queue, static_cast<const T *>(alpha));
T host_beta = detail::get_scalar_on_host(queue, static_cast<const T *>(beta));
auto internal_A_handle = detail::get_internal_handle(A_handle);
internal_A_handle->can_be_reset = false;
auto layout = B_handle->dense_layout;
auto columns = C_handle->num_cols;
auto ldb = B_handle->ld;
auto ldc = C_handle->ld;
if (internal_A_handle->all_use_buffer()) {
oneapi::mkl::sparse::gemm(queue, layout, opA, opB, cast_alpha,
oneapi::mkl::sparse::gemm(queue, layout, opA, opB, host_alpha,
internal_A_handle->backend_handle, B_handle->get_buffer<T>(),
columns, ldb, cast_beta, C_handle->get_buffer<T>(), ldc);
columns, ldb, host_beta, C_handle->get_buffer<T>(), ldc);
// Dependencies are not used for buffers
return {};
}
else {
return oneapi::mkl::sparse::gemm(queue, layout, opA, opB, cast_alpha,
return oneapi::mkl::sparse::gemm(queue, layout, opA, opB, host_alpha,
internal_A_handle->backend_handle,
B_handle->get_usm_ptr<T>(), columns, ldb, cast_beta,
B_handle->get_usm_ptr<T>(), columns, ldb, host_beta,
C_handle->get_usm_ptr<T>(), ldc, dependencies);
}
}
Expand Down
30 changes: 18 additions & 12 deletions src/sparse_blas/backends/mkl_common/mkl_spmv.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ void check_valid_spmv(const std::string function_name, sycl::queue &queue,
detail::check_ptr_is_host_accessible("spmv", "alpha", queue, alpha);
detail::check_ptr_is_host_accessible("spmv", "beta", queue, beta);
}
if (detail::is_ptr_accessible_on_host(queue, alpha) !=
detail::is_ptr_accessible_on_host(queue, beta)) {
throw mkl::invalid_argument(
"sparse_blas", function_name,
"Alpha and beta must both be placed on host memory or device memory.");
}
if (A_view.type_view == oneapi::mkl::sparse::matrix_descr::diagonal) {
throw mkl::invalid_argument("sparse_blas", function_name,
"Matrix view's type cannot be diagonal.");
Expand Down Expand Up @@ -153,25 +159,25 @@ sycl::event internal_spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const
oneapi::mkl::sparse::spmv_alg /*alg*/,
oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/,
const std::vector<sycl::event> &dependencies) {
T cast_alpha = *static_cast<const T *>(alpha);
T cast_beta = *static_cast<const T *>(beta);
T host_alpha = detail::get_scalar_on_host(queue, static_cast<const T *>(alpha));
T host_beta = detail::get_scalar_on_host(queue, static_cast<const T *>(beta));
auto internal_A_handle = detail::get_internal_handle(A_handle);
internal_A_handle->can_be_reset = false;
auto backend_handle = internal_A_handle->backend_handle;
if (internal_A_handle->all_use_buffer()) {
auto x_buffer = x_handle->get_buffer<T>();
auto y_buffer = y_handle->get_buffer<T>();
if (A_view.type_view == matrix_descr::triangular) {
oneapi::mkl::sparse::trmv(queue, A_view.uplo_view, opA, A_view.diag_view, cast_alpha,
backend_handle, x_buffer, cast_beta, y_buffer);
oneapi::mkl::sparse::trmv(queue, A_view.uplo_view, opA, A_view.diag_view, host_alpha,
backend_handle, x_buffer, host_beta, y_buffer);
}
else if (A_view.type_view == matrix_descr::symmetric ||
A_view.type_view == matrix_descr::hermitian) {
oneapi::mkl::sparse::symv(queue, A_view.uplo_view, cast_alpha, backend_handle, x_buffer,
cast_beta, y_buffer);
oneapi::mkl::sparse::symv(queue, A_view.uplo_view, host_alpha, backend_handle, x_buffer,
host_beta, y_buffer);
}
else {
oneapi::mkl::sparse::gemv(queue, opA, cast_alpha, backend_handle, x_buffer, cast_beta,
oneapi::mkl::sparse::gemv(queue, opA, host_alpha, backend_handle, x_buffer, host_beta,
y_buffer);
}
// Dependencies are not used for buffers
Expand All @@ -182,17 +188,17 @@ sycl::event internal_spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const
auto y_usm = y_handle->get_usm_ptr<T>();
if (A_view.type_view == matrix_descr::triangular) {
return oneapi::mkl::sparse::trmv(queue, A_view.uplo_view, opA, A_view.diag_view,
cast_alpha, backend_handle, x_usm, cast_beta, y_usm,
host_alpha, backend_handle, x_usm, host_beta, y_usm,
dependencies);
}
else if (A_view.type_view == matrix_descr::symmetric ||
A_view.type_view == matrix_descr::hermitian) {
return oneapi::mkl::sparse::symv(queue, A_view.uplo_view, cast_alpha, backend_handle,
x_usm, cast_beta, y_usm, dependencies);
return oneapi::mkl::sparse::symv(queue, A_view.uplo_view, host_alpha, backend_handle,
x_usm, host_beta, y_usm, dependencies);
}
else {
return oneapi::mkl::sparse::gemv(queue, opA, cast_alpha, backend_handle, x_usm,
cast_beta, y_usm, dependencies);
return oneapi::mkl::sparse::gemv(queue, opA, host_alpha, backend_handle, x_usm,
host_beta, y_usm, dependencies);
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/sparse_blas/backends/mkl_common/mkl_spsv.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,18 @@ sycl::event internal_spsv(sycl::queue &queue, oneapi::mkl::transpose opA, const
oneapi::mkl::sparse::spsv_alg /*alg*/,
oneapi::mkl::sparse::spsv_descr_t /*spsv_descr*/,
const std::vector<sycl::event> &dependencies) {
T cast_alpha = *static_cast<const T *>(alpha);
T host_alpha = detail::get_scalar_on_host(queue, static_cast<const T *>(alpha));
auto internal_A_handle = detail::get_internal_handle(A_handle);
internal_A_handle->can_be_reset = false;
if (internal_A_handle->all_use_buffer()) {
oneapi::mkl::sparse::trsv(queue, A_view.uplo_view, opA, A_view.diag_view, cast_alpha,
oneapi::mkl::sparse::trsv(queue, A_view.uplo_view, opA, A_view.diag_view, host_alpha,
internal_A_handle->backend_handle, x_handle->get_buffer<T>(),
y_handle->get_buffer<T>());
// Dependencies are not used for buffers
return {};
}
else {
return oneapi::mkl::sparse::trsv(queue, A_view.uplo_view, opA, A_view.diag_view, cast_alpha,
return oneapi::mkl::sparse::trsv(queue, A_view.uplo_view, opA, A_view.diag_view, host_alpha,
internal_A_handle->backend_handle,
x_handle->get_usm_ptr<T>(), y_handle->get_usm_ptr<T>(),
dependencies);
Expand Down
44 changes: 30 additions & 14 deletions tests/unit_tests/sparse_blas/include/test_spmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ void test_helper_with_format_with_transpose(
oneapi::mkl::sparse::matrix_view default_A_view;
std::set<oneapi::mkl::sparse::matrix_property> no_properties;
bool no_reset_data = false;
bool no_scalars_on_device = false;

{
int m = 4, k = 6, n = 5;
Expand All @@ -83,65 +84,77 @@ void test_helper_with_format_with_transpose(
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero,
col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb, ldc,
default_alg, default_A_view, no_properties, no_reset_data),
default_alg, default_A_view, no_properties, no_reset_data,
no_scalars_on_device),
num_passed, num_skipped);
// Reset data
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero,
col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb, ldc,
default_alg, default_A_view, no_properties, true),
default_alg, default_A_view, no_properties, true,
no_scalars_on_device),
num_passed, num_skipped);
// Test alpha and beta on the device
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero,
col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb, ldc,
default_alg, default_A_view, no_properties, no_reset_data, true),
num_passed, num_skipped);
// Test index_base 1
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix,
oneapi::mkl::index_base::one, col_major, transpose_A, transpose_B,
fp_one, fp_zero, ldb, ldc, default_alg, default_A_view, no_properties,
no_reset_data),
no_reset_data, no_scalars_on_device),
num_passed, num_skipped);
// Test non-default alpha
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero,
col_major, transpose_A, transpose_B, set_fp_value<fpType>()(2.f, 1.5f),
fp_zero, ldb, ldc, default_alg, default_A_view, no_properties,
no_reset_data),
no_reset_data, no_scalars_on_device),
num_passed, num_skipped);
// Test non-default beta
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero,
col_major, transpose_A, transpose_B, fp_one,
set_fp_value<fpType>()(3.2f, 1.f), ldb, ldc, default_alg,
default_A_view, no_properties, no_reset_data),
default_A_view, no_properties, no_reset_data, no_scalars_on_device),
num_passed, num_skipped);
// Test 0 alpha
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero,
col_major, transpose_A, transpose_B, fp_zero, fp_one, ldb, ldc,
default_alg, default_A_view, no_properties, no_reset_data),
default_alg, default_A_view, no_properties, no_reset_data,
no_scalars_on_device),
num_passed, num_skipped);
// Test 0 alpha and beta
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero,
col_major, transpose_A, transpose_B, fp_zero, fp_zero, ldb, ldc,
default_alg, default_A_view, no_properties, no_reset_data),
default_alg, default_A_view, no_properties, no_reset_data,
no_scalars_on_device),
num_passed, num_skipped);
// Test non-default ldb
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero,
col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb + 5, ldc,
default_alg, default_A_view, no_properties, no_reset_data),
default_alg, default_A_view, no_properties, no_reset_data,
no_scalars_on_device),
num_passed, num_skipped);
// Test non-default ldc
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero,
col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb, ldc + 6,
default_alg, default_A_view, no_properties, no_reset_data),
default_alg, default_A_view, no_properties, no_reset_data,
no_scalars_on_device),
num_passed, num_skipped);
// Test row major layout
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero,
oneapi::mkl::layout::row_major, transpose_A, transpose_B, fp_one,
fp_zero, ncols_B, ncols_C, default_alg, default_A_view, no_properties,
no_reset_data),
no_reset_data, no_scalars_on_device),
num_passed, num_skipped);
// Test int64 indices
long long_nrows_A = 27, long_ncols_A = 13, long_ncols_C = 6;
Expand All @@ -150,22 +163,24 @@ void test_helper_with_format_with_transpose(
test_functor_i64(dev, format, long_nrows_A, long_ncols_A, long_ncols_C,
density_A_matrix, index_zero, col_major, transpose_A, transpose_B,
fp_one, fp_zero, long_ldb, long_ldc, default_alg, default_A_view,
no_properties, no_reset_data),
no_properties, no_reset_data, no_scalars_on_device),
num_passed, num_skipped);
// Test other algorithms
for (auto alg : non_default_algorithms) {
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix,
index_zero, col_major, transpose_A, transpose_B, fp_one, fp_zero,
ldb, ldc, alg, default_A_view, no_properties, no_reset_data),
ldb, ldc, alg, default_A_view, no_properties, no_reset_data,
no_scalars_on_device),
num_passed, num_skipped);
}
// Test matrix properties
for (auto properties : test_matrix_properties) {
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix,
index_zero, col_major, transpose_A, transpose_B, fp_one, fp_zero,
ldb, ldc, default_alg, default_A_view, properties, no_reset_data),
ldb, ldc, default_alg, default_A_view, properties, no_reset_data,
no_scalars_on_device),
num_passed, num_skipped);
}
}
Expand All @@ -182,7 +197,8 @@ void test_helper_with_format_with_transpose(
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero,
col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb, ldc,
default_alg, default_A_view, no_properties, no_reset_data),
default_alg, default_A_view, no_properties, no_reset_data,
no_scalars_on_device),
num_passed, num_skipped);
}
}
Expand Down
Loading

0 comments on commit d452846

Please sign in to comment.