Skip to content

Commit

Permalink
Merge branch 'romain/update_sparse_mkl' into romain/cusparse
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Sep 3, 2024
2 parents 42356ba + 0ecb032 commit 2850276
Show file tree
Hide file tree
Showing 17 changed files with 48 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ int run_sparse_matrix_vector_multiply_example(const sycl::device &dev) {
std::size_t sizeja = static_cast<std::size_t>(27 * nrows);
std::size_t sizeia = static_cast<std::size_t>(nrows + 1);
std::size_t sizevec = static_cast<std::size_t>(nrows);
auto sizevec_i64 = static_cast<std::int64_t>(sizevec);

ia = (intType *)sycl::malloc_shared(sizeia * sizeof(intType), main_queue);
ja = (intType *)sycl::malloc_shared(sizeja * sizeof(intType), main_queue);
Expand Down Expand Up @@ -148,10 +149,8 @@ int run_sparse_matrix_vector_multiply_example(const sycl::device &dev) {
// Create and initialize dense vector handles
oneapi::mkl::sparse::dense_vector_handle_t x_handle = nullptr;
oneapi::mkl::sparse::dense_vector_handle_t y_handle = nullptr;
oneapi::mkl::sparse::init_dense_vector(main_queue, &x_handle,
static_cast<std::int64_t>(sizevec), x);
oneapi::mkl::sparse::init_dense_vector(main_queue, &y_handle,
static_cast<std::int64_t>(sizevec), y);
oneapi::mkl::sparse::init_dense_vector(main_queue, &x_handle, sizevec_i64, x);
oneapi::mkl::sparse::init_dense_vector(main_queue, &y_handle, sizevec_i64, y);

// Create operation descriptor
oneapi::mkl::sparse::spmv_descr_t descr = nullptr;
Expand Down
10 changes: 4 additions & 6 deletions src/sparse_blas/backends/cusparse/operations/cusparse_spmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,8 @@ inline void common_spmm_optimize(
detail::check_valid_spmm_common(__func__, A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible);
if (!spmm_descr->buffer_size_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spmm_buffer_size must be called with the same arguments before spmm_optimize.");
throw mkl::uninitialized("sparse_blas", __func__,
"spmm_buffer_size must be called before spmm_optimize.");
}
spmm_descr->optimized_called = true;
spmm_descr->last_optimized_opA = opA;
Expand Down Expand Up @@ -238,9 +237,8 @@ sycl::event spmm(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::tr
}

if (!spmm_descr->optimized_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spmm_optimize must be called with the same arguments before spmm.");
throw mkl::uninitialized("sparse_blas", __func__,
"spmm_optimize must be called before spmm.");
}
CHECK_DESCR_MATCH(spmm_descr, opA, "spmm_optimize");
CHECK_DESCR_MATCH(spmm_descr, opB, "spmm_optimize");
Expand Down
10 changes: 4 additions & 6 deletions src/sparse_blas/backends/cusparse/operations/cusparse_spmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,8 @@ inline void common_spmv_optimize(oneapi::mkl::transpose opA, bool is_alpha_host_
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible,
is_beta_host_accessible);
if (!spmv_descr->buffer_size_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spmv_buffer_size must be called with the same arguments before spmv_optimize.");
throw mkl::uninitialized("sparse_blas", __func__,
"spmv_buffer_size must be called before spmv_optimize.");
}
spmv_descr->optimized_called = true;
spmv_descr->last_optimized_opA = opA;
Expand Down Expand Up @@ -258,9 +257,8 @@ sycl::event spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
}

if (!spmv_descr->optimized_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spmv_optimize must be called with the same arguments before spmv.");
throw mkl::uninitialized("sparse_blas", __func__,
"spmv_optimize must be called before spmv.");
}
CHECK_DESCR_MATCH(spmv_descr, opA, "spmv_optimize");
CHECK_DESCR_MATCH(spmv_descr, A_view, "spmv_optimize");
Expand Down
10 changes: 4 additions & 6 deletions src/sparse_blas/backends/cusparse/operations/cusparse_spsv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ inline void common_spsv_optimize(oneapi::mkl::transpose opA, bool is_alpha_host_
detail::check_valid_spsv_common(__func__, A_view, A_handle, x_handle, y_handle,
is_alpha_host_accessible);
if (!spsv_descr->buffer_size_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spsv_buffer_size must be called with the same arguments before spsv_optimize.");
throw mkl::uninitialized("sparse_blas", __func__,
"spsv_buffer_size must be called before spsv_optimize.");
}
spsv_descr->optimized_called = true;
spsv_descr->last_optimized_opA = opA;
Expand Down Expand Up @@ -228,9 +227,8 @@ sycl::event spsv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
}

if (!spsv_descr->optimized_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spsv_optimize must be called with the same arguments before spsv.");
throw mkl::uninitialized("sparse_blas", __func__,
"spsv_optimize must be called before spsv.");
}
CHECK_DESCR_MATCH(spsv_descr, opA, "spsv_optimize");
CHECK_DESCR_MATCH(spsv_descr, A_view, "spsv_optimize");
Expand Down
10 changes: 4 additions & 6 deletions src/sparse_blas/backends/mkl_common/mkl_spmm.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,8 @@ inline void common_spmm_optimize(
check_valid_spmm(__func__, opA, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible,
is_beta_host_accessible);
if (!spmm_descr->buffer_size_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spmm_buffer_size must be called with the same arguments before spmm_optimize.");
throw mkl::uninitialized("sparse_blas", __func__,
"spmm_buffer_size must be called before spmm_optimize.");
}
spmm_descr->optimized_called = true;
spmm_descr->last_optimized_opA = opA;
Expand Down Expand Up @@ -200,9 +199,8 @@ sycl::event spmm(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::tr
is_beta_host_accessible);

if (!spmm_descr->optimized_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spmm_optimize must be called with the same arguments before spmm.");
throw mkl::uninitialized("sparse_blas", __func__,
"spmm_optimize must be called before spmm.");
}
CHECK_DESCR_MATCH(spmm_descr, opA, "spmm_optimize");
CHECK_DESCR_MATCH(spmm_descr, opB, "spmm_optimize");
Expand Down
27 changes: 11 additions & 16 deletions src/sparse_blas/backends/mkl_common/mkl_spmv.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,11 @@ inline void common_spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA,
oneapi::mkl::sparse::spmv_descr_t spmv_descr) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle,
is_alpha_host_accessible, is_beta_host_accessible);
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible,
is_beta_host_accessible);
if (!spmv_descr->buffer_size_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spmv_buffer_size must be called with the same arguments before spmv_optimize.");
throw mkl::uninitialized("sparse_blas", __func__,
"spmv_buffer_size must be called before spmv_optimize.");
}
spmv_descr->optimized_called = true;
spmv_descr->last_optimized_opA = opA;
Expand All @@ -120,22 +119,19 @@ void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a
if (alg == oneapi::mkl::sparse::spmv_alg::no_optimize_alg) {
return;
}
sycl::event event;
internal_A_handle->can_be_reset = false;
if (A_view.type_view == matrix_descr::triangular) {
event = oneapi::mkl::sparse::optimize_trmv(queue, A_view.uplo_view, opA, A_view.diag_view,
internal_A_handle->backend_handle);
oneapi::mkl::sparse::optimize_trmv(queue, A_view.uplo_view, opA, A_view.diag_view,
internal_A_handle->backend_handle);
}
else if (A_view.type_view == matrix_descr::symmetric ||
A_view.type_view == matrix_descr::hermitian) {
// No optimize_symv currently
return;
}
else {
event = oneapi::mkl::sparse::optimize_gemv(queue, opA, internal_A_handle->backend_handle);
oneapi::mkl::sparse::optimize_gemv(queue, opA, internal_A_handle->backend_handle);
}
// spmv_optimize is not asynchronous for buffers as the backend optimize functions don't take buffers.
event.wait_and_throw();
}

