Skip to content

Commit

Permalink
Use cublas_native_named_func more
Browse files Browse the repository at this point in the history
Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
  • Loading branch information
JackAKirk committed Sep 20, 2024
1 parent b33b9e3 commit 8fba319
Showing 1 changed file with 1 addition and 12 deletions.
13 changes: 1 addition & 12 deletions src/blas/backends/cublas/cublas_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -832,24 +832,13 @@ inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &que
for (int64_t i = 0; i < group_count; i++) {
auto **a_ = reinterpret_cast<const cuDataType **>(a);
auto **b_ = reinterpret_cast<cuDataType **>(b);
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T(func_name, func, err, handle,
cublas_native_named_func(func_name, func, err, handle,
get_cublas_side_mode(left_right[i]),
get_cublas_fill_mode(upper_lower[i]),
get_cublas_operation(trans[i]),
get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i],
(cuDataType *)&alpha[i], a_ + offset, (int)lda[i],
b_ + offset, (int)ldb[i], (int)group_size[i]);
#else
CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle,
get_cublas_side_mode(left_right[i]),
get_cublas_fill_mode(upper_lower[i]),
get_cublas_operation(trans[i]),
get_cublas_diag_type(unit_diag[i]), (int)m[i],
(int)n[i], (cuDataType *)&alpha[i], a_ + offset,
(int)lda[i], b_ + offset, (int)ldb[i],
(int)group_size[i]);
#endif
offset += group_size[i];
}
});
Expand Down

0 comments on commit 8fba319

Please sign in to comment.