Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lapack][blas][cuda] Update host task impl to use enqueue_native_command #572

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 63 additions & 22 deletions src/blas/backends/cublas/cublas_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,25 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran
auto b_ = sc.get_mem<cuTypeB *>(b_acc);
auto c_ = sc.get_mem<cuTypeC *>(c_acc);
cublasStatus_t err;
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c_, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use cublas_native_named_func("cublasGemmStridedBatchedEx", ...) to avoid the #ifdef? Here and in a few similar places below.

CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx,
err, handle, get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_,
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta, c_,
get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this formatted differently from the thing it's replacing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used clang-format on it and previously it either didn't use clang-format or had a different setting etc. I can change it back.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a local clang format here. You should make sure you are using this config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am delighted to see that we now have CI on oneMKL tests. Would another clang format job in CI be a good idea @Rbiessy ?

Copy link
Contributor

@Rbiessy Rbiessy Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've not done anything on the GitHub CI but it is nice indeed. A clang format job is a good idea. I think we would need to discuss which clang-format version to use. We have seen differences between different versions, the internal CI uses clang-format-9 AFAIK. I think it would be easier to use the one shipped with DPC++ instead. I'll make a note to create an issue about that.

cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_,
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta,
c_, get_cublas_datatype<cuTypeC>(), ldc, stride_c,
batch_size, get_cublas_datatype<cuTypeS>(),
cublas_gemm_algo);
#endif
Rbiessy marked this conversation as resolved.
Show resolved Hide resolved
});
});
}
Expand Down Expand Up @@ -608,12 +621,25 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra
onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
cublasStatus_t err;
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx,
err, handle, get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b,
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta, c,
get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx",
cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b,
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta,
c, get_cublas_datatype<cuTypeC>(), ldc, stride_c,
batch_size, get_cublas_datatype<cuTypeS>(),
cublas_gemm_algo);
#endif
});
});
return done;
Expand Down Expand Up @@ -687,14 +713,28 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr
int64_t offset = 0;
cublasStatus_t err;
for (int64_t i = 0; i < group_count; i++) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T("cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle,
get_cublas_operation(transa[i]),
get_cublas_operation(transb[i]), (int)m[i], (int)n[i],
(int)k[i], &alpha[i], (const void *const *)(a + offset),
get_cublas_datatype<cuTypeA>(), (int)lda[i],
(const void *const *)(b + offset),
get_cublas_datatype<cuTypeB>(), (int)ldb[i], &beta[i],
(void *const *)(c + offset), get_cublas_datatype<cuTypeC>(),
(int)ldc[i], (int)group_size[i],
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle,
get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i],
(int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset),
get_cublas_datatype<cuTypeA>(), (int)lda[i], (const void *const *)(b + offset),
get_cublas_datatype<cuTypeB>(), (int)ldb[i], &beta[i],
(void *const *)(c + offset), get_cublas_datatype<cuTypeC>(), (int)ldc[i],
(int)group_size[i], get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
get_cublas_datatype<cuTypeA>(), (int)lda[i],
(const void *const *)(b + offset), get_cublas_datatype<cuTypeB>(),
(int)ldb[i], &beta[i], (void *const *)(c + offset),
get_cublas_datatype<cuTypeC>(), (int)ldc[i], (int)group_size[i],
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#endif
offset += group_size[i];
}
});
Expand Down Expand Up @@ -792,12 +832,13 @@ inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &que
for (int64_t i = 0; i < group_count; i++) {
auto **a_ = reinterpret_cast<const cuDataType **>(a);
auto **b_ = reinterpret_cast<cuDataType **>(b);
CUBLAS_ERROR_FUNC_T_SYNC(
func_name, func, err, handle, get_cublas_side_mode(left_right[i]),
get_cublas_fill_mode(upper_lower[i]), get_cublas_operation(trans[i]),
get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i],
(cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i],
(int)group_size[i]);
cublas_native_named_func(func_name, func, err, handle,
get_cublas_side_mode(left_right[i]),
get_cublas_fill_mode(upper_lower[i]),
get_cublas_operation(trans[i]),
get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i],
(cuDataType *)&alpha[i], a_ + offset, (int)lda[i],
b_ + offset, (int)ldb[i], (int)group_size[i]);
offset += group_size[i];
}
});
Expand Down
27 changes: 27 additions & 0 deletions src/blas/backends/cublas/cublas_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ class cuda_error : virtual public std::runtime_error {
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId); \
cuStreamSynchronize(currentStreamId);

#define CUBLAS_ERROR_FUNC_T(name, func, err, handle, ...) \
err = func(handle, __VA_ARGS__); \
if (err != CUBLAS_STATUS_SUCCESS) { \
throw cublas_error(std::string(name) + std::string(" : "), err); \
}

#define CUBLAS_ERROR_FUNC_T_SYNC(name, func, err, handle, ...) \
err = func(handle, __VA_ARGS__); \
if (err != CUBLAS_STATUS_SUCCESS) { \
Expand All @@ -199,6 +205,27 @@ class cuda_error : virtual public std::runtime_error {
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId); \
cuStreamSynchronize(currentStreamId);

template <class Func, class... Types>
inline void cublas_native_func(Func func, cublasStatus_t err,
cublasHandle_t handle, Types... args) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC(func, err, handle, args...)
#else
CUBLAS_ERROR_FUNC_SYNC(func, err, handle, args...)
#endif
};

template <class Func, class... Types>
inline void cublas_native_named_func(const char *func_name, Func func,
cublasStatus_t err, cublasHandle_t handle,
Types... args) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T(func_name, func, err, handle, args...)
#else
CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, args...)
#endif
};

inline cublasOperation_t get_cublas_operation(oneapi::mkl::transpose trn) {
switch (trn) {
case oneapi::mkl::transpose::nontrans: return CUBLAS_OP_N;
Expand Down
Loading