sycl::event spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha,
Expand Down Expand Up @@ -235,13 +231,12 @@ sycl::event spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
const std::vector<sycl::event> &dependencies) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle,
is_alpha_host_accessible, is_beta_host_accessible);
check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible,
is_beta_host_accessible);

if (!spmv_descr->optimized_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spmv_optimize must be called with the same arguments before spmv.");
throw mkl::uninitialized("sparse_blas", __func__,
"spmv_optimize must be called before spmv.");
}
CHECK_DESCR_MATCH(spmv_descr, opA, "spmv_optimize");
CHECK_DESCR_MATCH(spmv_descr, A_view, "spmv_optimize");
Expand Down
16 changes: 6 additions & 10 deletions src/sparse_blas/backends/mkl_common/mkl_spsv.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ inline void common_spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA,
check_valid_spsv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible,
alg);
if (!spsv_descr->buffer_size_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spsv_buffer_size must be called with the same arguments before spsv_optimize.");
throw mkl::uninitialized("sparse_blas", __func__,
"spsv_buffer_size must be called before spsv_optimize.");
}
spsv_descr->optimized_called = true;
spsv_descr->last_optimized_opA = opA;
Expand All @@ -128,10 +127,8 @@ void spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a
return;
}
internal_A_handle->can_be_reset = false;
auto event = oneapi::mkl::sparse::optimize_trsv(queue, A_view.uplo_view, opA, A_view.diag_view,
internal_A_handle->backend_handle);
// spsv_optimize is not asynchronous for buffers as the backend optimize functions don't take buffers.
event.wait_and_throw();
oneapi::mkl::sparse::optimize_trsv(queue, A_view.uplo_view, opA, A_view.diag_view,
internal_A_handle->backend_handle);
}

