-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[lapack][blas][cuda] Update host task impl to use enqueue_native_command #572
base: develop
Are you sure you want to change the base?
Changes from 4 commits
e19072c
44867dc
a28cd4d
b33b9e3
8fba319
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,12 +167,25 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran | |
auto b_ = sc.get_mem<cuTypeB *>(b_acc); | ||
auto c_ = sc.get_mem<cuTypeC *>(c_acc); | ||
cublasStatus_t err; | ||
CUBLAS_ERROR_FUNC_T_SYNC( | ||
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, | ||
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_, | ||
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_, get_cublas_datatype<cuTypeB>(), | ||
ldb, stride_b, &beta, c_, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size, | ||
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo); | ||
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND | ||
CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, | ||
err, handle, get_cublas_operation(transa), | ||
get_cublas_operation(transb), m, n, k, &alpha, a_, | ||
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_, | ||
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta, c_, | ||
get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size, | ||
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo); | ||
#else | ||
CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this formatted differently from the thing it's replacing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used clang-format on it and previously it either didn't use clang-format or had a different setting etc. I can change it back. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a local clang format here. You should make sure you are using this config. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am delighted to see that we now have CI on oneMKL tests. Would another clang format job in CI be a good idea @Rbiessy ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We've not done anything on the GitHub CI but it is nice indeed. A clang format job is a good idea. I think we would need to discuss which clang-format version to use. We have seen differences between different versions, the internal CI uses clang-format-9 AFAIK. I think it would be easier to use the one shipped with DPC++ instead. I'll make a note to create an issue about that. |
||
cublasGemmStridedBatchedEx, err, handle, | ||
get_cublas_operation(transa), | ||
get_cublas_operation(transb), m, n, k, &alpha, a_, | ||
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_, | ||
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta, | ||
c_, get_cublas_datatype<cuTypeC>(), ldc, stride_c, | ||
batch_size, get_cublas_datatype<cuTypeS>(), | ||
cublas_gemm_algo); | ||
#endif | ||
Rbiessy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}); | ||
}); | ||
} | ||
|
@@ -608,12 +621,25 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra | |
onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { | ||
auto handle = sc.get_handle(queue); | ||
cublasStatus_t err; | ||
CUBLAS_ERROR_FUNC_T_SYNC( | ||
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, | ||
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a, | ||
get_cublas_datatype<cuTypeA>(), lda, stride_a, b, get_cublas_datatype<cuTypeB>(), | ||
ldb, stride_b, &beta, c, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size, | ||
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo); | ||
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND | ||
CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, | ||
err, handle, get_cublas_operation(transa), | ||
get_cublas_operation(transb), m, n, k, &alpha, a, | ||
get_cublas_datatype<cuTypeA>(), lda, stride_a, b, | ||
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta, c, | ||
get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size, | ||
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo); | ||
#else | ||
CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx", | ||
cublasGemmStridedBatchedEx, err, handle, | ||
get_cublas_operation(transa), | ||
get_cublas_operation(transb), m, n, k, &alpha, a, | ||
get_cublas_datatype<cuTypeA>(), lda, stride_a, b, | ||
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta, | ||
c, get_cublas_datatype<cuTypeC>(), ldc, stride_c, | ||
batch_size, get_cublas_datatype<cuTypeS>(), | ||
cublas_gemm_algo); | ||
#endif | ||
}); | ||
}); | ||
return done; | ||
|
@@ -687,14 +713,28 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr | |
int64_t offset = 0; | ||
cublasStatus_t err; | ||
for (int64_t i = 0; i < group_count; i++) { | ||
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND | ||
CUBLAS_ERROR_FUNC_T("cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, | ||
get_cublas_operation(transa[i]), | ||
get_cublas_operation(transb[i]), (int)m[i], (int)n[i], | ||
(int)k[i], &alpha[i], (const void *const *)(a + offset), | ||
get_cublas_datatype<cuTypeA>(), (int)lda[i], | ||
(const void *const *)(b + offset), | ||
get_cublas_datatype<cuTypeB>(), (int)ldb[i], &beta[i], | ||
(void *const *)(c + offset), get_cublas_datatype<cuTypeC>(), | ||
(int)ldc[i], (int)group_size[i], | ||
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo); | ||
#else | ||
CUBLAS_ERROR_FUNC_T_SYNC( | ||
"cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, | ||
get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i], | ||
(int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset), | ||
get_cublas_datatype<cuTypeA>(), (int)lda[i], (const void *const *)(b + offset), | ||
get_cublas_datatype<cuTypeB>(), (int)ldb[i], &beta[i], | ||
(void *const *)(c + offset), get_cublas_datatype<cuTypeC>(), (int)ldc[i], | ||
(int)group_size[i], get_cublas_datatype<cuTypeS>(), cublas_gemm_algo); | ||
get_cublas_datatype<cuTypeA>(), (int)lda[i], | ||
(const void *const *)(b + offset), get_cublas_datatype<cuTypeB>(), | ||
(int)ldb[i], &beta[i], (void *const *)(c + offset), | ||
get_cublas_datatype<cuTypeC>(), (int)ldc[i], (int)group_size[i], | ||
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo); | ||
#endif | ||
offset += group_size[i]; | ||
} | ||
}); | ||
|
@@ -792,12 +832,24 @@ inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &que | |
for (int64_t i = 0; i < group_count; i++) { | ||
auto **a_ = reinterpret_cast<const cuDataType **>(a); | ||
auto **b_ = reinterpret_cast<cuDataType **>(b); | ||
CUBLAS_ERROR_FUNC_T_SYNC( | ||
func_name, func, err, handle, get_cublas_side_mode(left_right[i]), | ||
get_cublas_fill_mode(upper_lower[i]), get_cublas_operation(trans[i]), | ||
get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i], | ||
(cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i], | ||
(int)group_size[i]); | ||
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND | ||
CUBLAS_ERROR_FUNC_T(func_name, func, err, handle, | ||
get_cublas_side_mode(left_right[i]), | ||
get_cublas_fill_mode(upper_lower[i]), | ||
get_cublas_operation(trans[i]), | ||
get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i], | ||
(cuDataType *)&alpha[i], a_ + offset, (int)lda[i], | ||
b_ + offset, (int)ldb[i], (int)group_size[i]); | ||
#else | ||
CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, | ||
get_cublas_side_mode(left_right[i]), | ||
get_cublas_fill_mode(upper_lower[i]), | ||
get_cublas_operation(trans[i]), | ||
get_cublas_diag_type(unit_diag[i]), (int)m[i], | ||
(int)n[i], (cuDataType *)&alpha[i], a_ + offset, | ||
(int)lda[i], b_ + offset, (int)ldb[i], | ||
(int)group_size[i]); | ||
#endif | ||
offset += group_size[i]; | ||
} | ||
}); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use
cublas_native_named_func("cublasGemmStridedBatchedEx", ...)
to avoid the#ifdef
? Here and in a few similar places below.