Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lapack][blas][cuda] Update host task impl to use enqueue_native_command #572

Merged
merged 11 commits into from
Oct 8, 2024
32 changes: 31 additions & 1 deletion src/blas/backends/cublas/cublas_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,21 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran
auto b_ = sc.get_mem<cuTypeB *>(b_acc);
auto c_ = sc.get_mem<cuTypeC *>(c_acc);
cublasStatus_t err;
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
JackAKirk marked this conversation as resolved.
Show resolved Hide resolved
CUBLAS_ERROR_FUNC_T(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c_, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c_, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#endif
});
});
}
Expand Down Expand Up @@ -608,12 +617,21 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra
onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
cublasStatus_t err;
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#endif
});
});
return done;
Expand Down Expand Up @@ -687,6 +705,16 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr
int64_t offset = 0;
cublasStatus_t err;
for (int64_t i = 0; i < group_count; i++) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T(
"cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle,
get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i],
(int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset),
get_cublas_datatype<cuTypeA>(), (int)lda[i], (const void *const *)(b + offset),
get_cublas_datatype<cuTypeB>(), (int)ldb[i], &beta[i],
(void *const *)(c + offset), get_cublas_datatype<cuTypeC>(), (int)ldc[i],
(int)group_size[i], get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle,
get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i],
Expand All @@ -695,6 +723,7 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr
get_cublas_datatype<cuTypeB>(), (int)ldb[i], &beta[i],
(void *const *)(c + offset), get_cublas_datatype<cuTypeC>(), (int)ldc[i],
(int)group_size[i], get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#endif
offset += group_size[i];
}
});
Expand Down Expand Up @@ -792,12 +821,13 @@ inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &que
for (int64_t i = 0; i < group_count; i++) {
auto **a_ = reinterpret_cast<const cuDataType **>(a);
auto **b_ = reinterpret_cast<cuDataType **>(b);
CUBLAS_ERROR_FUNC_T_SYNC(
cublas_native_named_func(
func_name, func, err, handle, get_cublas_side_mode(left_right[i]),
get_cublas_fill_mode(upper_lower[i]), get_cublas_operation(trans[i]),
get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i],
(cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i],
(int)group_size[i]);

offset += group_size[i];
}
});
Expand Down
27 changes: 27 additions & 0 deletions src/blas/backends/cublas/cublas_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ class cuda_error : virtual public std::runtime_error {
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId); \
cuStreamSynchronize(currentStreamId);

#define CUBLAS_ERROR_FUNC_T(name, func, err, handle, ...) \
err = func(handle, __VA_ARGS__); \
if (err != CUBLAS_STATUS_SUCCESS) { \
throw cublas_error(std::string(name) + std::string(" : "), err); \
}

#define CUBLAS_ERROR_FUNC_T_SYNC(name, func, err, handle, ...) \
err = func(handle, __VA_ARGS__); \
if (err != CUBLAS_STATUS_SUCCESS) { \
Expand All @@ -199,6 +205,27 @@ class cuda_error : virtual public std::runtime_error {
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId); \
cuStreamSynchronize(currentStreamId);

template <class Func, class... Types>
inline void cublas_native_func(Func func, cublasStatus_t err,
cublasHandle_t handle, Types... args) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC(func, err, handle, args...)
#else
CUBLAS_ERROR_FUNC_SYNC(func, err, handle, args...)
#endif
};

template <class Func, class... Types>
inline void cublas_native_named_func(const char *func_name, Func func,
cublasStatus_t err, cublasHandle_t handle,
Types... args) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T(func_name, func, err, handle, args...)
#else
CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, args...)
#endif
};

inline cublasOperation_t get_cublas_operation(oneapi::mkl::transpose trn) {
switch (trn) {
case oneapi::mkl::transpose::nontrans: return CUBLAS_OP_N;
Expand Down
Loading