sycl::event spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha,
Expand Down Expand Up @@ -196,9 +193,8 @@ sycl::event spsv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
alg);

if (!spsv_descr->optimized_called) {
throw mkl::uninitialized(
"sparse_blas", __func__,
"spsv_optimize must be called with the same arguments before spsv.");
throw mkl::uninitialized("sparse_blas", __func__,
"spsv_optimize must be called before spsv.");
}
CHECK_DESCR_MATCH(spsv_descr, opA, "spsv_optimize");
CHECK_DESCR_MATCH(spsv_descr, A_view, "spsv_optimize");
Expand Down
2 changes: 1 addition & 1 deletion src/sparse_blas/backends/mklcpu/mklcpu_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
*
**************************************************************************/

#include "sparse_blas/backends/mkl_common/mkl_handles.hpp"
#include "sparse_blas/backends/mkl_common/mkl_dispatch.hpp"
#include "sparse_blas/backends/mkl_common/mkl_handles.hpp"
#include "sparse_blas/common_op_verification.hpp"
#include "sparse_blas/macros.hpp"
#include "sparse_blas/matrix_view_comparison.hpp"
Expand Down
2 changes: 1 addition & 1 deletion src/sparse_blas/backends/mklgpu/mklgpu_handles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

#include "oneapi/mkl/sparse_blas/detail/mklgpu/onemkl_sparse_blas_mklgpu.hpp"

#include "sparse_blas/backends/mkl_common/mkl_handles.hpp"
#include "sparse_blas/backends/mkl_common/mkl_dispatch.hpp"
#include "sparse_blas/backends/mkl_common/mkl_handles.hpp"

namespace oneapi::mkl::sparse::mklgpu {

Expand Down
2 changes: 1 addition & 1 deletion src/sparse_blas/backends/mklgpu/mklgpu_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
*
**************************************************************************/

#include "sparse_blas/backends/mkl_common/mkl_handles.hpp"
#include "sparse_blas/backends/mkl_common/mkl_dispatch.hpp"
#include "sparse_blas/backends/mkl_common/mkl_handles.hpp"
#include "sparse_blas/common_op_verification.hpp"
#include "sparse_blas/macros.hpp"
#include "sparse_blas/matrix_view_comparison.hpp"
Expand Down
9 changes: 0 additions & 9 deletions tests/unit_tests/sparse_blas/include/test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,6 @@ void copy_host_to_buffer(sycl::queue queue, const std::vector<T> &src, sycl::buf
});
}

