From 77c439b081c20a0387b66e3d2a4dd22891e23b42 Mon Sep 17 00:00:00 2001 From: "romain.biessy" Date: Fri, 5 Jul 2024 16:34:15 +0200 Subject: [PATCH] [SPARSE] Add support for rocsparse backend --- CMakeLists.txt | 6 +- cmake/FindCompiler.cmake | 6 +- docs/building_the_project_with_dpcpp.rst | 18 +- docs/domains/sparse_linear_algebra.rst | 35 ++ .../run_time_dispatching/CMakeLists.txt | 3 + include/oneapi/mkl/detail/backends.hpp | 4 + include/oneapi/mkl/detail/backends_table.hpp | 6 + include/oneapi/mkl/sparse_blas.hpp | 3 + .../onemkl_sparse_blas_rocsparse.hpp | 35 ++ .../detail/rocsparse/sparse_blas_ct.hpp | 40 ++ src/config.hpp.in | 1 + src/sparse_blas/backends/CMakeLists.txt | 4 + .../backends/rocsparse/CMakeLists.txt | 81 +++ .../rocsparse/operations/rocsparse_spmm.cpp | 271 +++++++++ .../rocsparse/operations/rocsparse_spmv.cpp | 263 +++++++++ .../rocsparse/operations/rocsparse_spsv.cpp | 239 ++++++++ .../backends/rocsparse/rocsparse_error.hpp | 126 +++++ .../rocsparse/rocsparse_global_handle.hpp | 63 +++ .../backends/rocsparse/rocsparse_handles.cpp | 529 ++++++++++++++++++ .../backends/rocsparse/rocsparse_handles.hpp | 97 ++++ .../backends/rocsparse/rocsparse_helper.hpp | 160 ++++++ .../rocsparse/rocsparse_scope_handle.cpp | 125 +++++ .../rocsparse/rocsparse_scope_handle.hpp | 83 +++ .../backends/rocsparse/rocsparse_task.hpp | 187 +++++++ .../backends/rocsparse/rocsparse_wrappers.cpp | 32 ++ tests/unit_tests/CMakeLists.txt | 5 + tests/unit_tests/include/test_helper.hpp | 10 + tests/unit_tests/main_test.cpp | 3 +- .../sparse_blas/include/test_common.hpp | 34 +- .../sparse_blas/include/test_spmm.hpp | 40 +- .../sparse_blas/include/test_spmv.hpp | 51 +- .../sparse_blas/include/test_spsv.hpp | 59 +- .../sparse_blas/source/sparse_spsv_buffer.cpp | 3 + .../sparse_blas/source/sparse_spsv_usm.cpp | 3 + 34 files changed, 2536 insertions(+), 89 deletions(-) create mode 100644 include/oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp create mode 100644 include/oneapi/mkl/sparse_blas/detail/rocsparse/sparse_blas_ct.hpp create mode 100644 src/sparse_blas/backends/rocsparse/CMakeLists.txt create mode 100644 src/sparse_blas/backends/rocsparse/operations/rocsparse_spmm.cpp create mode 100644 src/sparse_blas/backends/rocsparse/operations/rocsparse_spmv.cpp create mode 100644 src/sparse_blas/backends/rocsparse/operations/rocsparse_spsv.cpp create mode 100644 src/sparse_blas/backends/rocsparse/rocsparse_error.hpp create mode 100644 src/sparse_blas/backends/rocsparse/rocsparse_global_handle.hpp create mode 100644 src/sparse_blas/backends/rocsparse/rocsparse_handles.cpp create mode 100644 src/sparse_blas/backends/rocsparse/rocsparse_handles.hpp create mode 100644 src/sparse_blas/backends/rocsparse/rocsparse_helper.hpp create mode 100644 src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.cpp create mode 100644 src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.hpp create mode 100644 src/sparse_blas/backends/rocsparse/rocsparse_task.hpp create mode 100644 src/sparse_blas/backends/rocsparse/rocsparse_wrappers.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index e6a02e792..d393b6ace 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,6 +62,7 @@ option(ENABLE_PORTFFT_BACKEND "Enable the portFFT DFT backend for the DFT interf # sparse option(ENABLE_CUSPARSE_BACKEND "Enable the cuSPARSE backend for the SPARSE_BLAS interface" OFF) +option(ENABLE_ROCSPARSE_BACKEND "Enable the rocSPARSE backend for the SPARSE_BLAS interface" OFF) set(ONEMKL_SYCL_IMPLEMENTATION "dpc++" CACHE STRING "Name of the SYCL compiler") set(HIP_TARGETS "" CACHE STRING "Target HIP architectures") @@ -106,7 +107,8 @@ if(ENABLE_MKLGPU_BACKEND endif() if(ENABLE_MKLCPU_BACKEND OR ENABLE_MKLGPU_BACKEND - OR ENABLE_CUSPARSE_BACKEND) + OR ENABLE_CUSPARSE_BACKEND + OR ENABLE_ROCSPARSE_BACKEND) list(APPEND DOMAINS_LIST "sparse_blas") endif() @@ -134,7 +136,7 @@ if(CMAKE_CXX_COMPILER OR NOT ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++") endif() else() if(ENABLE_CUBLAS_BACKEND OR ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_CUFFT_BACKEND OR ENABLE_CUSPARSE_BACKEND - OR ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCFFT_BACKEND) + OR ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCFFT_BACKEND OR ENABLE_ROCSPARSE_BACKEND) set(CMAKE_CXX_COMPILER "clang++") elseif(ENABLE_MKLGPU_BACKEND) if(UNIX) diff --git a/cmake/FindCompiler.cmake b/cmake/FindCompiler.cmake index 61ea894d4..c68e79d02 100644 --- a/cmake/FindCompiler.cmake +++ b/cmake/FindCompiler.cmake @@ -45,7 +45,7 @@ if(is_dpcpp) list(APPEND UNIX_INTERFACE_LINK_OPTIONS -fsycl-targets=nvptx64-nvidia-cuda) elseif(ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND - OR ENABLE_ROCSOLVER_BACKEND) + OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCSPARSE_BACKEND) list(APPEND UNIX_INTERFACE_COMPILE_OPTIONS -fsycl-targets=amdgcn-amd-amdhsa -fsycl-unnamed-lambda -Xsycl-target-backend --offload-arch=${HIP_TARGETS}) @@ -54,7 +54,7 @@ if(is_dpcpp) --offload-arch=${HIP_TARGETS}) endif() if(ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_CUSPARSE_BACKEND OR ENABLE_ROCBLAS_BACKEND - OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND) + OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCSPARSE_BACKEND) set_target_properties(ONEMKL::SYCL::SYCL PROPERTIES INTERFACE_COMPILE_OPTIONS "${UNIX_INTERFACE_COMPILE_OPTIONS}" INTERFACE_LINK_OPTIONS "${UNIX_INTERFACE_LINK_OPTIONS}" @@ -71,7 +71,7 @@ if(is_dpcpp) INTERFACE_LINK_LIBRARIES ${SYCL_LIBRARY}) endif() - if(ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND) + if(ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCSPARSE_BACKEND) # Allow find_package(HIP) to find the correct path to libclang_rt.builtins.a # HIP's CMake uses the command `${HIP_CXX_COMPILER} -print-libgcc-file-name --rtlib=compiler-rt` to find this path. # This can print a non-existing file if the compiler used is icpx. diff --git a/docs/building_the_project_with_dpcpp.rst b/docs/building_the_project_with_dpcpp.rst index 44d930e7e..55fd857d0 100644 --- a/docs/building_the_project_with_dpcpp.rst +++ b/docs/building_the_project_with_dpcpp.rst @@ -122,6 +122,9 @@ The most important supported build options are: * - ENABLE_ROCRAND_BACKEND - True, False - False + * - ENABLE_ROCSPARSE_BACKEND + - True, False + - False * - ENABLE_MKLCPU_THREAD_TBB - True, False - True @@ -198,14 +201,14 @@ Building for ROCm ^^^^^^^^^^^^^^^^^ The ROCm backends can be enabled with ``ENABLE_ROCBLAS_BACKEND``, -``ENABLE_ROCFFT_BACKEND``, ``ENABLE_ROCSOLVER_BACKEND`` and -``ENABLE_ROCRAND_BACKEND``. +``ENABLE_ROCFFT_BACKEND``, ``ENABLE_ROCSOLVER_BACKEND``, +``ENABLE_ROCRAND_BACKEND``, and ``ENABLE_ROCSPARSE_BACKEND``. -For *RocBLAS*, *RocSOLVER* and *RocRAND*, the target device architecture must be -set. This can be set with using the ``HIP_TARGETS`` parameter. For example, to -enable a build for MI200 series GPUs, ``-DHIP_TARGETS=gfx90a`` should be set. -Currently, DPC++ can only build for a single HIP target at a time. This may -change in future versions. +For *RocBLAS*, *RocSOLVER*, *RocRAND*, and *RocSPARSE*, the target device +architecture must be set. This can be set with using the ``HIP_TARGETS`` +parameter. For example, to enable a build for MI200 series GPUs, +``-DHIP_TARGETS=gfx90a`` should be set. Currently, DPC++ can only build for a +single HIP target at a time. This may change in future versions. A few often-used architectures are listed below: @@ -380,6 +383,7 @@ disabled: -DENABLE_ROCFFT_BACKEND=True \ -DENABLE_ROCBLAS_BACKEND=True \ -DENABLE_ROCSOLVER_BACKEND=True \ + -DENABLE_ROCSPARSE_BACKEND=True \ -DHIP_TARGETS=gfx90a \ -DBUILD_FUNCTIONAL_TESTS=False diff --git a/docs/domains/sparse_linear_algebra.rst b/docs/domains/sparse_linear_algebra.rst index 9108b7b7a..1b8ae82e1 100644 --- a/docs/domains/sparse_linear_algebra.rst +++ b/docs/domains/sparse_linear_algebra.rst @@ -48,6 +48,23 @@ Currently known limitations: `_. +rocSPARSE backend +---------------- + +Currently known limitations: + +- Using ``spmv`` with a ``type_view`` other than ``matrix_descr::general`` will + throw an ``oneapi::mkl::unimplemented`` exception. +- The COO format requires the indices to be sorted by row then by column. It is + not required to set the property + ``oneapi::mkl::sparse::matrix_property::sorted`` to a sparse matrix handle. + See the `rocSPARSE documentation + `_. +- The same sparse matrix handle cannot be reused for multiple ``spmm`` or + ``spmv`` operations. See `#332 + `_. + + Operation algorithms mapping ---------------------------- @@ -73,41 +90,50 @@ spmm - Default algorithm. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMM_ALG_DEFAULT`` + | rocSPARSE: ``rocsparse_spmm_alg_default`` * - ``no_optimize_alg`` - Default algorithm but may skip some optimizations. Useful only if an operation with the same configuration is run once. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMM_ALG_DEFAULT`` + | rocSPARSE: ``rocsparse_spmm_alg_default`` * - ``coo_alg1`` - Should provide best performance for COO format, small ``nnz`` and column-major layout. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMM_COO_ALG1`` + | rocSPARSE: ``rocsparse_spmm_alg_coo_segmented`` * - ``coo_alg2`` - Should provide best performance for COO format and column-major layout. Produces deterministic results. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMM_COO_ALG2`` + | rocSPARSE: ``rocsparse_spmm_alg_coo_atomic`` * - ``coo_alg3`` - Should provide best performance for COO format and large ``nnz``. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMM_COO_ALG3`` + | rocSPARSE: ``rocsparse_spmm_alg_coo_segmented_atomic`` * - ``coo_alg4`` - Should provide best performance for COO format and row-major layout. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMM_COO_ALG4`` + | rocSPARSE: none * - ``csr_alg1`` - Should provide best performance for CSR format and column-major layout. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMM_CSR_ALG1`` + | rocSPARSE: ``rocsparse_spmm_alg_csr`` * - ``csr_alg2`` - Should provide best performance for CSR format and row-major layout. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMM_CSR_ALG2`` + | rocSPARSE: ``rocsparse_spmm_alg_csr_row_split`` * - ``csr_alg3`` - Deterministic algorithm for CSR format. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMM_CSR_ALG3`` + | rocSPARSE: ``rocsparse_spmm_alg_csr_merge`` spmv @@ -124,31 +150,38 @@ spmv - Default algorithm. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMV_ALG_DEFAULT`` + | rocSPARSE: ``rocsparse_spmv_alg_default`` * - ``no_optimize_alg`` - Default algorithm but may skip some optimizations. Useful only if an operation with the same configuration is run once. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMM_ALG_DEFAULT`` + | rocSPARSE: ``rocsparse_spmv_alg_default`` * - ``coo_alg1`` - Default algorithm for COO format. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMV_COO_ALG1`` + | rocSPARSE: ``rocsparse_spmv_alg_coo`` * - ``coo_alg2`` - Deterministic algorithm for COO format. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMV_COO_ALG2`` + | rocSPARSE: ``rocsparse_spmv_alg_coo_atomic`` * - ``csr_alg1`` - Default algorithm for CSR format. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMV_CSR_ALG1`` + | rocSPARSE: ``rocsparse_spmv_alg_csr_adaptive`` * - ``csr_alg2`` - Deterministic algorithm for CSR format. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMV_CSR_ALG2`` + | rocSPARSE: ``rocsparse_spmv_alg_csr_stream`` * - ``csr_alg3`` - LRB variant of the algorithm for CSR format. - | MKL: none | cuSPARSE: none + | rocSPARSE: ``rocsparse_spmv_alg_csr_lrb`` spsv @@ -165,8 +198,10 @@ spsv - Default algorithm. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMM_ALG_DEFAULT`` + | rocSPARSE: ``rocsparse_spsv_alg_default`` * - ``no_optimize_alg`` - Default algorithm but may skip some optimizations. Useful only if an operation with the same configuration is run once. - | MKL: none | cuSPARSE: ``CUSPARSE_SPMM_ALG_DEFAULT`` + | rocSPARSE: ``rocsparse_spsv_alg_default`` diff --git a/examples/sparse_blas/run_time_dispatching/CMakeLists.txt b/examples/sparse_blas/run_time_dispatching/CMakeLists.txt index f09daf819..fb425ef16 100644 --- a/examples/sparse_blas/run_time_dispatching/CMakeLists.txt +++ b/examples/sparse_blas/run_time_dispatching/CMakeLists.txt @@ -36,6 +36,9 @@ endif() if(ENABLE_CUSPARSE_BACKEND) list(APPEND DEVICE_FILTERS "cuda:gpu") endif() +if(ENABLE_ROCSPARSE_BACKEND) + list(APPEND DEVICE_FILTERS "hip:gpu") +endif() message(STATUS "ONEAPI_DEVICE_SELECTOR will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples") diff --git a/include/oneapi/mkl/detail/backends.hpp b/include/oneapi/mkl/detail/backends.hpp index ded06c2e8..98c5e7053 100644 --- a/include/oneapi/mkl/detail/backends.hpp +++ b/include/oneapi/mkl/detail/backends.hpp @@ -41,11 +41,13 @@ enum class backend { rocfft, portfft, cusparse, + rocsparse, unsupported }; typedef std::map backendmap; +// clang-format off static backendmap backend_map = { { backend::mklcpu, "mklcpu" }, { backend::mklgpu, "mklgpu" }, { backend::cublas, "cublas" }, @@ -60,7 +62,9 @@ static backendmap backend_map = { { backend::mklcpu, "mklcpu" }, { backend::rocfft, "rocfft" }, { backend::portfft, "portfft" }, { backend::cusparse, "cusparse" }, + { backend::rocsparse, "rocsparse" }, { backend::unsupported, "unsupported" } }; +// clang-format on } //namespace mkl } //namespace oneapi diff --git a/include/oneapi/mkl/detail/backends_table.hpp b/include/oneapi/mkl/detail/backends_table.hpp index 8a79c5c06..6a5d57280 100644 --- a/include/oneapi/mkl/detail/backends_table.hpp +++ b/include/oneapi/mkl/detail/backends_table.hpp @@ -192,6 +192,12 @@ static std::map>> libraries = { #ifdef ENABLE_CUSPARSE_BACKEND LIB_NAME("sparse_blas_cusparse") +#endif + } }, + { device::amdgpu, + { +#ifdef ENABLE_ROCSPARSE_BACKEND + LIB_NAME("sparse_blas_rocsparse") #endif } } } }, }; diff --git a/include/oneapi/mkl/sparse_blas.hpp b/include/oneapi/mkl/sparse_blas.hpp index 73e6753c7..7ebf76ae6 100644 --- a/include/oneapi/mkl/sparse_blas.hpp +++ b/include/oneapi/mkl/sparse_blas.hpp @@ -37,6 +37,9 @@ #ifdef ENABLE_CUSPARSE_BACKEND #include "sparse_blas/detail/cusparse/sparse_blas_ct.hpp" #endif +#ifdef ENABLE_ROCSPARSE_BACKEND +#include "sparse_blas/detail/rocsparse/sparse_blas_ct.hpp" +#endif #include "sparse_blas/detail/sparse_blas_rt.hpp" diff --git a/include/oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp b/include/oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp new file mode 100644 index 000000000..57cf5487d --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp @@ -0,0 +1,35 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_ROCSPARSE_ONEMKL_SPARSE_BLAS_ROCSPARSE_HPP_ +#define _ONEMKL_SPARSE_BLAS_DETAIL_ROCSPARSE_ONEMKL_SPARSE_BLAS_ROCSPARSE_HPP_ + +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/sparse_blas/detail/helper_types.hpp" +#include "oneapi/mkl/sparse_blas/types.hpp" + +namespace oneapi::mkl::sparse::rocsparse { + +namespace detail = oneapi::mkl::sparse::detail; + +#include "oneapi/mkl/sparse_blas/detail/onemkl_sparse_blas_backends.hxx" + +} // namespace oneapi::mkl::sparse::rocsparse + +#endif // _ONEMKL_SPARSE_BLAS_DETAIL_ROCSPARSE_ONEMKL_SPARSE_BLAS_ROCSPARSE_HPP_ diff --git a/include/oneapi/mkl/sparse_blas/detail/rocsparse/sparse_blas_ct.hpp b/include/oneapi/mkl/sparse_blas/detail/rocsparse/sparse_blas_ct.hpp new file mode 100644 index 000000000..645230fa6 --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/rocsparse/sparse_blas_ct.hpp @@ -0,0 +1,40 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_ROCSPARSE_SPARSE_BLAS_CT_HPP_ +#define _ONEMKL_SPARSE_BLAS_DETAIL_ROCSPARSE_SPARSE_BLAS_CT_HPP_ + +#include "oneapi/mkl/detail/backends.hpp" +#include "oneapi/mkl/detail/backend_selector.hpp" + +#include "onemkl_sparse_blas_rocsparse.hpp" + +namespace oneapi { +namespace mkl { +namespace sparse { + +#define BACKEND rocsparse +#include "oneapi/mkl/sparse_blas/detail/sparse_blas_ct.hxx" +#undef BACKEND + +} //namespace sparse +} //namespace mkl +} //namespace oneapi + +#endif // _ONEMKL_SPARSE_BLAS_DETAIL_ROCSPARSE_SPARSE_BLAS_CT_HPP_ diff --git a/src/config.hpp.in b/src/config.hpp.in index fd55006a6..c67c3cc60 100644 --- a/src/config.hpp.in +++ b/src/config.hpp.in @@ -38,6 +38,7 @@ #cmakedefine ENABLE_ROCFFT_BACKEND #cmakedefine ENABLE_ROCRAND_BACKEND #cmakedefine ENABLE_ROCSOLVER_BACKEND +#cmakedefine ENABLE_ROCSPARSE_BACKEND #cmakedefine BUILD_SHARED_LIBS #cmakedefine REF_BLAS_LIBNAME "@REF_BLAS_LIBNAME@" #cmakedefine REF_CBLAS_LIBNAME "@REF_CBLAS_LIBNAME@" diff --git a/src/sparse_blas/backends/CMakeLists.txt b/src/sparse_blas/backends/CMakeLists.txt index baae9445d..4ee6b1dc1 100644 --- a/src/sparse_blas/backends/CMakeLists.txt +++ b/src/sparse_blas/backends/CMakeLists.txt @@ -31,3 +31,7 @@ endif() if(ENABLE_CUSPARSE_BACKEND) add_subdirectory(cusparse) endif() + +if(ENABLE_ROCSPARSE_BACKEND) + add_subdirectory(rocsparse) +endif() diff --git a/src/sparse_blas/backends/rocsparse/CMakeLists.txt b/src/sparse_blas/backends/rocsparse/CMakeLists.txt new file mode 100644 index 000000000..af26b50eb --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/CMakeLists.txt @@ -0,0 +1,81 @@ +#=============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +set(LIB_NAME onemkl_sparse_blas_rocsparse) +set(LIB_OBJ ${LIB_NAME}_obj) + +include(WarningsUtils) + +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT + rocsparse_handles.cpp + rocsparse_scope_handle.cpp + operations/rocsparse_spmm.cpp + operations/rocsparse_spmv.cpp + operations/rocsparse_spsv.cpp + $<$: rocsparse_wrappers.cpp> +) +add_dependencies(onemkl_backend_libs_sparse_blas ${LIB_NAME}) + +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src + ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} +) + +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) + +find_package(HIP REQUIRED) +find_package(rocsparse REQUIRED) + +target_link_libraries(${LIB_OBJ} PRIVATE hip::host roc::rocsparse) + +target_link_libraries(${LIB_OBJ} + PUBLIC ONEMKL::SYCL::SYCL + PRIVATE onemkl_warnings +) + +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON +) +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +#Set oneMKL libraries as not transitive for dynamic +if(BUILD_SHARED_LIBS) + set_target_properties(${LIB_NAME} PROPERTIES + INTERFACE_LINK_LIBRARIES ONEMKL::SYCL::SYCL + ) +endif() + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmm.cpp b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmm.cpp new file mode 100644 index 000000000..0d25c1202 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmm.cpp @@ -0,0 +1,271 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp" + +#include "sparse_blas/backends/rocsparse/rocsparse_error.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_helper.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_task.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_handles.hpp" +#include "sparse_blas/common_op_verification.hpp" +#include "sparse_blas/macros.hpp" +#include "sparse_blas/sycl_helper.hpp" + +namespace oneapi::mkl::sparse { + +// Complete the definition of the incomplete type +struct spmm_descr { + detail::generic_container workspace; + std::size_t temp_buffer_size = 0; +}; + +} // namespace oneapi::mkl::sparse + +namespace oneapi::mkl::sparse::rocsparse { + +void init_spmm_descr(sycl::queue& /*queue*/, spmm_descr_t* p_spmm_descr) { + *p_spmm_descr = new spmm_descr(); +} + +sycl::event release_spmm_descr(sycl::queue& queue, spmm_descr_t spmm_descr, + const std::vector& dependencies) { + return detail::submit_release(queue, spmm_descr, dependencies); +} + +inline auto get_roc_spmm_alg(spmm_alg alg) { + switch (alg) { + case spmm_alg::coo_alg1: return rocsparse_spmm_alg_coo_segmented; + case spmm_alg::coo_alg2: return rocsparse_spmm_alg_coo_atomic; + case spmm_alg::coo_alg3: return rocsparse_spmm_alg_coo_segmented_atomic; + case spmm_alg::csr_alg1: return rocsparse_spmm_alg_csr; + case spmm_alg::csr_alg2: return rocsparse_spmm_alg_csr_row_split; + case spmm_alg::csr_alg3: return rocsparse_spmm_alg_csr_merge; + default: return rocsparse_spmm_alg_default; + } +} + +inline void fallback_alg_if_needed(oneapi::mkl::sparse::spmm_alg& alg, oneapi::mkl::transpose opA, + oneapi::mkl::transpose opB) { + if (alg == oneapi::mkl::sparse::spmm_alg::csr_alg3 && + (opA != oneapi::mkl::transpose::nontrans || opB == oneapi::mkl::transpose::conjtrans)) { + // Avoid warnings printed on std::cerr + alg = oneapi::mkl::sparse::spmm_alg::default_alg; + } +} + +void spmm_buffer_size(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB, + const void* alpha, oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void* beta, + oneapi::mkl::sparse::dense_matrix_handle_t C_handle, + oneapi::mkl::sparse::spmm_alg alg, + oneapi::mkl::sparse::spmm_descr_t spmm_descr, std::size_t& temp_buffer_size) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + detail::check_valid_spmm_common(__func__, A_view, A_handle, B_handle, C_handle, + is_alpha_host_accessible, is_beta_host_accessible); + A_handle->throw_if_already_used(__func__); + fallback_alg_if_needed(alg, opA, opB); + auto functor = [=, &temp_buffer_size](RocsparseScopedContextHandler& sc) { + auto roc_handle = sc.get_handle(queue); + auto roc_a = A_handle->backend_handle; + auto roc_b = B_handle->backend_handle; + auto roc_c = C_handle->backend_handle; + auto roc_op_a = get_roc_operation(opA); + auto roc_op_b = get_roc_operation(opB); + auto roc_type = get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = get_roc_spmm_alg(alg); + set_pointer_mode(roc_handle, is_alpha_host_accessible); + auto status = rocsparse_spmm(roc_handle, roc_op_a, roc_op_b, alpha, roc_a, roc_b, beta, + roc_c, roc_type, roc_alg, rocsparse_spmm_stage_buffer_size, + &temp_buffer_size, nullptr); + check_status(status, __func__); + }; + auto event = dispatch_submit(__func__, queue, functor, A_handle, B_handle, C_handle); + event.wait_and_throw(); + spmm_descr->temp_buffer_size = temp_buffer_size; +} + +void spmm_optimize_impl(rocsparse_handle roc_handle, oneapi::mkl::transpose opA, + oneapi::mkl::transpose opB, const void* alpha, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void* beta, + oneapi::mkl::sparse::dense_matrix_handle_t C_handle, + oneapi::mkl::sparse::spmm_alg alg, std::size_t buffer_size, + void* workspace_ptr, bool is_alpha_host_accessible) { + auto roc_a = A_handle->backend_handle; + auto roc_b = B_handle->backend_handle; + auto roc_c = C_handle->backend_handle; + auto roc_op_a = get_roc_operation(opA); + auto roc_op_b = get_roc_operation(opB); + auto roc_type = get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = get_roc_spmm_alg(alg); + set_pointer_mode(roc_handle, is_alpha_host_accessible); + // rocsparse_spmm_stage_preprocess stage is blocking + auto status = + rocsparse_spmm(roc_handle, roc_op_a, roc_op_b, alpha, roc_a, roc_b, beta, roc_c, roc_type, + roc_alg, rocsparse_spmm_stage_preprocess, &buffer_size, workspace_ptr); + check_status(status, "optimize_spmm"); +} + +void spmm_optimize(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB, + const void* alpha, oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void* beta, + oneapi::mkl::sparse::dense_matrix_handle_t C_handle, + oneapi::mkl::sparse::spmm_alg alg, oneapi::mkl::sparse::spmm_descr_t spmm_descr, + sycl::buffer workspace) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + detail::check_valid_spmm_common(__func__, A_view, A_handle, B_handle, C_handle, + is_alpha_host_accessible, is_beta_host_accessible); + A_handle->throw_if_already_used(__func__); + if (!A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__func__); + } + // Copy the buffer to extend its lifetime until the descriptor is free'd. + spmm_descr->workspace.set_buffer_untyped(workspace); + if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) { + return; + } + fallback_alg_if_needed(alg, opA, opB); + std::size_t buffer_size = spmm_descr->temp_buffer_size; + + if (buffer_size > 0) { + auto functor = [=](RocsparseScopedContextHandler& sc, + sycl::accessor workspace_acc) { + auto roc_handle = sc.get_handle(queue); + auto workspace_ptr = sc.get_mem(workspace_acc); + spmm_optimize_impl(roc_handle, opA, opB, alpha, A_handle, B_handle, beta, C_handle, alg, + buffer_size, workspace_ptr, is_alpha_host_accessible); + }; + + // The accessor can only be bound to the cgh if the buffer size is + // greater than 0 + sycl::accessor workspace_placeholder_acc(workspace); + auto event = dispatch_submit(__func__, queue, functor, A_handle, workspace_placeholder_acc, + B_handle, C_handle); + event.wait_and_throw(); + } + else { + auto functor = [=](RocsparseScopedContextHandler& sc) { + auto roc_handle = sc.get_handle(queue); + spmm_optimize_impl(roc_handle, opA, opB, alpha, A_handle, B_handle, beta, C_handle, alg, + buffer_size, nullptr, is_alpha_host_accessible); + }; + + auto event = dispatch_submit(__func__, queue, functor, A_handle, B_handle, C_handle); + event.wait_and_throw(); + } +} + +sycl::event spmm_optimize(sycl::queue& queue, oneapi::mkl::transpose opA, + oneapi::mkl::transpose opB, const void* alpha, + oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void* beta, + oneapi::mkl::sparse::dense_matrix_handle_t C_handle, + oneapi::mkl::sparse::spmm_alg alg, + oneapi::mkl::sparse::spmm_descr_t spmm_descr, void* workspace, + const std::vector& dependencies) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + detail::check_valid_spmm_common(__func__, A_view, A_handle, B_handle, C_handle, + is_alpha_host_accessible, is_beta_host_accessible); + A_handle->throw_if_already_used(__func__); + if (A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__func__); + } + spmm_descr->workspace.usm_ptr = workspace; + if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) { + return detail::collapse_dependencies(queue, dependencies); + } + fallback_alg_if_needed(alg, opA, opB); + std::size_t buffer_size = spmm_descr->temp_buffer_size; + auto functor = [=](RocsparseScopedContextHandler& sc) { + auto roc_handle = sc.get_handle(queue); + spmm_optimize_impl(roc_handle, opA, opB, alpha, A_handle, B_handle, beta, C_handle, alg, + buffer_size, workspace, is_alpha_host_accessible); + }; + + return dispatch_submit(__func__, queue, dependencies, functor, A_handle, B_handle, C_handle); +} + +sycl::event spmm(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB, + const void* alpha, oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_matrix_handle_t B_handle, const void* beta, + oneapi::mkl::sparse::dense_matrix_handle_t C_handle, + oneapi::mkl::sparse::spmm_alg alg, oneapi::mkl::sparse::spmm_descr_t spmm_descr, + const std::vector& dependencies) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + detail::check_valid_spmm_common(__func__, A_view, A_handle, B_handle, C_handle, + is_alpha_host_accessible, is_beta_host_accessible); + if (A_handle->all_use_buffer() != spmm_descr->workspace.use_buffer()) { + detail::throw_incompatible_container(__func__); + } + A_handle->throw_if_already_used(__func__); + A_handle->mark_used(); + fallback_alg_if_needed(alg, opA, opB); + auto& buffer_size = spmm_descr->temp_buffer_size; + auto compute_functor = [=, &buffer_size](RocsparseScopedContextHandler& sc, + void* workspace_ptr) { + auto [roc_handle, roc_stream] = sc.get_handle_and_stream(queue); + auto roc_a = A_handle->backend_handle; + auto roc_b = B_handle->backend_handle; + auto roc_c = C_handle->backend_handle; + auto roc_op_a = get_roc_operation(opA); + auto roc_op_b = get_roc_operation(opB); + auto roc_type = get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = get_roc_spmm_alg(alg); + set_pointer_mode(roc_handle, is_alpha_host_accessible); + auto status = rocsparse_spmm(roc_handle, roc_op_a, roc_op_b, alpha, roc_a, roc_b, beta, + roc_c, roc_type, roc_alg, rocsparse_spmm_stage_compute, + &buffer_size, workspace_ptr); + check_status(status, __func__); + HIP_ERROR_FUNC(hipStreamSynchronize, roc_stream); + }; + if (A_handle->all_use_buffer() && buffer_size > 0) { + // The accessor can only be bound to the cgh if the buffer size is + // greater than 0 + auto functor_buffer = [=](RocsparseScopedContextHandler& sc, + sycl::accessor workspace_acc) { + auto workspace_ptr = sc.get_mem(workspace_acc); + compute_functor(sc, workspace_ptr); + }; + sycl::accessor workspace_placeholder_acc( + spmm_descr->workspace.get_buffer()); + return dispatch_submit(__func__, queue, dependencies, functor_buffer, A_handle, + workspace_placeholder_acc, B_handle, C_handle); + } + else { + // The same dispatch_submit can be used for USM or buffers if no + // workspace accessor is needed, workspace_ptr will be a nullptr in the + // latter case. + auto workspace_ptr = spmm_descr->workspace.usm_ptr; + auto functor_usm = [=](RocsparseScopedContextHandler& sc) { + compute_functor(sc, workspace_ptr); + }; + return dispatch_submit(__func__, queue, dependencies, functor_usm, A_handle, B_handle, + C_handle); + } +} + +} // namespace oneapi::mkl::sparse::rocsparse diff --git a/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmv.cpp b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmv.cpp new file mode 100644 index 000000000..f3e91bfd1 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spmv.cpp @@ -0,0 +1,263 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp" + +#include "sparse_blas/backends/rocsparse/rocsparse_error.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_helper.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_task.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_handles.hpp" +#include "sparse_blas/common_op_verification.hpp" +#include "sparse_blas/macros.hpp" +#include "sparse_blas/sycl_helper.hpp" + +namespace oneapi::mkl::sparse { + +// Complete the definition of the incomplete type +struct spmv_descr { + detail::generic_container workspace; + std::size_t temp_buffer_size = 0; +}; + +} // namespace oneapi::mkl::sparse + +namespace oneapi::mkl::sparse::rocsparse { + +void init_spmv_descr(sycl::queue & /*queue*/, spmv_descr_t *p_spmv_descr) { + *p_spmv_descr = new spmv_descr(); +} + +sycl::event release_spmv_descr(sycl::queue &queue, spmv_descr_t spmv_descr, + const std::vector &dependencies) { + return detail::submit_release(queue, spmv_descr, dependencies); +} + +inline auto get_roc_spmv_alg(spmv_alg alg) { + switch (alg) { + case spmv_alg::coo_alg1: return rocsparse_spmv_alg_coo; + case spmv_alg::coo_alg2: return rocsparse_spmv_alg_coo_atomic; + case spmv_alg::csr_alg1: return rocsparse_spmv_alg_csr_adaptive; + case spmv_alg::csr_alg2: return rocsparse_spmv_alg_csr_stream; + case spmv_alg::csr_alg3: return rocsparse_spmv_alg_csr_lrb; + default: return rocsparse_spmv_alg_default; + } +} + +void check_valid_spmv(const std::string &function_name, oneapi::mkl::transpose opA, + oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x_handle, + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + bool is_alpha_host_accessible, bool is_beta_host_accessible) { + detail::check_valid_spmv_common(function_name, opA, A_view, A_handle, x_handle, y_handle, + is_alpha_host_accessible, is_beta_host_accessible); + A_handle->throw_if_already_used(__func__); + if (A_view.type_view != oneapi::mkl::sparse::matrix_descr::general) { + throw mkl::unimplemented( + "sparse_blas", function_name, + "The backend does not support spmv with a `type_view` other than `matrix_descr::general`."); + } +} + +void spmv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha, + oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void *beta, + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + oneapi::mkl::sparse::spmv_alg alg, + oneapi::mkl::sparse::spmv_descr_t spmv_descr, std::size_t &temp_buffer_size) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + is_beta_host_accessible); + auto functor = [=, &temp_buffer_size](RocsparseScopedContextHandler &sc) { + auto roc_handle = sc.get_handle(queue); + auto roc_a = A_handle->backend_handle; + auto roc_x = x_handle->backend_handle; + auto roc_y = y_handle->backend_handle; + auto roc_op = get_roc_operation(opA); + auto roc_type = get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = get_roc_spmv_alg(alg); + set_pointer_mode(roc_handle, is_alpha_host_accessible); + auto status = + rocsparse_spmv(roc_handle, roc_op, alpha, roc_a, roc_x, beta, roc_y, roc_type, roc_alg, + rocsparse_spmv_stage_buffer_size, &temp_buffer_size, nullptr); + check_status(status, __func__); + }; + auto event = dispatch_submit(__func__, queue, functor, A_handle, x_handle, y_handle); + event.wait_and_throw(); + spmv_descr->temp_buffer_size = temp_buffer_size; +} + +void spmv_optimize_impl(rocsparse_handle roc_handle, oneapi::mkl::transpose opA, const void *alpha, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void *beta, + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + oneapi::mkl::sparse::spmv_alg alg, std::size_t buffer_size, + void *workspace_ptr, bool is_alpha_host_accessible) { + auto roc_a = A_handle->backend_handle; + auto roc_x = x_handle->backend_handle; + auto roc_y = y_handle->backend_handle; + auto roc_op = get_roc_operation(opA); + auto roc_type = get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = get_roc_spmv_alg(alg); + set_pointer_mode(roc_handle, is_alpha_host_accessible); + // rocsparse_spmv_stage_preprocess stage is blocking + auto status = + rocsparse_spmv(roc_handle, roc_op, alpha, roc_a, roc_x, beta, roc_y, roc_type, roc_alg, + rocsparse_spmv_stage_preprocess, &buffer_size, workspace_ptr); + check_status(status, "optimize_spmv"); +} + +void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha, + oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void *beta, + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t spmv_descr, + sycl::buffer workspace) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + is_beta_host_accessible); + if (!A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__func__); + } + // Copy the buffer to extend its lifetime until the descriptor is free'd. + spmv_descr->workspace.set_buffer_untyped(workspace); + if (alg == oneapi::mkl::sparse::spmv_alg::no_optimize_alg) { + return; + } + std::size_t buffer_size = spmv_descr->temp_buffer_size; + if (buffer_size > 0) { + auto functor = [=](RocsparseScopedContextHandler &sc, + sycl::accessor workspace_acc) { + auto roc_handle = sc.get_handle(queue); + auto workspace_ptr = sc.get_mem(workspace_acc); + spmv_optimize_impl(roc_handle, opA, alpha, A_handle, x_handle, beta, y_handle, alg, + buffer_size, workspace_ptr, is_alpha_host_accessible); + }; + + // The accessor can only be bound to the cgh if the buffer size is + // greater than 0 + sycl::accessor workspace_placeholder_acc(workspace); + auto event = dispatch_submit(__func__, queue, functor, A_handle, workspace_placeholder_acc, + x_handle, y_handle); + event.wait_and_throw(); + } + else { + auto functor = [=](RocsparseScopedContextHandler &sc) { + auto roc_handle = sc.get_handle(queue); + spmv_optimize_impl(roc_handle, opA, alpha, A_handle, x_handle, beta, y_handle, alg, + buffer_size, nullptr, is_alpha_host_accessible); + }; + + auto event = dispatch_submit(__func__, queue, functor, A_handle, x_handle, y_handle); + event.wait_and_throw(); + } +} + +sycl::event spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha, + oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void *beta, + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + oneapi::mkl::sparse::spmv_alg alg, + oneapi::mkl::sparse::spmv_descr_t spmv_descr, void *workspace, + const std::vector &dependencies) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + is_beta_host_accessible); + if (A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__func__); + } + spmv_descr->workspace.usm_ptr = workspace; + if (alg == oneapi::mkl::sparse::spmv_alg::no_optimize_alg) { + return detail::collapse_dependencies(queue, dependencies); + } + std::size_t buffer_size = spmv_descr->temp_buffer_size; + auto functor = [=](RocsparseScopedContextHandler &sc) { + auto roc_handle = sc.get_handle(queue); + spmv_optimize_impl(roc_handle, opA, alpha, A_handle, x_handle, beta, y_handle, alg, + buffer_size, workspace, is_alpha_host_accessible); + }; + + return dispatch_submit(__func__, queue, dependencies, functor, A_handle, x_handle, y_handle); +} + +sycl::event spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha, + oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void *beta, + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t spmv_descr, + const std::vector &dependencies) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta); + check_valid_spmv(__func__, opA, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible, + is_beta_host_accessible); + if (A_handle->all_use_buffer() != spmv_descr->workspace.use_buffer()) { + detail::throw_incompatible_container(__func__); + } + A_handle->mark_used(); + auto &buffer_size = spmv_descr->temp_buffer_size; + auto compute_functor = [=, &buffer_size](RocsparseScopedContextHandler &sc, + void *workspace_ptr) { + auto [roc_handle, roc_stream] = sc.get_handle_and_stream(queue); + auto roc_a = A_handle->backend_handle; + auto roc_x = x_handle->backend_handle; + auto roc_y = y_handle->backend_handle; + auto roc_op = get_roc_operation(opA); + auto roc_type = get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = get_roc_spmv_alg(alg); + set_pointer_mode(roc_handle, is_alpha_host_accessible); + auto status = + rocsparse_spmv(roc_handle, roc_op, alpha, roc_a, roc_x, beta, roc_y, roc_type, roc_alg, + rocsparse_spmv_stage_compute, &buffer_size, workspace_ptr); + check_status(status, __func__); + HIP_ERROR_FUNC(hipStreamSynchronize, roc_stream); + }; + if (A_handle->all_use_buffer() && buffer_size > 0) { + // The accessor can only be bound to the cgh if the buffer size is + // greater than 0 + auto functor_buffer = [=](RocsparseScopedContextHandler &sc, + sycl::accessor workspace_acc) { + auto workspace_ptr = sc.get_mem(workspace_acc); + compute_functor(sc, workspace_ptr); + }; + sycl::accessor workspace_placeholder_acc( + spmv_descr->workspace.get_buffer()); + return dispatch_submit(__func__, queue, dependencies, functor_buffer, A_handle, + workspace_placeholder_acc, x_handle, y_handle); + } + else { + // The same dispatch_submit can be used for USM or buffers if no + // workspace accessor is needed, workspace_ptr will be a nullptr in the + // latter case. + auto workspace_ptr = spmv_descr->workspace.usm_ptr; + auto functor_usm = [=](RocsparseScopedContextHandler &sc) { + compute_functor(sc, workspace_ptr); + }; + return dispatch_submit(__func__, queue, dependencies, functor_usm, A_handle, x_handle, + y_handle); + } +} + +} // namespace oneapi::mkl::sparse::rocsparse diff --git a/src/sparse_blas/backends/rocsparse/operations/rocsparse_spsv.cpp b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spsv.cpp new file mode 100644 index 000000000..93929b21e --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/operations/rocsparse_spsv.cpp @@ -0,0 +1,239 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp" + +#include "sparse_blas/backends/rocsparse/rocsparse_error.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_helper.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_task.hpp" +#include "sparse_blas/backends/rocsparse/rocsparse_handles.hpp" +#include "sparse_blas/common_op_verification.hpp" +#include "sparse_blas/macros.hpp" +#include "sparse_blas/sycl_helper.hpp" + +namespace oneapi::mkl::sparse { + +// Complete the definition of the incomplete type +struct spsv_descr { + detail::generic_container workspace; + std::size_t temp_buffer_size = 0; +}; + +} // namespace oneapi::mkl::sparse + +namespace oneapi::mkl::sparse::rocsparse { + +void init_spsv_descr(sycl::queue & /*queue*/, spsv_descr_t *p_spsv_descr) { + *p_spsv_descr = new spsv_descr(); +} + +sycl::event release_spsv_descr(sycl::queue &queue, spsv_descr_t spsv_descr, + const std::vector &dependencies) { + return detail::submit_release(queue, spsv_descr, dependencies); +} + +inline auto get_roc_spsv_alg(spsv_alg /*alg*/) { + return rocsparse_spsv_alg_default; +} + +void spsv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha, + oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x_handle, + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + oneapi::mkl::sparse::spsv_alg alg, + oneapi::mkl::sparse::spsv_descr_t spsv_descr, std::size_t &temp_buffer_size) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + detail::check_valid_spsv_common(__func__, A_view, A_handle, x_handle, y_handle, + is_alpha_host_accessible); + A_handle->throw_if_already_used(__func__); + auto functor = [=, &temp_buffer_size](RocsparseScopedContextHandler &sc) { + auto roc_handle = sc.get_handle(queue); + auto roc_a = A_handle->backend_handle; + auto roc_x = x_handle->backend_handle; + auto roc_y = y_handle->backend_handle; + set_matrix_attributes(__func__, roc_a, A_view); + auto roc_op = get_roc_operation(opA); + auto roc_type = get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = get_roc_spsv_alg(alg); + set_pointer_mode(roc_handle, is_alpha_host_accessible); + auto status = + rocsparse_spsv(roc_handle, roc_op, alpha, roc_a, roc_x, roc_y, roc_type, roc_alg, + rocsparse_spsv_stage_buffer_size, &temp_buffer_size, nullptr); + check_status(status, __func__); + }; + auto event = dispatch_submit(__func__, queue, functor, A_handle, x_handle, y_handle); + event.wait_and_throw(); + spsv_descr->temp_buffer_size = temp_buffer_size; +} + +void spsv_optimize_impl(rocsparse_handle roc_handle, oneapi::mkl::transpose opA, const void *alpha, + oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x_handle, + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + oneapi::mkl::sparse::spsv_alg alg, std::size_t buffer_size, + void *workspace_ptr, bool is_alpha_host_accessible) { + auto roc_a = A_handle->backend_handle; + auto roc_x = x_handle->backend_handle; + auto roc_y = y_handle->backend_handle; + set_matrix_attributes("optimize_spsv", roc_a, A_view); + auto roc_op = get_roc_operation(opA); + auto roc_type = get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = get_roc_spsv_alg(alg); + set_pointer_mode(roc_handle, is_alpha_host_accessible); + // rocsparse_spsv_stage_preprocess stage is blocking + auto status = rocsparse_spsv(roc_handle, roc_op, alpha, roc_a, roc_x, roc_y, roc_type, roc_alg, + rocsparse_spsv_stage_preprocess, &buffer_size, workspace_ptr); + check_status(status, "optimize_spsv"); +} + +void spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha, + oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x_handle, + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + oneapi::mkl::sparse::spsv_alg alg, oneapi::mkl::sparse::spsv_descr_t spsv_descr, + sycl::buffer workspace) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + detail::check_valid_spsv_common(__func__, A_view, A_handle, x_handle, y_handle, + is_alpha_host_accessible); + if (!A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__func__); + } + A_handle->throw_if_already_used(__func__); + // Ignore spsv_alg::no_optimize_alg as this step is mandatory for rocSPARSE + // Copy the buffer to extend its lifetime until the descriptor is free'd. + spsv_descr->workspace.set_buffer_untyped(workspace); + std::size_t buffer_size = spsv_descr->temp_buffer_size; + if (buffer_size) { + auto functor = [=](RocsparseScopedContextHandler &sc, + sycl::accessor workspace_acc) { + auto roc_handle = sc.get_handle(queue); + auto workspace_ptr = sc.get_mem(workspace_acc); + spsv_optimize_impl(roc_handle, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, + buffer_size, workspace_ptr, is_alpha_host_accessible); + }; + + // The accessor can only be bound to the cgh if the buffer size is + // greater than 0 + sycl::accessor workspace_placeholder_acc(workspace); + auto event = dispatch_submit(__func__, queue, functor, A_handle, workspace_placeholder_acc, + x_handle, y_handle); + event.wait_and_throw(); + } + else { + auto functor = [=](RocsparseScopedContextHandler &sc) { + auto roc_handle = sc.get_handle(queue); + spsv_optimize_impl(roc_handle, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, + buffer_size, nullptr, is_alpha_host_accessible); + }; + + auto event = dispatch_submit(__func__, queue, functor, A_handle, x_handle, y_handle); + event.wait_and_throw(); + } +} + +sycl::event spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha, + oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x_handle, + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + oneapi::mkl::sparse::spsv_alg alg, + oneapi::mkl::sparse::spsv_descr_t spsv_descr, void *workspace, + const std::vector &dependencies) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + detail::check_valid_spsv_common(__func__, A_view, A_handle, x_handle, y_handle, + is_alpha_host_accessible); + if (A_handle->all_use_buffer()) { + detail::throw_incompatible_container(__func__); + } + A_handle->throw_if_already_used(__func__); + spsv_descr->workspace.usm_ptr = workspace; + // Ignore spsv_alg::no_optimize_alg as this step is mandatory for rocSPARSE + std::size_t buffer_size = spsv_descr->temp_buffer_size; + auto functor = [=](RocsparseScopedContextHandler &sc) { + auto roc_handle = sc.get_handle(queue); + spsv_optimize_impl(roc_handle, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, + buffer_size, workspace, is_alpha_host_accessible); + }; + + return dispatch_submit(__func__, queue, dependencies, functor, A_handle, x_handle, y_handle); +} + +sycl::event spsv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha, + oneapi::mkl::sparse::matrix_view A_view, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x_handle, + oneapi::mkl::sparse::dense_vector_handle_t y_handle, + oneapi::mkl::sparse::spsv_alg alg, oneapi::mkl::sparse::spsv_descr_t spsv_descr, + const std::vector &dependencies) { + bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha); + detail::check_valid_spsv_common(__func__, A_view, A_handle, x_handle, y_handle, + is_alpha_host_accessible); + if (A_handle->all_use_buffer() != spsv_descr->workspace.use_buffer()) { + detail::throw_incompatible_container(__func__); + } + A_handle->throw_if_already_used(__func__); + A_handle->mark_used(); + auto &buffer_size = spsv_descr->temp_buffer_size; + auto compute_functor = [=, &buffer_size](RocsparseScopedContextHandler &sc, + void *workspace_ptr) { + auto [roc_handle, roc_stream] = sc.get_handle_and_stream(queue); + auto roc_a = A_handle->backend_handle; + auto roc_x = x_handle->backend_handle; + auto roc_y = y_handle->backend_handle; + set_matrix_attributes(__func__, roc_a, A_view); + auto roc_op = get_roc_operation(opA); + auto roc_type = get_roc_value_type(A_handle->value_container.data_type); + auto roc_alg = get_roc_spsv_alg(alg); + set_pointer_mode(roc_handle, is_alpha_host_accessible); + auto status = + rocsparse_spsv(roc_handle, roc_op, alpha, roc_a, roc_x, roc_y, roc_type, roc_alg, + rocsparse_spsv_stage_compute, &buffer_size, workspace_ptr); + check_status(status, __func__); + HIP_ERROR_FUNC(hipStreamSynchronize, roc_stream); + }; + if (A_handle->all_use_buffer() && buffer_size > 0) { + // The accessor can only be bound to the cgh if the buffer size is + // greater than 0 + auto functor_buffer = [=](RocsparseScopedContextHandler &sc, + sycl::accessor workspace_acc) { + auto workspace_ptr = sc.get_mem(workspace_acc); + compute_functor(sc, workspace_ptr); + }; + sycl::accessor workspace_placeholder_acc( + spsv_descr->workspace.get_buffer()); + return dispatch_submit(__func__, queue, dependencies, functor_buffer, A_handle, + workspace_placeholder_acc, x_handle, y_handle); + } + else { + // The same dispatch_submit can be used for USM or buffers if no + // workspace accessor is needed, workspace_ptr will be a nullptr in the + // latter case. + auto workspace_ptr = spsv_descr->workspace.usm_ptr; + auto functor_usm = [=](RocsparseScopedContextHandler &sc) { + compute_functor(sc, workspace_ptr); + }; + return dispatch_submit(__func__, queue, dependencies, functor_usm, A_handle, x_handle, + y_handle); + } +} + +} // namespace oneapi::mkl::sparse::rocsparse diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_error.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_error.hpp new file mode 100644 index 000000000..a86d4deef --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_error.hpp @@ -0,0 +1,126 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_ERROR_HPP_ +#define _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_ERROR_HPP_ + +#include + +#include +#include + +#include "oneapi/mkl/exceptions.hpp" + +namespace oneapi::mkl::sparse::rocsparse { + +inline std::string hip_result_to_str(hipError_t result) { + switch (result) { +#define ONEMKL_ROCSPARSE_CASE(STATUS) \ + case STATUS: return #STATUS + ONEMKL_ROCSPARSE_CASE(hipSuccess); + ONEMKL_ROCSPARSE_CASE(hipErrorInvalidContext); + ONEMKL_ROCSPARSE_CASE(hipErrorInvalidKernelFile); + ONEMKL_ROCSPARSE_CASE(hipErrorMemoryAllocation); + ONEMKL_ROCSPARSE_CASE(hipErrorInitializationError); + ONEMKL_ROCSPARSE_CASE(hipErrorLaunchFailure); + ONEMKL_ROCSPARSE_CASE(hipErrorLaunchOutOfResources); + ONEMKL_ROCSPARSE_CASE(hipErrorInvalidDevice); + ONEMKL_ROCSPARSE_CASE(hipErrorInvalidValue); + ONEMKL_ROCSPARSE_CASE(hipErrorInvalidDevicePointer); + ONEMKL_ROCSPARSE_CASE(hipErrorInvalidMemcpyDirection); + ONEMKL_ROCSPARSE_CASE(hipErrorUnknown); + ONEMKL_ROCSPARSE_CASE(hipErrorInvalidResourceHandle); + ONEMKL_ROCSPARSE_CASE(hipErrorNotReady); + ONEMKL_ROCSPARSE_CASE(hipErrorNoDevice); + ONEMKL_ROCSPARSE_CASE(hipErrorPeerAccessAlreadyEnabled); + ONEMKL_ROCSPARSE_CASE(hipErrorPeerAccessNotEnabled); + ONEMKL_ROCSPARSE_CASE(hipErrorRuntimeMemory); + ONEMKL_ROCSPARSE_CASE(hipErrorRuntimeOther); + ONEMKL_ROCSPARSE_CASE(hipErrorHostMemoryAlreadyRegistered); + ONEMKL_ROCSPARSE_CASE(hipErrorHostMemoryNotRegistered); + ONEMKL_ROCSPARSE_CASE(hipErrorMapBufferObjectFailed); + ONEMKL_ROCSPARSE_CASE(hipErrorTbd); + default: return ""; + } +} + +#define HIP_ERROR_FUNC(func, ...) \ + do { \ + auto res = func(__VA_ARGS__); \ + if (res != hipSuccess) { \ + throw oneapi::mkl::exception("sparse_blas", #func, \ + "hip error: " + hip_result_to_str(res)); \ + } \ + } while (0) + +inline std::string rocsparse_status_to_str(rocsparse_status status) { + switch (status) { +#define ONEMKL_ROCSPARSE_CASE(STATUS) \ + case STATUS: return #STATUS + ONEMKL_ROCSPARSE_CASE(rocsparse_status_success); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_invalid_handle); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_not_implemented); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_invalid_pointer); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_invalid_size); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_memory_error); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_internal_error); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_invalid_value); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_arch_mismatch); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_zero_pivot); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_not_initialized); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_type_mismatch); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_requires_sorted_storage); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_thrown_exception); + ONEMKL_ROCSPARSE_CASE(rocsparse_status_continue); +#undef ONEMKL_ROCSPARSE_CASE + default: return ""; + } +} + +inline void check_status(rocsparse_status status, const std::string& function, + std::string error_str = "") { + if (status != rocsparse_status_success) { + if (!error_str.empty()) { + error_str += "; "; + } + error_str += "rocSPARSE status: " + rocsparse_status_to_str(status); + switch (status) { + case rocsparse_status_not_implemented: + throw oneapi::mkl::unimplemented("sparse_blas", function, error_str); + case rocsparse_status_invalid_handle: + case rocsparse_status_invalid_pointer: + case rocsparse_status_invalid_size: + case rocsparse_status_invalid_value: + throw oneapi::mkl::invalid_argument("sparse_blas", function, error_str); + case rocsparse_status_not_initialized: + throw oneapi::mkl::uninitialized("sparse_blas", function, error_str); + default: throw oneapi::mkl::exception("sparse_blas", function, error_str); + } + } +} + +#define ROCSPARSE_ERR_FUNC(func, ...) \ + do { \ + auto status = func(__VA_ARGS__); \ + check_status(status, #func); \ + } while (0) + +} // namespace oneapi::mkl::sparse::rocsparse + +#endif // _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_ERROR_HPP_ diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_global_handle.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_global_handle.hpp new file mode 100644 index 000000000..283757f73 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_global_handle.hpp @@ -0,0 +1,63 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_GLOBAL_HANDLE_HPP_ +#define _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_GLOBAL_HANDLE_HPP_ + +/** + * @file Similar to blas_handle.hpp + * Provides a map from a pi_context (or equivalent) to a rocsparse_handle. + * @see rocsparse_scope_handle.hpp +*/ + +#include +#include + +namespace oneapi::mkl::sparse::rocsparse { + +template +struct rocsparse_global_handle { + using handle_container_t = std::unordered_map *>; + handle_container_t rocsparse_global_handle_mapper_{}; + + ~rocsparse_global_handle() noexcept(false) { + for (auto &handle_pair : rocsparse_global_handle_mapper_) { + if (handle_pair.second != nullptr) { + auto handle = handle_pair.second->exchange(nullptr); + if (handle != nullptr) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_handle, handle); + handle = nullptr; + } + else { + // if the handle is nullptr it means the handle was already + // destroyed by the ContextCallback and we're free to delete the + // atomic object. + delete handle_pair.second; + } + + handle_pair.second = nullptr; + } + } + rocsparse_global_handle_mapper_.clear(); + } +}; + +} // namespace oneapi::mkl::sparse::rocsparse + +#endif // _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_GLOBAL_HANDLE_HPP_ diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_handles.cpp b/src/sparse_blas/backends/rocsparse/rocsparse_handles.cpp new file mode 100644 index 000000000..1e9bb45c7 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_handles.cpp @@ -0,0 +1,529 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp" + +#include "rocsparse_error.hpp" +#include "rocsparse_helper.hpp" +#include "rocsparse_handles.hpp" +#include "rocsparse_task.hpp" +#include "sparse_blas/macros.hpp" + +namespace oneapi::mkl::sparse::rocsparse { + +/** + * In this file RocsparseScopedContextHandler are used to ensure that a rocsparse_handle is created before any other rocSPARSE call, as required by the specification. +*/ + +// Dense vector +template +void init_dense_vector(sycl::queue &queue, dense_vector_handle_t *p_dvhandle, std::int64_t size, + sycl::buffer val) { + auto event = queue.submit([&](sycl::handler &cgh) { + auto acc = val.template get_access(cgh); + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + auto roc_value_type = RocEnumType::value; + rocsparse_dnvec_descr roc_dvhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_dnvec_descr, &roc_dvhandle, size, sc.get_mem(acc), + roc_value_type); + *p_dvhandle = new dense_vector_handle(roc_dvhandle, val, size); + }); + }); + event.wait_and_throw(); +} + +template +void init_dense_vector(sycl::queue &queue, dense_vector_handle_t *p_dvhandle, std::int64_t size, + fpType *val) { + auto event = queue.submit([&](sycl::handler &cgh) { + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + auto roc_value_type = RocEnumType::value; + rocsparse_dnvec_descr roc_dvhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_dnvec_descr, &roc_dvhandle, size, sc.get_mem(val), + roc_value_type); + *p_dvhandle = new dense_vector_handle(roc_dvhandle, val, size); + }); + }); + event.wait_and_throw(); +} + +template +void set_dense_vector_data(sycl::queue &queue, oneapi::mkl::sparse::dense_vector_handle_t dvhandle, + std::int64_t size, sycl::buffer val) { + detail::check_can_reset_value_handle(__func__, dvhandle, true); + auto event = queue.submit([&](sycl::handler &cgh) { + auto acc = val.template get_access(cgh); + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + if (dvhandle->size != size) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnvec_descr, dvhandle->backend_handle); + auto roc_value_type = RocEnumType::value; + ROCSPARSE_ERR_FUNC(rocsparse_create_dnvec_descr, &dvhandle->backend_handle, size, + sc.get_mem(acc), roc_value_type); + dvhandle->size = size; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_dnvec_set_values, dvhandle->backend_handle, + sc.get_mem(acc)); + } + dvhandle->set_buffer(val); + }); + }); + event.wait_and_throw(); +} + +template +void set_dense_vector_data(sycl::queue &queue, oneapi::mkl::sparse::dense_vector_handle_t dvhandle, + std::int64_t size, fpType *val) { + detail::check_can_reset_value_handle(__func__, dvhandle, false); + auto event = queue.submit([&](sycl::handler &cgh) { + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + if (dvhandle->size != size) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnvec_descr, dvhandle->backend_handle); + auto roc_value_type = RocEnumType::value; + ROCSPARSE_ERR_FUNC(rocsparse_create_dnvec_descr, &dvhandle->backend_handle, size, + sc.get_mem(val), roc_value_type); + dvhandle->size = size; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_dnvec_set_values, dvhandle->backend_handle, + sc.get_mem(val)); + } + dvhandle->set_usm_ptr(val); + }); + }); + event.wait_and_throw(); +} + +FOR_EACH_FP_TYPE(INSTANTIATE_DENSE_VECTOR_FUNCS); + +sycl::event release_dense_vector(sycl::queue &queue, dense_vector_handle_t dvhandle, + const std::vector &dependencies) { + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + cgh.host_task([=]() { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnvec_descr, dvhandle->backend_handle); + delete dvhandle; + }); + }); +} + +// Dense matrix +template +void init_dense_matrix(sycl::queue &queue, dense_matrix_handle_t *p_dmhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t ld, layout dense_layout, + sycl::buffer val) { + auto event = queue.submit([&](sycl::handler &cgh) { + auto acc = val.template get_access(cgh); + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + auto roc_value_type = RocEnumType::value; + auto roc_order = get_roc_order(dense_layout); + rocsparse_dnmat_descr roc_dmhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_dnmat_descr, &roc_dmhandle, num_rows, num_cols, ld, + sc.get_mem(acc), roc_value_type, roc_order); + *p_dmhandle = + new dense_matrix_handle(roc_dmhandle, val, num_rows, num_cols, ld, dense_layout); + }); + }); + event.wait_and_throw(); +} + +template +void init_dense_matrix(sycl::queue &queue, dense_matrix_handle_t *p_dmhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t ld, layout dense_layout, fpType *val) { + auto event = queue.submit([&](sycl::handler &cgh) { + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + auto roc_value_type = RocEnumType::value; + auto roc_order = get_roc_order(dense_layout); + rocsparse_dnmat_descr roc_dmhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_dnmat_descr, &roc_dmhandle, num_rows, num_cols, ld, + sc.get_mem(val), roc_value_type, roc_order); + *p_dmhandle = + new dense_matrix_handle(roc_dmhandle, val, num_rows, num_cols, ld, dense_layout); + }); + }); + event.wait_and_throw(); +} + +template +void set_dense_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::dense_matrix_handle_t dmhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, + oneapi::mkl::layout dense_layout, sycl::buffer val) { + detail::check_can_reset_value_handle(__func__, dmhandle, true); + auto event = queue.submit([&](sycl::handler &cgh) { + auto acc = val.template get_access(cgh); + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + if (dmhandle->num_rows != num_rows || dmhandle->num_cols != num_cols || + dmhandle->ld != ld || dmhandle->dense_layout != dense_layout) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnmat_descr, dmhandle->backend_handle); + auto roc_value_type = RocEnumType::value; + auto roc_order = get_roc_order(dense_layout); + ROCSPARSE_ERR_FUNC(rocsparse_create_dnmat_descr, &dmhandle->backend_handle, + num_rows, num_cols, ld, sc.get_mem(acc), roc_value_type, + roc_order); + dmhandle->num_rows = num_rows; + dmhandle->num_cols = num_cols; + dmhandle->ld = ld; + dmhandle->dense_layout = dense_layout; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_dnmat_set_values, dmhandle->backend_handle, + sc.get_mem(acc)); + } + dmhandle->set_buffer(val); + }); + }); + event.wait_and_throw(); +} + +template +void set_dense_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::dense_matrix_handle_t dmhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, + oneapi::mkl::layout dense_layout, fpType *val) { + detail::check_can_reset_value_handle(__func__, dmhandle, false); + auto event = queue.submit([&](sycl::handler &cgh) { + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + if (dmhandle->num_rows != num_rows || dmhandle->num_cols != num_cols || + dmhandle->ld != ld || dmhandle->dense_layout != dense_layout) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnmat_descr, dmhandle->backend_handle); + auto roc_value_type = RocEnumType::value; + auto roc_order = get_roc_order(dense_layout); + ROCSPARSE_ERR_FUNC(rocsparse_create_dnmat_descr, &dmhandle->backend_handle, + num_rows, num_cols, ld, sc.get_mem(val), roc_value_type, + roc_order); + dmhandle->num_rows = num_rows; + dmhandle->num_cols = num_cols; + dmhandle->ld = ld; + dmhandle->dense_layout = dense_layout; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_dnmat_set_values, dmhandle->backend_handle, + sc.get_mem(val)); + } + dmhandle->set_usm_ptr(val); + }); + }); + event.wait_and_throw(); +} + +FOR_EACH_FP_TYPE(INSTANTIATE_DENSE_MATRIX_FUNCS); + +sycl::event release_dense_matrix(sycl::queue &queue, dense_matrix_handle_t dmhandle, + const std::vector &dependencies) { + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + cgh.host_task([=]() { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_dnmat_descr, dmhandle->backend_handle); + delete dmhandle; + }); + }); +} + +// COO matrix +template +void init_coo_matrix(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p_smhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, + oneapi::mkl::index_base index, sycl::buffer row_ind, + sycl::buffer col_ind, sycl::buffer val) { + auto event = queue.submit([&](sycl::handler &cgh) { + auto row_acc = row_ind.template get_access(cgh); + auto col_acc = col_ind.template get_access(cgh); + auto val_acc = val.template get_access(cgh); + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + auto roc_index_type = RocIndexEnumType::value; + auto roc_index_base = get_roc_index_base(index); + auto roc_value_type = RocEnumType::value; + rocsparse_spmat_descr roc_smhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_coo_descr, &roc_smhandle, num_rows, num_cols, nnz, + sc.get_mem(row_acc), sc.get_mem(col_acc), sc.get_mem(val_acc), + roc_index_type, roc_index_base, roc_value_type); + *p_smhandle = new oneapi::mkl::sparse::matrix_handle( + roc_smhandle, row_ind, col_ind, val, num_rows, num_cols, nnz, index); + }); + }); + event.wait_and_throw(); +} + +template +void init_coo_matrix(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p_smhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, + oneapi::mkl::index_base index, intType *row_ind, intType *col_ind, + fpType *val) { + auto event = queue.submit([&](sycl::handler &cgh) { + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + auto roc_index_type = RocIndexEnumType::value; + auto roc_index_base = get_roc_index_base(index); + auto roc_value_type = RocEnumType::value; + rocsparse_spmat_descr roc_smhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_coo_descr, &roc_smhandle, num_rows, num_cols, nnz, + sc.get_mem(row_ind), sc.get_mem(col_ind), sc.get_mem(val), + roc_index_type, roc_index_base, roc_value_type); + *p_smhandle = new oneapi::mkl::sparse::matrix_handle( + roc_smhandle, row_ind, col_ind, val, num_rows, num_cols, nnz, index); + }); + }); + event.wait_and_throw(); +} + +template +void set_coo_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t smhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, + oneapi::mkl::index_base index, sycl::buffer row_ind, + sycl::buffer col_ind, sycl::buffer val) { + detail::check_can_reset_sparse_handle(__func__, smhandle, true); + auto event = queue.submit([&](sycl::handler &cgh) { + auto row_acc = row_ind.template get_access(cgh); + auto col_acc = col_ind.template get_access(cgh); + auto val_acc = val.template get_access(cgh); + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + if (smhandle->num_rows != num_rows || smhandle->num_cols != num_cols || + smhandle->nnz != nnz || smhandle->index != index) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle); + auto roc_index_type = RocIndexEnumType::value; + auto roc_index_base = get_roc_index_base(index); + auto roc_value_type = RocEnumType::value; + ROCSPARSE_ERR_FUNC(rocsparse_create_coo_descr, &smhandle->backend_handle, num_rows, + num_cols, nnz, sc.get_mem(row_acc), sc.get_mem(col_acc), + sc.get_mem(val_acc), roc_index_type, roc_index_base, + roc_value_type); + smhandle->num_rows = num_rows; + smhandle->num_cols = num_cols; + smhandle->nnz = nnz; + smhandle->index = index; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_coo_set_pointers, smhandle->backend_handle, + sc.get_mem(row_acc), sc.get_mem(col_acc), sc.get_mem(val_acc)); + } + smhandle->row_container.set_buffer(row_ind); + smhandle->col_container.set_buffer(col_ind); + smhandle->value_container.set_buffer(val); + }); + }); + event.wait_and_throw(); +} + +template +void set_coo_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t smhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, + oneapi::mkl::index_base index, intType *row_ind, intType *col_ind, + fpType *val) { + detail::check_can_reset_sparse_handle(__func__, smhandle, false); + auto event = queue.submit([&](sycl::handler &cgh) { + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + if (smhandle->num_rows != num_rows || smhandle->num_cols != num_cols || + smhandle->nnz != nnz || smhandle->index != index) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle); + auto roc_index_type = RocIndexEnumType::value; + auto roc_index_base = get_roc_index_base(index); + auto roc_value_type = RocEnumType::value; + ROCSPARSE_ERR_FUNC(rocsparse_create_coo_descr, &smhandle->backend_handle, num_rows, + num_cols, nnz, sc.get_mem(row_ind), sc.get_mem(col_ind), + sc.get_mem(val), roc_index_type, roc_index_base, roc_value_type); + smhandle->num_rows = num_rows; + smhandle->num_cols = num_cols; + smhandle->nnz = nnz; + smhandle->index = index; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_coo_set_pointers, smhandle->backend_handle, + sc.get_mem(row_ind), sc.get_mem(col_ind), sc.get_mem(val)); + } + smhandle->row_container.set_usm_ptr(row_ind); + smhandle->col_container.set_usm_ptr(col_ind); + smhandle->value_container.set_usm_ptr(val); + }); + }); + event.wait_and_throw(); +} + +FOR_EACH_FP_AND_INT_TYPE(INSTANTIATE_COO_MATRIX_FUNCS); + +// CSR matrix +template +void init_csr_matrix(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p_smhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, + oneapi::mkl::index_base index, sycl::buffer row_ptr, + sycl::buffer col_ind, sycl::buffer val) { + auto event = queue.submit([&](sycl::handler &cgh) { + auto row_acc = row_ptr.template get_access(cgh); + auto col_acc = col_ind.template get_access(cgh); + auto val_acc = val.template get_access(cgh); + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + auto roc_index_type = RocIndexEnumType::value; + auto roc_index_base = get_roc_index_base(index); + auto roc_value_type = RocEnumType::value; + rocsparse_spmat_descr roc_smhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_csr_descr, &roc_smhandle, num_rows, num_cols, nnz, + sc.get_mem(row_acc), sc.get_mem(col_acc), sc.get_mem(val_acc), + roc_index_type, roc_index_type, roc_index_base, roc_value_type); + *p_smhandle = new oneapi::mkl::sparse::matrix_handle( + roc_smhandle, row_ptr, col_ind, val, num_rows, num_cols, nnz, index); + }); + }); + event.wait_and_throw(); +} + +template +void init_csr_matrix(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p_smhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, + oneapi::mkl::index_base index, intType *row_ptr, intType *col_ind, + fpType *val) { + auto event = queue.submit([&](sycl::handler &cgh) { + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + auto roc_index_type = RocIndexEnumType::value; + auto roc_index_base = get_roc_index_base(index); + auto roc_value_type = RocEnumType::value; + rocsparse_spmat_descr roc_smhandle; + ROCSPARSE_ERR_FUNC(rocsparse_create_csr_descr, &roc_smhandle, num_rows, num_cols, nnz, + sc.get_mem(row_ptr), sc.get_mem(col_ind), sc.get_mem(val), + roc_index_type, roc_index_type, roc_index_base, roc_value_type); + *p_smhandle = new oneapi::mkl::sparse::matrix_handle( + roc_smhandle, row_ptr, col_ind, val, num_rows, num_cols, nnz, index); + }); + }); + event.wait_and_throw(); +} + +template +void set_csr_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t smhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, + oneapi::mkl::index_base index, sycl::buffer row_ptr, + sycl::buffer col_ind, sycl::buffer val) { + detail::check_can_reset_sparse_handle(__func__, smhandle, true); + auto event = queue.submit([&](sycl::handler &cgh) { + auto row_acc = row_ptr.template get_access(cgh); + auto col_acc = col_ind.template get_access(cgh); + auto val_acc = val.template get_access(cgh); + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + if (smhandle->num_rows != num_rows || smhandle->num_cols != num_cols || + smhandle->nnz != nnz || smhandle->index != index) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle); + auto roc_index_type = RocIndexEnumType::value; + auto roc_index_base = get_roc_index_base(index); + auto roc_value_type = RocEnumType::value; + ROCSPARSE_ERR_FUNC(rocsparse_create_csr_descr, &smhandle->backend_handle, num_rows, + num_cols, nnz, sc.get_mem(row_acc), sc.get_mem(col_acc), + sc.get_mem(val_acc), roc_index_type, roc_index_type, + roc_index_base, roc_value_type); + smhandle->num_rows = num_rows; + smhandle->num_cols = num_cols; + smhandle->nnz = nnz; + smhandle->index = index; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_csr_set_pointers, smhandle->backend_handle, + sc.get_mem(row_acc), sc.get_mem(col_acc), sc.get_mem(val_acc)); + } + smhandle->row_container.set_buffer(row_ptr); + smhandle->col_container.set_buffer(col_ind); + smhandle->value_container.set_buffer(val); + }); + }); + event.wait_and_throw(); +} + +template +void set_csr_matrix_data(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t smhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, + oneapi::mkl::index_base index, intType *row_ptr, intType *col_ind, + fpType *val) { + detail::check_can_reset_sparse_handle(__func__, smhandle, false); + auto event = queue.submit([&](sycl::handler &cgh) { + submit_host_task(cgh, queue, [=](RocsparseScopedContextHandler &sc) { + // Ensure that a rocsparse handle is created before any other rocSPARSE function is called. + sc.get_handle(queue); + if (smhandle->num_rows != num_rows || smhandle->num_cols != num_cols || + smhandle->nnz != nnz || smhandle->index != index) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle); + auto roc_index_type = RocIndexEnumType::value; + auto roc_index_base = get_roc_index_base(index); + auto roc_value_type = RocEnumType::value; + ROCSPARSE_ERR_FUNC(rocsparse_create_csr_descr, &smhandle->backend_handle, num_rows, + num_cols, nnz, sc.get_mem(row_ptr), sc.get_mem(col_ind), + sc.get_mem(val), roc_index_type, roc_index_type, roc_index_base, + roc_value_type); + smhandle->num_rows = num_rows; + smhandle->num_cols = num_cols; + smhandle->nnz = nnz; + smhandle->index = index; + } + else { + ROCSPARSE_ERR_FUNC(rocsparse_csr_set_pointers, smhandle->backend_handle, + sc.get_mem(row_ptr), sc.get_mem(col_ind), sc.get_mem(val)); + } + smhandle->row_container.set_usm_ptr(row_ptr); + smhandle->col_container.set_usm_ptr(col_ind); + smhandle->value_container.set_usm_ptr(val); + }); + }); + event.wait_and_throw(); +} + +FOR_EACH_FP_AND_INT_TYPE(INSTANTIATE_CSR_MATRIX_FUNCS); + +sycl::event release_sparse_matrix(sycl::queue &queue, matrix_handle_t smhandle, + const std::vector &dependencies) { + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + cgh.host_task([=]() { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_spmat_descr, smhandle->backend_handle); + delete smhandle; + }); + }); +} + +// Matrix property +bool set_matrix_property(sycl::queue &, matrix_handle_t smhandle, matrix_property property) { + // No equivalent in rocSPARSE + // Store the matrix property internally for future usages + smhandle->set_matrix_property(property); + return false; +} + +} // namespace oneapi::mkl::sparse::rocsparse diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_handles.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_handles.hpp new file mode 100644 index 000000000..59a73649d --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_handles.hpp @@ -0,0 +1,97 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SRC_SPARSE_BLAS_BACKENDS_ROCSPARSE_HANDLES_HPP_ +#define _ONEMKL_SRC_SPARSE_BLAS_BACKENDS_ROCSPARSE_HANDLES_HPP_ + +#include + +#include "sparse_blas/generic_container.hpp" + +namespace oneapi::mkl::sparse { + +// Complete the definition of incomplete types dense_vector_handle, dense_matrix_handle and matrix_handle. + +struct dense_vector_handle : public detail::generic_dense_vector_handle { + template + dense_vector_handle(rocsparse_dnvec_descr roc_descr, T* value_ptr, std::int64_t size) + : detail::generic_dense_vector_handle(roc_descr, value_ptr, + size) {} + + template + dense_vector_handle(rocsparse_dnvec_descr roc_descr, const sycl::buffer value_buffer, + std::int64_t size) + : detail::generic_dense_vector_handle(roc_descr, value_buffer, + size) {} +}; + +struct dense_matrix_handle : public detail::generic_dense_matrix_handle { + template + dense_matrix_handle(rocsparse_dnmat_descr roc_descr, T* value_ptr, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t ld, layout dense_layout) + : detail::generic_dense_matrix_handle( + roc_descr, value_ptr, num_rows, num_cols, ld, dense_layout) {} + + template + dense_matrix_handle(rocsparse_dnmat_descr roc_descr, const sycl::buffer value_buffer, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, + layout dense_layout) + : detail::generic_dense_matrix_handle( + roc_descr, value_buffer, num_rows, num_cols, ld, dense_layout) {} +}; + +struct matrix_handle : public detail::generic_sparse_handle { + // A matrix handle should only be used once per operation to be safe with the rocSPARSE backend. + // An operation can store information in the handle. See details in https://github.com/ROCm/rocSPARSE/issues/332. +private: + bool used = false; + +public: + template + matrix_handle(rocsparse_spmat_descr roc_descr, intType* row_ptr, intType* col_ptr, + fpType* value_ptr, std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, + oneapi::mkl::index_base index) + : detail::generic_sparse_handle( + roc_descr, row_ptr, col_ptr, value_ptr, num_rows, num_cols, nnz, index) {} + + template + matrix_handle(rocsparse_spmat_descr roc_descr, const sycl::buffer row_buffer, + const sycl::buffer col_buffer, + const sycl::buffer value_buffer, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index) + : detail::generic_sparse_handle( + roc_descr, row_buffer, col_buffer, value_buffer, num_rows, num_cols, nnz, index) { + } + + void throw_if_already_used(const std::string& function_name) { + if (used) { + throw mkl::unimplemented( + "sparse_blas", function_name, + "The backend does not support re-using the same sparse matrix handle in multiple operations."); + } + } + + void mark_used() { + used = true; + } +}; + +} // namespace oneapi::mkl::sparse + +#endif // _ONEMKL_SRC_SPARSE_BLAS_BACKENDS_ROCSPARSE_HANDLES_HPP_ diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_helper.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_helper.hpp new file mode 100644 index 000000000..7afb45da3 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_helper.hpp @@ -0,0 +1,160 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_HELPER_HPP_ +#define _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_HELPER_HPP_ + +#include +#include +#include +#include + +#include + +#include "oneapi/mkl/sparse_blas/types.hpp" +#include "sparse_blas/enum_data_types.hpp" +#include "sparse_blas/sycl_helper.hpp" +#include "rocsparse_error.hpp" + +namespace oneapi::mkl::sparse::rocsparse { + +template +struct RocEnumType; +template <> +struct RocEnumType { + static constexpr rocsparse_datatype value = rocsparse_datatype_f32_r; +}; +template <> +struct RocEnumType { + static constexpr rocsparse_datatype value = rocsparse_datatype_f64_r; +}; +template <> +struct RocEnumType> { + static constexpr rocsparse_datatype value = rocsparse_datatype_f32_c; +}; +template <> +struct RocEnumType> { + static constexpr rocsparse_datatype value = rocsparse_datatype_f64_c; +}; + +template +struct RocIndexEnumType; +template <> +struct RocIndexEnumType { + static constexpr rocsparse_indextype value = rocsparse_indextype_i32; +}; +template <> +struct RocIndexEnumType { + static constexpr rocsparse_indextype value = rocsparse_indextype_i64; +}; + +template +inline std::string cast_enum_to_str(E e) { + return std::to_string(static_cast(e)); +} + +inline auto get_roc_value_type(detail::data_type onemkl_data_type) { + switch (onemkl_data_type) { + case detail::data_type::real_fp32: return rocsparse_datatype_f32_r; + case detail::data_type::real_fp64: return rocsparse_datatype_f64_r; + case detail::data_type::complex_fp32: return rocsparse_datatype_f32_c; + case detail::data_type::complex_fp64: return rocsparse_datatype_f64_c; + default: + throw oneapi::mkl::invalid_argument( + "sparse_blas", "get_roc_value_type", + "Invalid data type: " + cast_enum_to_str(onemkl_data_type)); + } +} + +inline auto get_roc_order(layout l) { + switch (l) { + case layout::row_major: return rocsparse_order_row; + case layout::col_major: return rocsparse_order_column; + default: + throw oneapi::mkl::invalid_argument("sparse_blas", "get_roc_order", + "Unknown layout: " + cast_enum_to_str(l)); + } +} + +inline auto get_roc_index_base(index_base index) { + switch (index) { + case index_base::zero: return rocsparse_index_base_zero; + case index_base::one: return rocsparse_index_base_one; + default: + throw oneapi::mkl::invalid_argument("sparse_blas", "get_roc_index_base", + "Unknown index_base: " + cast_enum_to_str(index)); + } +} + +inline auto get_roc_operation(transpose op) { + switch (op) { + case transpose::nontrans: return rocsparse_operation_none; + case transpose::trans: return rocsparse_operation_transpose; + case transpose::conjtrans: return rocsparse_operation_conjugate_transpose; + default: + throw oneapi::mkl::invalid_argument( + "sparse_blas", "get_roc_operation", + "Unknown transpose operation: " + cast_enum_to_str(op)); + } +} + +inline auto get_roc_uplo(uplo uplo_val) { + switch (uplo_val) { + case uplo::upper: return rocsparse_fill_mode_upper; + case uplo::lower: return rocsparse_fill_mode_lower; + default: + throw oneapi::mkl::invalid_argument("sparse_blas", "get_roc_uplo", + "Unknown uplo: " + cast_enum_to_str(uplo_val)); + } +} + +inline auto get_roc_diag(diag diag_val) { + switch (diag_val) { + case diag::nonunit: return rocsparse_diag_type_non_unit; + case diag::unit: return rocsparse_diag_type_unit; + default: + throw oneapi::mkl::invalid_argument("sparse_blas", "get_roc_diag", + "Unknown diag: " + cast_enum_to_str(diag_val)); + } +} + +inline void set_matrix_attributes(const std::string& func_name, rocsparse_spmat_descr roc_a, + oneapi::mkl::sparse::matrix_view A_view) { + auto roc_fill_mode = get_roc_uplo(A_view.uplo_view); + auto status = rocsparse_spmat_set_attribute(roc_a, rocsparse_spmat_fill_mode, &roc_fill_mode, + sizeof(roc_fill_mode)); + check_status(status, func_name + "/set_uplo"); + + auto roc_diag_type = get_roc_diag(A_view.diag_view); + status = rocsparse_spmat_set_attribute(roc_a, rocsparse_spmat_diag_type, &roc_diag_type, + sizeof(roc_diag_type)); + check_status(status, func_name + "/set_diag"); +} + +/** + * rocSPARSE requires to set the pointer mode for scalars parameters (typically alpha and beta). + */ +inline void set_pointer_mode(rocsparse_handle roc_handle, bool is_ptr_accessible_on_host) { + rocsparse_set_pointer_mode(roc_handle, is_ptr_accessible_on_host + ? rocsparse_pointer_mode_host + : rocsparse_pointer_mode_device); +} + +} // namespace oneapi::mkl::sparse::rocsparse + +#endif //_ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_HELPER_HPP_ diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.cpp b/src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.cpp new file mode 100644 index 000000000..580d766e2 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.cpp @@ -0,0 +1,125 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +/** + * @file Similar to rocblas_scope_handle.cpp +*/ + +#include "rocsparse_scope_handle.hpp" + +namespace oneapi::mkl::sparse::rocsparse { + +/** + * Inserts a new element in the map if its key is unique. This new element + * is constructed in place using args as the arguments for the construction + * of a value_type (which is an object of a pair type). The insertion only + * takes place if no other element in the container has a key equivalent to + * the one being emplaced (keys in a map container are unique). + */ +thread_local rocsparse_handle_container RocsparseScopedContextHandler::handle_helper = + rocsparse_handle_container{}; + +// Disable warning for deprecated hipCtxGetCurrent and similar hip runtime functions +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + +RocsparseScopedContextHandler::RocsparseScopedContextHandler(sycl::queue queue, + sycl::interop_handle &ih) + : ih(ih), + needToRecover_(false) { + placedContext_ = new sycl::context(queue.get_context()); + auto hipDevice = ih.get_native_device(); + hipCtx_t desired; + HIP_ERROR_FUNC(hipCtxGetCurrent, &original_); + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, &desired, hipDevice); + if (original_ != desired) { + // Sets the desired context as the active one for the thread + HIP_ERROR_FUNC(hipCtxSetCurrent, desired); + // No context is installed and the suggested context is primary + // This is the most common case. We can activate the context in the + // thread and leave it there until all the PI context referring to the + // same underlying rocsparse primary context are destroyed. This emulates + // the behaviour of the rocsparse runtime api, and avoids costly context + // switches. No action is required on this side of the if. + needToRecover_ = !(original_ == nullptr); + } +} + +RocsparseScopedContextHandler::~RocsparseScopedContextHandler() noexcept(false) { + if (needToRecover_) { + HIP_ERROR_FUNC(hipCtxSetCurrent, original_); + } + delete placedContext_; +} + +void ContextCallback(void *userData) { + auto *atomic_ptr = static_cast *>(userData); + if (!atomic_ptr) { + return; + } + auto handle = atomic_ptr->exchange(nullptr); + if (handle != nullptr) { + ROCSPARSE_ERR_FUNC(rocsparse_destroy_handle, handle); + } + delete atomic_ptr; +} + +std::pair RocsparseScopedContextHandler::get_handle_and_stream( + const sycl::queue &queue) { + auto hipDevice = ih.get_native_device(); + hipCtx_t desired; + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, &desired, hipDevice); + auto piPlacedContext_ = reinterpret_cast(desired); + hipStream_t streamId = sycl::get_native(queue); + auto it = handle_helper.rocsparse_handle_container_mapper_.find(piPlacedContext_); + if (it != handle_helper.rocsparse_handle_container_mapper_.end()) { + if (it->second == nullptr) { + handle_helper.rocsparse_handle_container_mapper_.erase(it); + } + else { + auto handle = it->second->load(); + if (handle != nullptr) { + hipStream_t currentStreamId; + ROCSPARSE_ERR_FUNC(rocsparse_get_stream, handle, ¤tStreamId); + if (currentStreamId != streamId) { + ROCSPARSE_ERR_FUNC(rocsparse_set_stream, handle, streamId); + } + return { handle, streamId }; + } + else { + handle_helper.rocsparse_handle_container_mapper_.erase(it); + } + } + } + + rocsparse_handle handle; + ROCSPARSE_ERR_FUNC(rocsparse_create_handle, &handle); + ROCSPARSE_ERR_FUNC(rocsparse_set_stream, handle, streamId); + + auto atomic_ptr = new std::atomic(handle); + handle_helper.rocsparse_handle_container_mapper_.insert( + std::make_pair(piPlacedContext_, atomic_ptr)); + + sycl::detail::pi::contextSetExtendedDeleter(*placedContext_, ContextCallback, atomic_ptr); + return { handle, streamId }; +} + +#pragma clang diagnostic pop + +} // namespace oneapi::mkl::sparse::rocsparse diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.hpp new file mode 100644 index 000000000..86ef459ea --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_scope_handle.hpp @@ -0,0 +1,83 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_SCOPE_HANDLE_HPP_ +#define _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_SCOPE_HANDLE_HPP_ + +/** + * @file Similar to rocblas_scope_handle.hpp +*/ + +#if __has_include() +#include +#else +#include +#endif + +#include + +#include "rocsparse_error.hpp" +#include "rocsparse_global_handle.hpp" +#include "rocsparse_helper.hpp" + +namespace oneapi::mkl::sparse::rocsparse { + +template +struct rocsparse_handle_container { + using handle_container_t = std::unordered_map *>; + handle_container_t rocsparse_handle_container_mapper_{}; + + // Do not free any pointer nor handle in this destructor. The resources are + // free'd via the PI ContextCallback to ensure the context is still alive. + ~rocsparse_handle_container() = default; +}; + +class RocsparseScopedContextHandler { + HIPcontext original_; + sycl::context *placedContext_; + sycl::interop_handle &ih; + bool needToRecover_; + static thread_local rocsparse_handle_container handle_helper; + +public: + RocsparseScopedContextHandler(sycl::queue queue, sycl::interop_handle &ih); + ~RocsparseScopedContextHandler() noexcept(false); + + std::pair get_handle_and_stream(const sycl::queue &queue); + + rocsparse_handle get_handle(const sycl::queue &queue) { + return get_handle_and_stream(queue).first; + } + + // This is a work-around function for reinterpret_casting the memory. This + // will be fixed when SYCL-2020 has been implemented for Pi backend. + template + inline void *get_mem(AccT acc) { + auto hipPtr = ih.get_native_mem(acc); + return reinterpret_cast(hipPtr); + } + + template + inline void *get_mem(T *ptr) { + return reinterpret_cast(ptr); + } +}; + +} // namespace oneapi::mkl::sparse::rocsparse + +#endif //_ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_SCOPE_HANDLE_HPP_ diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_task.hpp b/src/sparse_blas/backends/rocsparse/rocsparse_task.hpp new file mode 100644 index 000000000..a13f24aeb --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_task.hpp @@ -0,0 +1,187 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_TASKS_HPP_ +#define _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_TASKS_HPP_ + +#include "rocsparse_handles.hpp" +#include "rocsparse_scope_handle.hpp" + +/// This file provide a helper function to submit host_task using buffers or USM seamlessly + +namespace oneapi::mkl::sparse::rocsparse { + +template +auto get_value_accessor(sycl::handler &cgh, Container container) { + auto buffer_ptr = + reinterpret_cast *>(container->value_container.buffer_ptr.get()); + return buffer_ptr->template get_access(cgh); +} + +template +auto get_fp_accessors(sycl::handler &cgh, Ts... containers) { + return std::array, sizeof...(containers)>{ get_value_accessor( + cgh, containers)... }; +} + +template +auto get_row_accessor(sycl::handler &cgh, matrix_handle_t smhandle) { + auto buffer_ptr = + reinterpret_cast *>(smhandle->row_container.buffer_ptr.get()); + return buffer_ptr->template get_access(cgh); +} + +template +auto get_col_accessor(sycl::handler &cgh, matrix_handle_t smhandle) { + auto buffer_ptr = + reinterpret_cast *>(smhandle->col_container.buffer_ptr.get()); + return buffer_ptr->template get_access(cgh); +} + +template +auto get_int_accessors(sycl::handler &cgh, matrix_handle_t smhandle) { + return std::array, 2>{ get_row_accessor(cgh, smhandle), + get_col_accessor(cgh, smhandle) }; +} + +template +void submit_host_task(sycl::handler &cgh, sycl::queue &queue, Functor functor, + CaptureOnlyAcc... capture_only_accessors) { + // Only capture the accessors to ensure the dependencies are properly handled + // The accessors's pointer have already been set to the native container types in previous functions + cgh.host_task([functor, queue, capture_only_accessors...](sycl::interop_handle ih) { + auto unused = std::make_tuple(capture_only_accessors...); + (void)unused; + auto sc = RocsparseScopedContextHandler(queue, ih); + functor(sc); + }); +} + +template +void submit_host_task_with_acc(sycl::handler &cgh, sycl::queue &queue, Functor functor, + sycl::accessor workspace_placeholder_acc, + CaptureOnlyAcc... capture_only_accessors) { + // Only capture the accessors to ensure the dependencies are properly handled + // The accessors's pointer have already been set to the native container types in previous functions + cgh.require(workspace_placeholder_acc); + cgh.host_task([functor, queue, workspace_placeholder_acc, + capture_only_accessors...](sycl::interop_handle ih) { + auto unused = std::make_tuple(capture_only_accessors...); + (void)unused; + auto sc = RocsparseScopedContextHandler(queue, ih); + functor(sc, workspace_placeholder_acc); + }); +} + +/// Helper submit functions to capture all accessors from the generic containers \p other_containers and ensure the dependencies of buffers are respected. +/// The accessors are not directly used as the underlying data pointer has already been captured in previous functions. +/// \p workspace_placeholder_acc is a placeholder accessor that will be bound to the cgh if not empty and given to the functor as a last argument +template +sycl::event dispatch_submit(const std::string &function_name, sycl::queue queue, + const std::vector &dependencies, Functor functor, + matrix_handle_t sm_handle, + sycl::accessor workspace_placeholder_acc, + Ts... other_containers) { + if (sm_handle->all_use_buffer()) { + detail::data_type value_type = sm_handle->get_value_type(); + detail::data_type int_type = sm_handle->get_int_type(); + +#define ONEMKL_ROCSPARSE_SUBMIT(FP_TYPE, INT_TYPE) \ + return queue.submit([&](sycl::handler &cgh) { \ + cgh.depends_on(dependencies); \ + auto fp_accs = get_fp_accessors(cgh, sm_handle, other_containers...); \ + auto int_accs = get_int_accessors(cgh, sm_handle); \ + if constexpr (UseWorkspace) { \ + submit_host_task_with_acc(cgh, queue, functor, workspace_placeholder_acc, fp_accs, \ + int_accs); \ + } \ + else { \ + (void)workspace_placeholder_acc; \ + submit_host_task(cgh, queue, functor, fp_accs, int_accs); \ + } \ + }) +#define ONEMKL_ROCSPARSE_SUBMIT_INT(FP_TYPE) \ + if (int_type == detail::data_type::int32) { \ + ONEMKL_ROCSPARSE_SUBMIT(FP_TYPE, std::int32_t); \ + } \ + else if (int_type == detail::data_type::int64) { \ + ONEMKL_ROCSPARSE_SUBMIT(FP_TYPE, std::int64_t); \ + } + + if (value_type == detail::data_type::real_fp32) { + ONEMKL_ROCSPARSE_SUBMIT_INT(float) + } + else if (value_type == detail::data_type::real_fp64) { + ONEMKL_ROCSPARSE_SUBMIT_INT(double) + } + else if (value_type == detail::data_type::complex_fp32) { + ONEMKL_ROCSPARSE_SUBMIT_INT(std::complex) + } + else if (value_type == detail::data_type::complex_fp64) { + ONEMKL_ROCSPARSE_SUBMIT_INT(std::complex) + } + +#undef ONEMKL_ROCSPARSE_SUBMIT_INT +#undef ONEMKL_ROCSPARSE_SUBMIT + + throw oneapi::mkl::exception("sparse_blas", function_name, + "Could not dispatch buffer kernel to a supported type"); + } + else { + // USM submit does not need to capture accessors + if constexpr (!UseWorkspace) { + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + submit_host_task(cgh, queue, functor); + }); + } + else { + throw oneapi::mkl::exception("sparse_blas", function_name, + "Internal error: Cannot use accessor workspace with USM"); + } + } +} + +template +sycl::event dispatch_submit(const std::string &function_name, sycl::queue queue, Functor functor, + matrix_handle_t sm_handle, + sycl::accessor workspace_placeholder_acc, + Ts... other_containers) { + return dispatch_submit(function_name, queue, {}, functor, sm_handle, + workspace_placeholder_acc, other_containers...); +} + +template +sycl::event dispatch_submit(const std::string &function_name, sycl::queue queue, + const std::vector &dependencies, Functor functor, + matrix_handle_t sm_handle, Ts... other_containers) { + return dispatch_submit(function_name, queue, dependencies, functor, sm_handle, {}, + other_containers...); +} + +template +sycl::event dispatch_submit(const std::string &function_name, sycl::queue queue, Functor functor, + matrix_handle_t sm_handle, Ts... other_containers) { + return dispatch_submit(function_name, queue, {}, functor, sm_handle, {}, + other_containers...); +} + +} // namespace oneapi::mkl::sparse::rocsparse + +#endif // _ONEMKL_SPARSE_BLAS_BACKENDS_ROCSPARSE_TASKS_HPP_ diff --git a/src/sparse_blas/backends/rocsparse/rocsparse_wrappers.cpp b/src/sparse_blas/backends/rocsparse/rocsparse_wrappers.cpp new file mode 100644 index 000000000..eaa8d82d2 --- /dev/null +++ b/src/sparse_blas/backends/rocsparse/rocsparse_wrappers.cpp @@ -0,0 +1,32 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/mkl/sparse_blas/types.hpp" + +#include "oneapi/mkl/sparse_blas/detail/rocsparse/onemkl_sparse_blas_rocsparse.hpp" + +#include "sparse_blas/function_table.hpp" + +#define WRAPPER_VERSION 1 +#define BACKEND rocsparse + +extern "C" sparse_blas_function_table_t mkl_sparse_blas_table = { + WRAPPER_VERSION, +#include "sparse_blas/backends/backend_wrappers.cxx" +}; diff --git a/tests/unit_tests/CMakeLists.txt b/tests/unit_tests/CMakeLists.txt index 5fc56d04a..59cbc4ec6 100644 --- a/tests/unit_tests/CMakeLists.txt +++ b/tests/unit_tests/CMakeLists.txt @@ -183,6 +183,11 @@ foreach(domain ${TARGET_DOMAINS}) list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_cusparse) endif() + if(domain STREQUAL "sparse_blas" AND ENABLE_ROCSPARSE_BACKEND) + add_dependencies(test_main_${domain}_ct onemkl_${domain}_rocsparse) + list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_rocsparse) + endif() + target_link_libraries(test_main_${domain}_ct PUBLIC gtest gtest_main diff --git a/tests/unit_tests/include/test_helper.hpp b/tests/unit_tests/include/test_helper.hpp index 5457079e0..7bbc14482 100644 --- a/tests/unit_tests/include/test_helper.hpp +++ b/tests/unit_tests/include/test_helper.hpp @@ -183,6 +183,13 @@ #define TEST_RUN_NVIDIAGPU_CUSPARSE_SELECT(q, func, ...) #endif +#ifdef ENABLE_ROCSPARSE_BACKEND +#define TEST_RUN_AMDGPU_ROCSPARSE_SELECT(q, func, ...) \ + func(oneapi::mkl::backend_selector{ q }, __VA_ARGS__) +#else +#define TEST_RUN_AMDGPU_ROCSPARSE_SELECT(q, func, ...) +#endif + #ifndef __HIPSYCL__ #define CHECK_HOST_OR_CPU(q) q.get_device().is_cpu() #else @@ -278,6 +285,9 @@ else if (vendor_id == NVIDIA_ID) { \ TEST_RUN_NVIDIAGPU_CUSPARSE_SELECT(q, func, __VA_ARGS__); \ } \ + else if (vendor_id == AMD_ID) { \ + TEST_RUN_AMDGPU_ROCSPARSE_SELECT(q, func, __VA_ARGS__); \ + } \ } \ } while (0); diff --git a/tests/unit_tests/main_test.cpp b/tests/unit_tests/main_test.cpp index fc208da09..fe79876a3 100644 --- a/tests/unit_tests/main_test.cpp +++ b/tests/unit_tests/main_test.cpp @@ -129,7 +129,8 @@ int main(int argc, char** argv) { #endif #if !defined(ENABLE_ROCBLAS_BACKEND) && !defined(ENABLE_ROCRAND_BACKEND) && \ !defined(ENABLE_ROCSOLVER_BACKEND) && !defined(ENABLE_PORTBLAS_BACKEND_AMD_GPU) && \ - !defined(ENABLE_ROCFFT_BACKEND) && !defined(ENABLE_PORTFFT_BACKEND) + !defined(ENABLE_ROCFFT_BACKEND) && !defined(ENABLE_PORTFFT_BACKEND) && \ + !defined(ENABLE_ROCSPARSE_BACKEND) if (dev.is_gpu() && vendor_id == AMD_ID) continue; #endif diff --git a/tests/unit_tests/sparse_blas/include/test_common.hpp b/tests/unit_tests/sparse_blas/include/test_common.hpp index 180c355cd..4700758ec 100644 --- a/tests/unit_tests/sparse_blas/include/test_common.hpp +++ b/tests/unit_tests/sparse_blas/include/test_common.hpp @@ -59,12 +59,28 @@ enum sparse_matrix_format_t { COO, }; -static std::vector> test_matrix_properties{ - { oneapi::mkl::sparse::matrix_property::sorted }, - { oneapi::mkl::sparse::matrix_property::symmetric }, - { oneapi::mkl::sparse::matrix_property::sorted, - oneapi::mkl::sparse::matrix_property::symmetric } -}; +inline std::set get_default_matrix_properties( + sycl::queue queue, sparse_matrix_format_t format) { + auto vendor_id = oneapi::mkl::get_device_id(queue); + if (vendor_id == oneapi::mkl::device::amdgpu && format == sparse_matrix_format_t::COO) { + return { oneapi::mkl::sparse::matrix_property::sorted }; + } + return {}; +} + +/// Return the combinations of matrix_properties to test other than the default +inline std::vector> +get_all_matrix_properties_combinations(sycl::queue queue, sparse_matrix_format_t format) { + auto vendor_id = oneapi::mkl::get_device_id(queue); + if (vendor_id == oneapi::mkl::device::amdgpu && format == sparse_matrix_format_t::COO) { + return { { oneapi::mkl::sparse::matrix_property::sorted, + oneapi::mkl::sparse::matrix_property::symmetric } }; + } + return { { oneapi::mkl::sparse::matrix_property::sorted }, + { oneapi::mkl::sparse::matrix_property::symmetric }, + { oneapi::mkl::sparse::matrix_property::sorted, + oneapi::mkl::sparse::matrix_property::symmetric } }; +} void print_error_code(sycl::exception const &e); @@ -216,9 +232,9 @@ template fpType generate_data(bool is_diag) { rand_scalar rand_data; if (is_diag) { - // Guarantee an amplitude >= 0.1 + // Guarantee a large amplitude fpType sign = (std::rand() % 2) * 2 - 1; - return rand_data(0.1, 0.5) * sign; + return rand_data(10, 20) * sign; } return rand_data(-0.5, 0.5); } @@ -426,6 +442,8 @@ void set_matrix_data(sycl::queue &queue, sparse_matrix_format_t format, oneapi::mkl::sparse::matrix_handle_t smhandle, std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index, ContainerIndexT rows, ContainerIndexT cols, ContainerValueT vals) { + // Ensure to wait for previous operations to finish before resetting the data + queue.wait_and_throw(); if (format == sparse_matrix_format_t::CSR) { CALL_RT_OR_CT(oneapi::mkl::sparse::set_csr_matrix_data, queue, smhandle, num_rows, num_cols, nnz, index, rows, cols, vals); diff --git a/tests/unit_tests/sparse_blas/include/test_spmm.hpp b/tests/unit_tests/sparse_blas/include/test_spmm.hpp index 6188d4268..d47b1732c 100644 --- a/tests/unit_tests/sparse_blas/include/test_spmm.hpp +++ b/tests/unit_tests/sparse_blas/include/test_spmm.hpp @@ -65,10 +65,13 @@ void test_helper_with_format_with_transpose( oneapi::mkl::layout col_major = oneapi::mkl::layout::col_major; oneapi::mkl::sparse::spmm_alg default_alg = oneapi::mkl::sparse::spmm_alg::default_alg; oneapi::mkl::sparse::matrix_view default_A_view; - std::set no_properties; bool no_reset_data = false; bool no_scalars_on_device = false; + // Queue is only used to get which matrix_property should be used for the tests. + sycl::queue properties_queue(*dev); + auto default_properties = get_default_matrix_properties(properties_queue, format); + { int m = 4, k = 6, n = 5; int nrows_A = (transpose_A != oneapi::mkl::transpose::nontrans) ? k : m; @@ -84,34 +87,34 @@ void test_helper_with_format_with_transpose( EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb, ldc, - default_alg, default_A_view, no_properties, no_reset_data, + default_alg, default_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Reset data EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb, ldc, - default_alg, default_A_view, no_properties, true, + default_alg, default_A_view, default_properties, true, no_scalars_on_device), num_passed, num_skipped); // Test alpha and beta on the device EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb, ldc, - default_alg, default_A_view, no_properties, no_reset_data, true), + default_alg, default_A_view, default_properties, no_reset_data, true), num_passed, num_skipped); // Test index_base 1 EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, oneapi::mkl::index_base::one, col_major, transpose_A, transpose_B, - fp_one, fp_zero, ldb, ldc, default_alg, default_A_view, no_properties, - no_reset_data, no_scalars_on_device), + fp_one, fp_zero, ldb, ldc, default_alg, default_A_view, + default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test non-default alpha EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, transpose_B, set_fp_value()(2.f, 1.5f), - fp_zero, ldb, ldc, default_alg, default_A_view, no_properties, + fp_zero, ldb, ldc, default_alg, default_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test non-default beta @@ -119,42 +122,43 @@ void test_helper_with_format_with_transpose( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, transpose_B, fp_one, set_fp_value()(3.2f, 1.f), ldb, ldc, default_alg, - default_A_view, no_properties, no_reset_data, no_scalars_on_device), + default_A_view, default_properties, no_reset_data, + no_scalars_on_device), num_passed, num_skipped); // Test 0 alpha EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, transpose_B, fp_zero, fp_one, ldb, ldc, - default_alg, default_A_view, no_properties, no_reset_data, + default_alg, default_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test 0 alpha and beta EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, transpose_B, fp_zero, fp_zero, ldb, ldc, - default_alg, default_A_view, no_properties, no_reset_data, + default_alg, default_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test non-default ldb EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb + 5, ldc, - default_alg, default_A_view, no_properties, no_reset_data, + default_alg, default_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test non-default ldc EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb, ldc + 6, - default_alg, default_A_view, no_properties, no_reset_data, + default_alg, default_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test row major layout EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, oneapi::mkl::layout::row_major, transpose_A, transpose_B, fp_one, - fp_zero, ncols_B, ncols_C, default_alg, default_A_view, no_properties, - no_reset_data, no_scalars_on_device), + fp_zero, ncols_B, ncols_C, default_alg, default_A_view, + default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test int64 indices long long_nrows_A = 27, long_ncols_A = 13, long_ncols_C = 6; @@ -163,19 +167,19 @@ void test_helper_with_format_with_transpose( test_functor_i64(dev, format, long_nrows_A, long_ncols_A, long_ncols_C, density_A_matrix, index_zero, col_major, transpose_A, transpose_B, fp_one, fp_zero, long_ldb, long_ldc, default_alg, default_A_view, - no_properties, no_reset_data, no_scalars_on_device), + default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test other algorithms for (auto alg : non_default_algorithms) { EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, transpose_B, fp_one, fp_zero, - ldb, ldc, alg, default_A_view, no_properties, no_reset_data, + ldb, ldc, alg, default_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); } // Test matrix properties - for (auto properties : test_matrix_properties) { + for (auto properties : get_all_matrix_properties_combinations(properties_queue, format)) { EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, transpose_B, fp_one, fp_zero, @@ -197,7 +201,7 @@ void test_helper_with_format_with_transpose( EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb, ldc, - default_alg, default_A_view, no_properties, no_reset_data, + default_alg, default_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); } diff --git a/tests/unit_tests/sparse_blas/include/test_spmv.hpp b/tests/unit_tests/sparse_blas/include/test_spmv.hpp index 6ee256adb..5f9484de8 100644 --- a/tests/unit_tests/sparse_blas/include/test_spmv.hpp +++ b/tests/unit_tests/sparse_blas/include/test_spmv.hpp @@ -63,62 +63,65 @@ void test_helper_with_format_with_transpose( oneapi::mkl::index_base index_zero = oneapi::mkl::index_base::zero; oneapi::mkl::sparse::spmv_alg default_alg = oneapi::mkl::sparse::spmv_alg::default_alg; oneapi::mkl::sparse::matrix_view default_A_view; - std::set no_properties; bool no_reset_data = false; bool no_scalars_on_device = false; + // Queue is only used to get which matrix_property should be used for the tests. + sycl::queue properties_queue(*dev); + auto default_properties = get_default_matrix_properties(properties_queue, format); + // Basic test EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, - fp_one, fp_zero, default_alg, default_A_view, no_properties, no_reset_data, - no_scalars_on_device), + fp_one, fp_zero, default_alg, default_A_view, default_properties, + no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Reset data EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, - fp_one, fp_zero, default_alg, default_A_view, no_properties, true, + fp_one, fp_zero, default_alg, default_A_view, default_properties, true, no_scalars_on_device), num_passed, num_skipped); // Test alpha and beta on the device EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, - fp_one, fp_zero, default_alg, default_A_view, no_properties, no_reset_data, - true), + fp_one, fp_zero, default_alg, default_A_view, default_properties, + no_reset_data, true), num_passed, num_skipped); // Test index_base 1 EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, oneapi::mkl::index_base::one, transpose_val, fp_one, fp_zero, default_alg, - default_A_view, no_properties, no_reset_data, no_scalars_on_device), + default_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test non-default alpha EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, set_fp_value()(2.f, 1.5f), fp_zero, default_alg, default_A_view, - no_properties, no_reset_data, no_scalars_on_device), + default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test non-default beta EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, fp_one, set_fp_value()(3.2f, 1.f), default_alg, default_A_view, - no_properties, no_reset_data, no_scalars_on_device), + default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test 0 alpha EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, - fp_zero, fp_one, default_alg, default_A_view, no_properties, no_reset_data, - no_scalars_on_device), + fp_zero, fp_one, default_alg, default_A_view, default_properties, + no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test 0 alpha and beta EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, - fp_zero, fp_zero, default_alg, default_A_view, no_properties, + fp_zero, fp_zero, default_alg, default_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test int64 indices EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i64(dev, format, 27L, 13L, density_A_matrix, index_zero, transpose_val, fp_one, - fp_zero, default_alg, default_A_view, no_properties, no_reset_data, + fp_zero, default_alg, default_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Lower triangular @@ -126,14 +129,14 @@ void test_helper_with_format_with_transpose( oneapi::mkl::sparse::matrix_descr::triangular); EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, - fp_one, fp_zero, default_alg, triangular_A_view, no_properties, + fp_one, fp_zero, default_alg, triangular_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Upper triangular triangular_A_view.uplo_view = oneapi::mkl::uplo::upper; EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, - fp_one, fp_zero, default_alg, triangular_A_view, no_properties, + fp_one, fp_zero, default_alg, triangular_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Lower triangular unit diagonal @@ -142,14 +145,14 @@ void test_helper_with_format_with_transpose( triangular_unit_A_view.diag_view = oneapi::mkl::diag::unit; EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, - fp_one, fp_zero, default_alg, triangular_unit_A_view, no_properties, + fp_one, fp_zero, default_alg, triangular_unit_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Upper triangular unit diagonal triangular_A_view.uplo_view = oneapi::mkl::uplo::upper; EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, - fp_one, fp_zero, default_alg, triangular_unit_A_view, no_properties, + fp_one, fp_zero, default_alg, triangular_unit_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); if (transpose_val != oneapi::mkl::transpose::conjtrans) { @@ -160,14 +163,14 @@ void test_helper_with_format_with_transpose( EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, fp_one, fp_zero, default_alg, symmetric_view, - no_properties, no_reset_data, no_scalars_on_device), + default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Upper symmetric symmetric_view.uplo_view = oneapi::mkl::uplo::upper; EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, fp_one, fp_zero, default_alg, symmetric_view, - no_properties, no_reset_data, no_scalars_on_device), + default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Lower hermitian oneapi::mkl::sparse::matrix_view hermitian_view( @@ -175,26 +178,26 @@ void test_helper_with_format_with_transpose( EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, fp_one, fp_zero, default_alg, hermitian_view, - no_properties, no_reset_data, no_scalars_on_device), + default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Upper hermitian hermitian_view.uplo_view = oneapi::mkl::uplo::upper; EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, fp_one, fp_zero, default_alg, hermitian_view, - no_properties, no_reset_data, no_scalars_on_device), + default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); } // Test other algorithms for (auto alg : non_default_algorithms) { EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, - transpose_val, fp_one, fp_zero, alg, default_A_view, no_properties, - no_reset_data, no_scalars_on_device), + transpose_val, fp_one, fp_zero, alg, default_A_view, + default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); } // Test matrix properties - for (auto properties : test_matrix_properties) { + for (auto properties : get_all_matrix_properties_combinations(properties_queue, format)) { EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, fp_one, fp_zero, default_alg, default_A_view, diff --git a/tests/unit_tests/sparse_blas/include/test_spsv.hpp b/tests/unit_tests/sparse_blas/include/test_spsv.hpp index bdf9210f8..ca58dfd7a 100644 --- a/tests/unit_tests/sparse_blas/include/test_spsv.hpp +++ b/tests/unit_tests/sparse_blas/include/test_spsv.hpp @@ -60,76 +60,83 @@ void test_helper_with_format(testFunctorI32 test_functor_i32, testFunctorI64 tes oneapi::mkl::sparse::matrix_view default_A_view(oneapi::mkl::sparse::matrix_descr::triangular); oneapi::mkl::sparse::matrix_view upper_A_view(oneapi::mkl::sparse::matrix_descr::triangular); upper_A_view.uplo_view = oneapi::mkl::uplo::upper; - std::set no_properties; bool no_reset_data = false; bool no_scalars_on_device = false; + // Queue is only used to get which matrix_property should be used for the tests. + sycl::queue properties_queue(*dev); + auto default_properties = get_default_matrix_properties(properties_queue, format); + // Basic test - EXPECT_TRUE_OR_FUTURE_SKIP(test_functor_i32(dev, format, m, density_A_matrix, index_zero, - transpose_val, alpha, default_alg, default_A_view, - no_properties, no_reset_data, no_scalars_on_device), - num_passed, num_skipped); - // Reset data EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, m, density_A_matrix, index_zero, transpose_val, alpha, - default_alg, default_A_view, no_properties, true, no_scalars_on_device), + default_alg, default_A_view, default_properties, no_reset_data, + no_scalars_on_device), num_passed, num_skipped); + // Reset data + EXPECT_TRUE_OR_FUTURE_SKIP(test_functor_i32(dev, format, m, density_A_matrix, index_zero, + transpose_val, alpha, default_alg, default_A_view, + default_properties, true, no_scalars_on_device), + num_passed, num_skipped); // Test alpha on the device EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, m, density_A_matrix, index_zero, transpose_val, alpha, - default_alg, default_A_view, no_properties, no_reset_data, true), + default_alg, default_A_view, default_properties, no_reset_data, true), num_passed, num_skipped); // Test index_base 1 EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, m, density_A_matrix, oneapi::mkl::index_base::one, - transpose_val, alpha, default_alg, default_A_view, no_properties, + transpose_val, alpha, default_alg, default_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test upper triangular matrix - EXPECT_TRUE_OR_FUTURE_SKIP(test_functor_i32(dev, format, m, density_A_matrix, index_zero, - transpose_val, alpha, default_alg, upper_A_view, - no_properties, no_reset_data, no_scalars_on_device), - num_passed, num_skipped); + EXPECT_TRUE_OR_FUTURE_SKIP( + test_functor_i32(dev, format, m, density_A_matrix, index_zero, transpose_val, alpha, + default_alg, upper_A_view, default_properties, no_reset_data, + no_scalars_on_device), + num_passed, num_skipped); // Test lower triangular unit diagonal matrix oneapi::mkl::sparse::matrix_view triangular_unit_A_view( oneapi::mkl::sparse::matrix_descr::triangular); triangular_unit_A_view.diag_view = oneapi::mkl::diag::unit; EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, m, density_A_matrix, index_zero, transpose_val, alpha, - default_alg, triangular_unit_A_view, no_properties, no_reset_data, + default_alg, triangular_unit_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test upper triangular unit diagonal matrix triangular_unit_A_view.uplo_view = oneapi::mkl::uplo::upper; EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, m, density_A_matrix, index_zero, transpose_val, alpha, - default_alg, triangular_unit_A_view, no_properties, no_reset_data, + default_alg, triangular_unit_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test non-default alpha EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, m, density_A_matrix, index_zero, transpose_val, set_fp_value()(2.f, 1.5f), default_alg, default_A_view, - no_properties, no_reset_data, no_scalars_on_device), + default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test int64 indices - EXPECT_TRUE_OR_FUTURE_SKIP(test_functor_i64(dev, format, 15L, density_A_matrix, index_zero, - transpose_val, alpha, default_alg, default_A_view, - no_properties, no_reset_data, no_scalars_on_device), - num_passed, num_skipped); + EXPECT_TRUE_OR_FUTURE_SKIP( + test_functor_i64(dev, format, 15L, density_A_matrix, index_zero, transpose_val, alpha, + default_alg, default_A_view, default_properties, no_reset_data, + no_scalars_on_device), + num_passed, num_skipped); // Test lower no_optimize_alg EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, m, density_A_matrix, index_zero, transpose_val, alpha, - no_optimize_alg, default_A_view, no_properties, no_reset_data, + no_optimize_alg, default_A_view, default_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); // Test upper no_optimize_alg - EXPECT_TRUE_OR_FUTURE_SKIP(test_functor_i32(dev, format, m, density_A_matrix, index_zero, - transpose_val, alpha, no_optimize_alg, upper_A_view, - no_properties, no_reset_data, no_scalars_on_device), - num_passed, num_skipped); + EXPECT_TRUE_OR_FUTURE_SKIP( + test_functor_i32(dev, format, m, density_A_matrix, index_zero, transpose_val, alpha, + no_optimize_alg, upper_A_view, default_properties, no_reset_data, + no_scalars_on_device), + num_passed, num_skipped); // Test matrix properties - for (auto properties : test_matrix_properties) { + for (auto properties : get_all_matrix_properties_combinations(properties_queue, format)) { // Basic test with matrix properties EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, m, density_A_matrix, index_zero, transpose_val, alpha, diff --git a/tests/unit_tests/sparse_blas/source/sparse_spsv_buffer.cpp b/tests/unit_tests/sparse_blas/source/sparse_spsv_buffer.cpp index 778439162..522327e65 100644 --- a/tests/unit_tests/sparse_blas/source/sparse_spsv_buffer.cpp +++ b/tests/unit_tests/sparse_blas/source/sparse_spsv_buffer.cpp @@ -47,6 +47,9 @@ int test_spsv(sycl::device *dev, sparse_matrix_format_t format, intType m, doubl matrix_properties.find(oneapi::mkl::sparse::matrix_property::symmetric) != matrix_properties.cend(); + // Use a fixed seed for operations very sensible to the input data + std::srand(1); + // Input matrix std::vector ia_host, ja_host; std::vector a_host; diff --git a/tests/unit_tests/sparse_blas/source/sparse_spsv_usm.cpp b/tests/unit_tests/sparse_blas/source/sparse_spsv_usm.cpp index 54e496036..e69146d57 100644 --- a/tests/unit_tests/sparse_blas/source/sparse_spsv_usm.cpp +++ b/tests/unit_tests/sparse_blas/source/sparse_spsv_usm.cpp @@ -43,6 +43,9 @@ int test_spsv(sycl::device *dev, sparse_matrix_format_t format, intType m, doubl matrix_properties.find(oneapi::mkl::sparse::matrix_property::symmetric) != matrix_properties.cend(); + // Use a fixed seed for operations very sensible to the input data + std::srand(1); + // Input matrix std::vector ia_host, ja_host; std::vector a_host;