Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rocBLAS] Fix issues with rocBLAS 4.0.0 #448

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion cmake/FindrocBLAS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ list(APPEND CMAKE_PREFIX_PATH

find_package(HIP QUIET)
find_package(rocblas REQUIRED)
set(ROCBLAS_VERSION ${rocblas_VERSION})

# this is work around to avoid duplication half creation in both HIP and SYCL
add_compile_definitions(HIP_NO_HALF)
Expand All @@ -47,12 +48,21 @@ find_package_handle_standard_args(rocBLAS
HIP_LIBRARIES
ROCBLAS_INCLUDE_DIR
ROCBLAS_LIBRARIES
VERSION_VAR
ROCBLAS_VERSION
)

if (DEFINED rocblas_INCLUDE_DIR)
set(ROCBLAS_LIB_PATH "${rocblas_INCLUDE_DIR}/../lib/librocblas.so")
else()
set(ROCBLAS_LIB_PATH "${HIP_PATH}/../rocblas/lib/librocblas.so")
endif()

# OPENCL_INCLUDE_DIR
if(NOT TARGET ONEMKL::rocBLAS::rocBLAS)
add_library(ONEMKL::rocBLAS::rocBLAS SHARED IMPORTED)
set_target_properties(ONEMKL::rocBLAS::rocBLAS PROPERTIES
IMPORTED_LOCATION "${HIP_PATH}/../rocblas/lib/librocblas.so"
IMPORTED_LOCATION "${ROCBLAS_LIB_PATH}"
INTERFACE_INCLUDE_DIRECTORIES "${ROCBLAS_INCLUDE_DIR};${HIP_INCLUDE_DIRS};"
INTERFACE_LINK_LIBRARIES "Threads::Threads;${ROCBLAS_LIBRARIES};"
)
Expand Down
4 changes: 4 additions & 0 deletions src/blas/backends/rocblas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ target_include_directories(${LIB_OBJ}
${ONEMKL_GENERATED_INCLUDE_PATH}
)

if (${ROCBLAS_VERSION} VERSION_GREATER_EQUAL "4.0")
target_compile_definitions(${LIB_OBJ} PRIVATE ROCBLAS_NO_LEGACY_TRMM)
endif()

if(NOT ${ONEMKL_SYCL_IMPLEMENTATION} STREQUAL "hipsycl")
target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT})
target_compile_options(ONEMKL::SYCL::SYCL INTERFACE
Expand Down
18 changes: 18 additions & 0 deletions src/blas/backends/rocblas/rocblas_level3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,19 @@ inline void trmm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe
auto a_ = sc.get_mem<rocDataType *>(a_acc);
auto b_ = sc.get_mem<rocDataType *>(b_acc);
rocblas_status err;

// rocblas version 4.0.0 removed the legacy BLAS trmm implementation
#ifdef ROCBLAS_NO_LEGACY_TRMM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#ifdef ROCBLAS_NO_LEGACY_TRMM
#if ROCBLAS_VERSION_MAJOR >= 3

We can use ROCBLAS_VERSION_MAJOR, that way we don't need to add a new define, also since it's deprecated in rocBLAS 3 it gives warning so we could also use the new interface for rocBLAS 3.x.

I ran into the same issue and ended up with a similar patch before finding this PR. I didn't have any ROCBLAS_LIB_PATH issues with rocBLAS 4.x though.

ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
get_rocblas_fill_mode(upper_lower),
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
m, n, (rocDataType *)&alpha, a_, lda, b_, ldb, b_, ldb);
#else
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
get_rocblas_fill_mode(upper_lower),
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
m, n, (rocDataType *)&alpha, a_, lda, b_, ldb);
#endif
});
});
}
Expand Down Expand Up @@ -805,10 +814,19 @@ inline sycl::event trmm(Func func, sycl::queue &queue, side left_right, uplo upp
auto a_ = reinterpret_cast<const rocDataType *>(a);
auto b_ = reinterpret_cast<rocDataType *>(b);
rocblas_status err;

// rocblas version 4.0.0 removed the legacy BLAS trmm implementation
#ifdef ROCBLAS_NO_LEGACY_TRMM
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
get_rocblas_fill_mode(upper_lower),
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
m, n, (rocDataType *)&alpha, a_, lda, b_, ldb, b_, ldb);
#else
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
get_rocblas_fill_mode(upper_lower),
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
m, n, (rocDataType *)&alpha, a_, lda, b_, ldb);
#endif
});
});

Expand Down