template <typename T>
void fill_buffer_to_0(sycl::queue queue, sycl::buffer<T, 1> dst) {
queue.submit([&](sycl::handler &cgh) {
auto dst_acc = dst.template get_access<sycl::access::mode::discard_write>(
cgh, sycl::range<1>(dst.size()));
cgh.fill(dst_acc, T(0));
});
}

template <typename OutT, typename XT, typename YT>
std::pair<OutT, OutT> swap_if_cond(bool swap, XT x, YT y) {
if (swap) {
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/sparse_blas/source/sparse_spmm_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ int test_spmm(sycl::device *dev, sparse_matrix_format_t format, intType nrows_A,
copy_host_to_buffer(main_queue, a_host, a_buf);
}
nnz = reset_nnz;
fill_buffer_to_0(main_queue, c_buf);
copy_host_to_buffer(main_queue, c_ref_host, c_buf);
set_matrix_data(main_queue, format, A_handle, nrows_A, ncols_A, nnz, index, ia_buf,
ja_buf, a_buf);

Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/sparse_blas/source/sparse_spmm_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ int test_spmm(sycl::device *dev, sparse_matrix_format_t format, intType nrows_A,
a_host.data(), reset_nnz, static_cast<std::size_t>(nrows_A));
}
if (reset_nnz > nnz) {
// Wait before freeing usm pointers
ev_spmm.wait_and_throw();
ia_usm_uptr = malloc_device_uptr<intType>(main_queue, ia_host.size());
ja_usm_uptr = malloc_device_uptr<intType>(main_queue, ja_host.size());
a_usm_uptr = malloc_device_uptr<fpType>(main_queue, a_host.size());
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/sparse_blas/source/sparse_spmv_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ int test_spmv(sycl::device *dev, sparse_matrix_format_t format, intType nrows_A,
copy_host_to_buffer(main_queue, ja_host, ja_buf);
copy_host_to_buffer(main_queue, a_host, a_buf);
}
fill_buffer_to_0(main_queue, y_buf);
copy_host_to_buffer(main_queue, y_ref_host, y_buf);
nnz = reset_nnz;
set_matrix_data(main_queue, format, A_handle, nrows_A, ncols_A, nnz, index, ia_buf,
ja_buf, a_buf);
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/sparse_blas/source/sparse_spmv_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ int test_spmv(sycl::device *dev, sparse_matrix_format_t format, intType nrows_A,
a_host.data(), reset_nnz, static_cast<std::size_t>(nrows_A));
}
if (reset_nnz > nnz) {
// Wait before freeing usm pointers
ev_spmv.wait_and_throw();
ia_usm_uptr = malloc_device_uptr<intType>(main_queue, ia_host.size());
ja_usm_uptr = malloc_device_uptr<intType>(main_queue, ja_host.size());
a_usm_uptr = malloc_device_uptr<fpType>(main_queue, a_host.size());
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/sparse_blas/source/sparse_spsv_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ int test_spsv(sycl::device *dev, sparse_matrix_format_t format, intType m, doubl
copy_host_to_buffer(main_queue, ja_host, ja_buf);
copy_host_to_buffer(main_queue, a_host, a_buf);
}
fill_buffer_to_0(main_queue, y_buf);
copy_host_to_buffer(main_queue, y_ref_host, y_buf);
nnz = reset_nnz;
set_matrix_data(main_queue, format, A_handle, m, m, nnz, index, ia_buf, ja_buf, a_buf);

Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/sparse_blas/source/sparse_spsv_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ int test_spsv(sycl::device *dev, sparse_matrix_format_t format, intType m, doubl
a_host.data(), reset_nnz, mu);
}
if (reset_nnz > nnz) {
// Wait before freeing usm pointers
ev_spsv.wait_and_throw();
ia_usm_uptr = malloc_device_uptr<intType>(main_queue, ia_host.size());
ja_usm_uptr = malloc_device_uptr<intType>(main_queue, ja_host.size());
a_usm_uptr = malloc_device_uptr<fpType>(main_queue, a_host.size());
Expand Down

0 comments on commit 2850276

Please sign in to comment.