From e19072c0cc9c08dbe3b3526d3148b4992ab5c8a0 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Wed, 18 Sep 2024 11:10:38 -0700 Subject: [PATCH 01/11] Implemented cusolver native_command. See SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND for details. Signed-off-by: JackAKirk --- .../backends/cusolver/cusolver_batch.cpp | 50 +++++---- .../backends/cusolver/cusolver_helper.hpp | 11 ++ .../backends/cusolver/cusolver_lapack.cpp | 100 +++++++++--------- .../backends/cusolver/cusolver_task.hpp | 5 +- 4 files changed, 95 insertions(+), 71 deletions(-) diff --git a/src/lapack/backends/cusolver/cusolver_batch.cpp b/src/lapack/backends/cusolver/cusolver_batch.cpp index 59fa47f84..a2fdc4b99 100644 --- a/src/lapack/backends/cusolver/cusolver_batch.cpp +++ b/src/lapack/backends/cusolver/cusolver_batch.cpp @@ -53,7 +53,7 @@ inline void geqrf_batch(const char *func_name, Func func, sycl::queue &queue, st // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_ + stride_a * i, lda, tau_ + stride_tau * i, scratch_, scratchpad_size, nullptr); } @@ -137,8 +137,8 @@ inline void getri_batch(const char *func_name, Func func, sycl::queue &queue, st sizeof(T *) * batch_size); auto **scratch_dev_ = reinterpret_cast(scratch_dev); - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32_, - scratch_dev_, lda, info_, batch_size) + blas::cublas::cublas_native_func(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32_, + scratch_dev_, lda, info_, batch_size); free(a_batched); free(scratch_batched); @@ -227,7 +227,9 @@ inline void getrs_batch(const char *func_name, Func func, sycl::queue &queue, nrhs, a_ + stride_a * i, lda, ipiv_ + stride_ipiv * i, b_ + stride_b * i, ldb, nullptr); } +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUSOLVER_SYNC(err, handle) +#endif }); }); } @@ -283,7 +285,7 @@ inline void getrf_batch(const char *func_name, Func func, sycl::queue &queue, st // Uses scratch so sync between each cuSolver call for (std::int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_ + stride_a * i, lda, scratch_, ipiv_ + stride_ipiv * i, devInfo_ + i); } }); @@ -340,7 +342,7 @@ inline void orgqr_batch(const char *func_name, Func func, sycl::queue &queue, st // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_ + stride_a * i, lda, tau_ + stride_tau * i, scratch_, scratchpad_size, nullptr); } @@ -388,7 +390,7 @@ inline void potrf_batch(const char *func_name, Func func, sycl::queue &queue, auto **a_dev_ = reinterpret_cast(a_dev); - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), (int)n, a_dev_, (int)lda, nullptr, (int)batch_size); free(a_batched); @@ -452,7 +454,7 @@ inline void potrs_batch(const char *func_name, Func func, sycl::queue &queue, auto **a_dev_ = reinterpret_cast(a_dev); auto **b_dev_ = reinterpret_cast(b_dev); - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), (int)n, (int)nrhs, a_dev_, (int)lda, b_dev_, ldb, nullptr, (int)batch_size); @@ -506,7 +508,7 @@ inline void ungqr_batch(const char *func_name, Func func, sycl::queue &queue, st // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_ + stride_a * i, lda, tau_ + stride_tau * i, scratch_, scratchpad_size, nullptr); } @@ -551,7 +553,7 @@ inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &qu // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_ + stride_a * i, lda, tau_ + stride_tau * i, scratch_, scratchpad_size, nullptr); } @@ -605,7 +607,7 @@ inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &qu for (int64_t group_id = 0; group_id < group_count; ++group_id) { for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id], + cusolver_native_named_func(func_name, func, err, handle, m[group_id], n[group_id], a_[global_id], lda[group_id], tau_[global_id], scratch_, scratchpad_size, nullptr); } @@ -661,7 +663,7 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_ + stride_a * i, lda, scratchpad_, ipiv_ + stride_ipiv * i, devInfo_ + i); } }); @@ -744,7 +746,7 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu for (int64_t group_id = 0; group_id < group_count; ++group_id) { for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id], + cusolver_native_named_func(func_name, func, err, handle, m[group_id], n[group_id], a_[global_id], lda[group_id], scratch_, ipiv32[global_id], devInfo + global_id); } @@ -857,8 +859,8 @@ sycl::event getri_batch(const char *func_name, Func func, sycl::queue &queue, st sizeof(T *) * batch_size); auto **scratch_dev_ = reinterpret_cast(scratch_dev); - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32, - scratch_dev_, lda, devInfo, batch_size) + blas::cublas::cublas_native_func(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32, + scratch_dev_, lda, devInfo, batch_size); free(a_batched); free(scratch_batched); @@ -972,7 +974,9 @@ inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &qu nrhs, a_ + stride_a * i, lda, ipiv_ + stride_ipiv * i, b_ + stride_b * i, ldb, nullptr); } +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUSOLVER_SYNC(err, handle) +#endif sycl::free(ipiv32, queue); }); @@ -1062,7 +1066,9 @@ inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &qu ipiv32[global_id], b_[global_id], ldb[group_id], nullptr); } } +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUSOLVER_SYNC(err, handle) +#endif for (int64_t i = 0; i < batch_size; ++i) sycl::free(ipiv32[i], queue); @@ -1112,7 +1118,7 @@ inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &qu // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_ + stride_a * i, lda, tau_ + stride_tau * i, scratch_, scratchpad_size, nullptr); } @@ -1165,7 +1171,7 @@ inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &qu for (int64_t group_id = 0; group_id < group_count; ++group_id) { for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id], + cusolver_native_named_func(func_name, func, err, handle, m[group_id], n[group_id], k[group_id], a_[global_id], lda[group_id], tau_[global_id], scratch_, scratchpad_size, nullptr); @@ -1219,7 +1225,7 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu auto **a_dev_ = reinterpret_cast(a_dev); - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), (int)n, a_dev_, (int)lda, nullptr, (int)batch_size); free(a_batched); @@ -1281,7 +1287,9 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu (int)group_sizes[i]); offset += group_sizes[i]; } +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUSOLVER_SYNC(err, handle) +#endif cuMemFree(a_dev); }); @@ -1342,7 +1350,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu auto **a_dev_ = reinterpret_cast(a_dev); auto **b_dev_ = reinterpret_cast(b_dev); - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), (int)n, (int)nrhs, a_dev_, (int)lda, b_dev_, ldb, nullptr, (int)batch_size); @@ -1421,7 +1429,9 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu b_ + offset, (int)ldb[i], info_, (int)group_sizes[i]); offset += group_sizes[i]; } +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUSOLVER_SYNC(err, handle) +#endif }); }); return done; @@ -1467,7 +1477,7 @@ inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &qu // Uses scratch so sync between each cuSolver call for (int64_t i = 0; i < batch_size; ++i) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_ + stride_a * i, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_ + stride_a * i, lda, tau_ + stride_tau * i, scratch_, scratchpad_size, nullptr); } @@ -1520,7 +1530,7 @@ inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &qu for (int64_t group_id = 0; group_id < group_count; ++group_id) { for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id) { - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id], + cusolver_native_named_func(func_name, func, err, handle, m[group_id], n[group_id], k[group_id], a_[global_id], lda[group_id], tau_[global_id], scratch_, scratchpad_size, nullptr); diff --git a/src/lapack/backends/cusolver/cusolver_helper.hpp b/src/lapack/backends/cusolver/cusolver_helper.hpp index e10f56b36..954d41246 100644 --- a/src/lapack/backends/cusolver/cusolver_helper.hpp +++ b/src/lapack/backends/cusolver/cusolver_helper.hpp @@ -200,6 +200,17 @@ class cuda_error : virtual public std::runtime_error { } \ CUSOLVER_SYNC(err, handle) +template +inline void cusolver_native_named_func(const char *func_name, Func func, + cusolverStatus_t err, + cusolverDnHandle_t handle, Types... args){ +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, args...) +#else + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, args...) +#endif +}; + inline cusolverEigType_t get_cusolver_itype(std::int64_t itype) { switch (itype) { case 1: return CUSOLVER_EIG_TYPE_1; diff --git a/src/lapack/backends/cusolver/cusolver_lapack.cpp b/src/lapack/backends/cusolver/cusolver_lapack.cpp index 0c7aaefc8..3d54c403a 100644 --- a/src/lapack/backends/cusolver/cusolver_lapack.cpp +++ b/src/lapack/backends/cusolver/cusolver_lapack.cpp @@ -57,7 +57,7 @@ inline void gebrd(const char *func_name, Func func, sycl::queue &queue, std::int auto taup_ = sc.get_mem(taup_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, taup_, scratch_, scratchpad_size, nullptr); }); }); @@ -117,7 +117,7 @@ inline void geqrf(const char *func_name, Func func, sycl::queue &queue, std::int auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, tau_, scratch_, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -164,7 +164,7 @@ void getrf(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, scratch_, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, scratch_, ipiv32_, devInfo_); }); }); @@ -243,7 +243,7 @@ inline void getrs(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = sc.get_mem(ipiv_acc); auto b_ = sc.get_mem(b_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_operation(trans), n, nrhs, a_, lda, ipiv_, b_, ldb, nullptr); }); }); @@ -292,7 +292,7 @@ inline void gesvd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; // rwork is set to nullptr. If set it is filled with information from the superdiagonal. - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_jobsvd(jobu), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_jobsvd(jobu), get_cusolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, ldu, vt_, ldvt, scratch_, scratchpad_size, nullptr, devInfo_); }); @@ -338,7 +338,7 @@ inline void heevd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, scratchpad_size, devInfo_); }); @@ -383,7 +383,7 @@ inline void hegvd(const char *func_name, Func func, sycl::queue &queue, std::int auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_itype(itype), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_itype(itype), get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, scratch_, scratchpad_size, devInfo_); }); @@ -430,7 +430,7 @@ inline void hetrd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, d_, e_, tau_, scratch_, scratchpad_size, devInfo_); }); }); @@ -480,7 +480,7 @@ inline void orgbr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_generate(vec), m, n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -515,7 +515,7 @@ inline void orgqr(const char *func_name, Func func, sycl::queue &queue, std::int auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -550,7 +550,7 @@ inline void orgtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -589,7 +589,7 @@ inline void ormtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto c_ = sc.get_mem(c_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); @@ -644,7 +644,7 @@ inline void ormqr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto c_ = sc.get_mem(c_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); }); @@ -682,7 +682,7 @@ inline void potrf(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, scratch_, scratchpad_size, devInfo_); }); }); @@ -720,7 +720,7 @@ inline void potri(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, scratch_, scratchpad_size, devInfo_); }); }); @@ -757,7 +757,7 @@ inline void potrs(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto b_ = sc.get_mem(b_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nrhs, a_, lda, b_, ldb, nullptr); }); }); @@ -797,7 +797,7 @@ inline void syevd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, scratchpad_size, devInfo_); }); @@ -840,7 +840,7 @@ inline void sygvd(const char *func_name, Func func, sycl::queue &queue, std::int auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_itype(itype), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_itype(itype), get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, scratch_, scratchpad_size, devInfo_); }); @@ -886,7 +886,7 @@ inline void sytrd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, d_, e_, tau_, scratch_, scratchpad_size, devInfo_); }); }); @@ -934,7 +934,7 @@ inline void sytrf(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, ipiv32_, scratch_, scratchpad_size, devInfo_); }); }); @@ -1009,7 +1009,7 @@ inline void ungbr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_generate(vec), m, n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1044,7 +1044,7 @@ inline void ungqr(const char *func_name, Func func, sycl::queue &queue, std::int auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1079,7 +1079,7 @@ inline void ungtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1132,7 +1132,7 @@ inline void unmqr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto c_ = sc.get_mem(c_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); }); @@ -1173,7 +1173,7 @@ inline void unmtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto c_ = sc.get_mem(c_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); @@ -1224,7 +1224,7 @@ inline sycl::event gebrd(const char *func_name, Func func, sycl::queue &queue, s auto taup_ = reinterpret_cast(taup); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, taup_, scratch_, scratchpad_size, nullptr); }); }); @@ -1286,7 +1286,7 @@ inline sycl::event geqrf(const char *func_name, Func func, sycl::queue &queue, s auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, tau_, scratch_, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1335,7 +1335,7 @@ inline sycl::event getrf(const char *func_name, Func func, sycl::queue &queue, s auto scratch_ = reinterpret_cast(scratchpad); auto ipiv_ = reinterpret_cast(ipiv32); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, scratch_, ipiv_, + cusolver_native_named_func(func_name, func, err, handle, m, n, a_, lda, scratch_, ipiv_, devInfo_); }); }); @@ -1422,7 +1422,7 @@ inline sycl::event getrs(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = reinterpret_cast(ipiv32); auto b_ = reinterpret_cast(b); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_operation(trans), n, nrhs, a_, lda, ipiv_, b_, ldb, nullptr); }); }); @@ -1475,7 +1475,7 @@ inline sycl::event gesvd(const char *func_name, Func func, sycl::queue &queue, auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; // rwork is set to nullptr. If set it is filled with information from the superdiagonal. - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_jobsvd(jobu), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_jobsvd(jobu), get_cusolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, ldu, vt_, ldvt, scratch_, scratchpad_size, nullptr, devInfo_); }); @@ -1523,7 +1523,7 @@ inline sycl::event heevd(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, scratchpad_size, devInfo_); }); @@ -1570,7 +1570,7 @@ inline sycl::event hegvd(const char *func_name, Func func, sycl::queue &queue, s auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_itype(itype), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_itype(itype), get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, scratch_, scratchpad_size, devInfo); }); @@ -1618,7 +1618,7 @@ inline sycl::event hetrd(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, d_, e_, tau_, scratch_, scratchpad_size, devInfo_); }); }); @@ -1673,7 +1673,7 @@ inline sycl::event orgbr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_generate(vec), m, n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1712,7 +1712,7 @@ inline sycl::event orgqr(const char *func_name, Func func, sycl::queue &queue, s auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1750,7 +1750,7 @@ inline sycl::event orgtr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -1791,7 +1791,7 @@ inline sycl::event ormtr(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); @@ -1848,7 +1848,7 @@ inline sycl::event ormqr(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); }); @@ -1890,7 +1890,7 @@ inline sycl::event potrf(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, scratch_, scratchpad_size, devInfo_); }); }); @@ -1933,7 +1933,7 @@ inline sycl::event potri(const char *func_name, Func func, sycl::queue &queue, auto scratch_ = reinterpret_cast(scratchpad); auto devInfo_ = reinterpret_cast(devInfo); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, scratch_, scratchpad_size, devInfo_); }); }); @@ -1976,7 +1976,7 @@ inline sycl::event potrs(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nrhs, a_, lda, b_, ldb, nullptr); }); }); @@ -2019,7 +2019,7 @@ inline sycl::event syevd(const char *func_name, Func func, sycl::queue &queue, auto scratch_ = reinterpret_cast(scratchpad); auto devInfo_ = reinterpret_cast(devInfo); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, scratchpad_size, devInfo_); }); @@ -2065,7 +2065,7 @@ inline sycl::event sygvd(const char *func_name, Func func, sycl::queue &queue, s auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_itype(itype), + cusolver_native_named_func(func_name, func, err, handle, get_cusolver_itype(itype), get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, scratch_, scratchpad_size, devInfo); }); @@ -2111,7 +2111,7 @@ inline sycl::event sytrd(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, d_, e_, tau_, scratch_, scratchpad_size, devInfo_); }); }); @@ -2161,7 +2161,7 @@ inline sycl::event sytrf(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = reinterpret_cast(ipiv32); auto devInfo_ = reinterpret_cast(devInfo); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, ipiv_, scratch_, scratchpad_size, devInfo_); }); }); @@ -2245,7 +2245,7 @@ inline sycl::event ungbr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_generate(vec), m, n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -2284,7 +2284,7 @@ inline sycl::event ungqr(const char *func_name, Func func, sycl::queue &queue, s auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_, + cusolver_native_named_func(func_name, func, err, handle, m, n, k, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -2322,7 +2322,7 @@ inline sycl::event ungtr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + cusolver_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, lda, tau_, scratch_, scratchpad_size, nullptr); }); }); @@ -2377,7 +2377,7 @@ inline sycl::event unmqr(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); }); @@ -2421,7 +2421,7 @@ inline sycl::event unmtr(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), + cusolver_native_named_func(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, a_, lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); diff --git a/src/lapack/backends/cusolver/cusolver_task.hpp b/src/lapack/backends/cusolver/cusolver_task.hpp index 00e6e26be..6a35dea84 100644 --- a/src/lapack/backends/cusolver/cusolver_task.hpp +++ b/src/lapack/backends/cusolver/cusolver_task.hpp @@ -50,10 +50,13 @@ namespace cusolver { template static inline void host_task_internal(H &cgh, sycl::queue queue, F f) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command([f, queue](sycl::interop_handle ih){ +#else cgh.host_task([f, queue](sycl::interop_handle ih) { +#endif auto sc = CusolverScopedContextHandler(queue, ih); f(sc); - sc.wait_stream(queue); }); } From 44867dcec76ca0ffb26a5b5a010670c2325b11ae Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Wed, 18 Sep 2024 11:17:22 -0700 Subject: [PATCH 02/11] Impl native command for cublas See SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND extension document for details. Generalize helpers funcs and use them for blas l1, l2, l3, batch Signed-off-by: JackAKirk --- src/blas/backends/cublas/cublas_batch.cpp | 96 +++++++++++++++++----- src/blas/backends/cublas/cublas_helper.hpp | 27 ++++++ src/blas/backends/cublas/cublas_level1.cpp | 58 ++++++------- src/blas/backends/cublas/cublas_level2.cpp | 92 ++++++++++----------- src/blas/backends/cublas/cublas_level3.cpp | 50 +++++++---- src/blas/backends/cublas/cublas_task.hpp | 4 + 6 files changed, 212 insertions(+), 115 deletions(-) diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index 009bb9541..8b10c7744 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -167,12 +167,25 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC( - "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, - get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_, - get_cublas_datatype(), lda, stride_a, b_, get_cublas_datatype(), - ldb, stride_b, &beta, c_, get_cublas_datatype(), ldc, stride_c, batch_size, - get_cublas_datatype(), cublas_gemm_algo); +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, + err, handle, get_cublas_operation(transa), + get_cublas_operation(transb), m, n, k, &alpha, a_, + get_cublas_datatype(), lda, stride_a, b_, + get_cublas_datatype(), ldb, stride_b, &beta, c_, + get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); +#else + CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx", + cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), + get_cublas_operation(transb), m, n, k, &alpha, a_, + get_cublas_datatype(), lda, stride_a, b_, + get_cublas_datatype(), ldb, stride_b, &beta, + c_, get_cublas_datatype(), ldc, stride_c, + batch_size, get_cublas_datatype(), + cublas_gemm_algo); +#endif }); }); } @@ -608,12 +621,25 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC( - "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, - get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a, - get_cublas_datatype(), lda, stride_a, b, get_cublas_datatype(), - ldb, stride_b, &beta, c, get_cublas_datatype(), ldc, stride_c, batch_size, - get_cublas_datatype(), cublas_gemm_algo); +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, + err, handle, get_cublas_operation(transa), + get_cublas_operation(transb), m, n, k, &alpha, a, + get_cublas_datatype(), lda, stride_a, b, + get_cublas_datatype(), ldb, stride_b, &beta, c, + get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); +#else + CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx", + cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), + get_cublas_operation(transb), m, n, k, &alpha, a, + get_cublas_datatype(), lda, stride_a, b, + get_cublas_datatype(), ldb, stride_b, &beta, + c, get_cublas_datatype(), ldc, stride_c, + batch_size, get_cublas_datatype(), + cublas_gemm_algo); +#endif }); }); return done; @@ -687,14 +713,28 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr int64_t offset = 0; cublasStatus_t err; for (int64_t i = 0; i < group_count; i++) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUBLAS_ERROR_FUNC_T("cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, + get_cublas_operation(transa[i]), + get_cublas_operation(transb[i]), (int)m[i], (int)n[i], + (int)k[i], &alpha[i], (const void *const *)(a + offset), + get_cublas_datatype(), (int)lda[i], + (const void *const *)(b + offset), + get_cublas_datatype(), (int)ldb[i], &beta[i], + (void *const *)(c + offset), get_cublas_datatype(), + (int)ldc[i], (int)group_size[i], + get_cublas_datatype(), cublas_gemm_algo); +#else CUBLAS_ERROR_FUNC_T_SYNC( "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset), - get_cublas_datatype(), (int)lda[i], (const void *const *)(b + offset), - get_cublas_datatype(), (int)ldb[i], &beta[i], - (void *const *)(c + offset), get_cublas_datatype(), (int)ldc[i], - (int)group_size[i], get_cublas_datatype(), cublas_gemm_algo); + get_cublas_datatype(), (int)lda[i], + (const void *const *)(b + offset), get_cublas_datatype(), + (int)ldb[i], &beta[i], (void *const *)(c + offset), + get_cublas_datatype(), (int)ldc[i], (int)group_size[i], + get_cublas_datatype(), cublas_gemm_algo); +#endif offset += group_size[i]; } }); @@ -792,12 +832,24 @@ inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &que for (int64_t i = 0; i < group_count; i++) { auto **a_ = reinterpret_cast(a); auto **b_ = reinterpret_cast(b); - CUBLAS_ERROR_FUNC_T_SYNC( - func_name, func, err, handle, get_cublas_side_mode(left_right[i]), - get_cublas_fill_mode(upper_lower[i]), get_cublas_operation(trans[i]), - get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i], - (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i], - (int)group_size[i]); +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUBLAS_ERROR_FUNC_T(func_name, func, err, handle, + get_cublas_side_mode(left_right[i]), + get_cublas_fill_mode(upper_lower[i]), + get_cublas_operation(trans[i]), + get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i], + (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], + b_ + offset, (int)ldb[i], (int)group_size[i]); +#else + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_side_mode(left_right[i]), + get_cublas_fill_mode(upper_lower[i]), + get_cublas_operation(trans[i]), + get_cublas_diag_type(unit_diag[i]), (int)m[i], + (int)n[i], (cuDataType *)&alpha[i], a_ + offset, + (int)lda[i], b_ + offset, (int)ldb[i], + (int)group_size[i]); +#endif offset += group_size[i]; } }); diff --git a/src/blas/backends/cublas/cublas_helper.hpp b/src/blas/backends/cublas/cublas_helper.hpp index 0fe7e7c5a..0bd4e6274 100644 --- a/src/blas/backends/cublas/cublas_helper.hpp +++ b/src/blas/backends/cublas/cublas_helper.hpp @@ -190,6 +190,12 @@ class cuda_error : virtual public std::runtime_error { CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); \ cuStreamSynchronize(currentStreamId); +#define CUBLAS_ERROR_FUNC_T(name, func, err, handle, ...) \ + err = func(handle, __VA_ARGS__); \ + if (err != CUBLAS_STATUS_SUCCESS) { \ + throw cublas_error(std::string(name) + std::string(" : "), err); \ + } + #define CUBLAS_ERROR_FUNC_T_SYNC(name, func, err, handle, ...) \ err = func(handle, __VA_ARGS__); \ if (err != CUBLAS_STATUS_SUCCESS) { \ @@ -199,6 +205,27 @@ class cuda_error : virtual public std::runtime_error { CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); \ cuStreamSynchronize(currentStreamId); +template +inline void cublas_native_func(Func func, cublasStatus_t err, + cublasHandle_t handle, Types... args) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUBLAS_ERROR_FUNC(func, err, handle, args...) +#else + CUBLAS_ERROR_FUNC_SYNC(func, err, handle, args...) +#endif +}; + +template +inline void cublas_native_named_func(const char *func_name, Func func, + cublasStatus_t err, cublasHandle_t handle, + Types... args) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUBLAS_ERROR_FUNC_T(func_name, func, err, handle, args...) +#else + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, args...) +#endif +}; + inline cublasOperation_t get_cublas_operation(oneapi::mkl::transpose trn) { switch (trn) { case oneapi::mkl::transpose::nontrans: return CUBLAS_OP_N; diff --git a/src/blas/backends/cublas/cublas_level1.cpp b/src/blas/backends/cublas/cublas_level1.cpp index 5f7087727..3b0699c87 100644 --- a/src/blas/backends/cublas/cublas_level1.cpp +++ b/src/blas/backends/cublas/cublas_level1.cpp @@ -53,7 +53,7 @@ inline void asum(const char *func_name, Func func, sycl::queue &queue, int64_t n auto res_ = sc.get_mem(res_acc); cublasStatus_t err; // ASUM does not support negative index - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, std::abs(incx), res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, std::abs(incx), res_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -86,7 +86,7 @@ inline void scal(const char *func_name, Func func, sycl::queue &queue, int64_t n auto x_ = sc.get_mem(x_acc); cublasStatus_t err; // SCAL does not support negative incx - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, (cuDataType1 *)&a, x_, + cublas_native_named_func(func_name, func, err, handle, n, (cuDataType1 *)&a, x_, std::abs(incx)); }); }); @@ -117,7 +117,7 @@ inline void axpy(const char *func_name, Func func, sycl::queue &queue, int64_t n auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, (cuDataType *)&alpha, x_, + cublas_native_named_func(func_name, func, err, handle, n, (cuDataType *)&alpha, x_, incx, y_, incy); }); }); @@ -180,7 +180,7 @@ inline void rotg(const char *func_name, Func func, sycl::queue &queue, sycl::buf auto c_ = sc.get_mem(c_acc); auto s_ = sc.get_mem(s_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, a_, b_, c_, s_); + cublas_native_named_func(func_name, func, err, handle, a_, b_, c_, s_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -223,7 +223,7 @@ inline void rotm(const char *func_name, Func func, sycl::queue &queue, int64_t n auto y_ = sc.get_mem(y_acc); auto param_ = sc.get_mem(param_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy, param_); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy, param_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -255,7 +255,7 @@ inline void copy(const char *func_name, Func func, sycl::queue &queue, int64_t n auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy); }); }); } @@ -294,7 +294,7 @@ inline void dot(const char *func_name, Func func, sycl::queue &queue, int64_t n, auto y_ = sc.get_mem(y_acc); auto res_ = sc.get_mem(res_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy, res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy, res_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -338,7 +338,7 @@ inline void rot(const char *func_name, Func func, sycl::queue &queue, int64_t n, auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy, + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy, (cuDataType2 *)&c, (cuDataType3 *)&s); }); }); @@ -376,7 +376,7 @@ void sdsdot(sycl::queue &queue, int64_t n, float sb, sycl::buffer &x, auto y_ = sc.get_mem(y_acc); auto res_ = sc.get_mem(res_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_SYNC(cublasSdot, err, handle, n, x_, incx, y_, incy, res_); + cublas_native_func(cublasSdot, err, handle, n, x_, incx, y_, incy, res_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -418,7 +418,7 @@ inline void rotmg(const char *func_name, Func func, sycl::queue &queue, sycl::bu auto y1_ = sc.get_mem(y1_acc); auto param_ = sc.get_mem(param_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, d1_, d2_, x1_, y1_, param_); + cublas_native_named_func(func_name, func, err, handle, d1_, d2_, x1_, y1_, param_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -466,7 +466,7 @@ inline void iamax(const char *func_name, Func func, sycl::queue &queue, int64_t cublasStatus_t err; // For negative incx, iamax returns 0. This behaviour is similar to that of // reference netlib BLAS. - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, int_res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, int_res_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -506,7 +506,7 @@ inline void swap(const char *func_name, Func func, sycl::queue &queue, int64_t n auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy); }); }); } @@ -552,7 +552,7 @@ inline void iamin(const char *func_name, Func func, sycl::queue &queue, int64_t cublasStatus_t err; // For negative incx, iamin returns 0. This behaviour is similar to that of // implemented as a reference IAMIN. - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, int_res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, int_res_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -601,7 +601,7 @@ inline void nrm2(const char *func_name, Func func, sycl::queue &queue, int64_t n auto res_ = sc.get_mem(res_acc); cublasStatus_t err; // NRM2 does not support negative index - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, std::abs(incx), res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, std::abs(incx), res_); // Higher level BLAS functions expect CUBLAS_POINTER_MODE_HOST // to be set, therfore we need to reset this to the default value // in order to avoid CUDA_ERROR_ILLEGAL_ADRESS errors @@ -648,7 +648,7 @@ inline sycl::event asum(const char *func_name, Func func, sycl::queue &queue, in } cublasStatus_t err; // ASUM does not support negative index - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, std::abs(incx), res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, std::abs(incx), res_); if (result_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } @@ -684,7 +684,7 @@ inline sycl::event scal(const char *func_name, Func func, sycl::queue &queue, in auto x_ = reinterpret_cast(x); cublasStatus_t err; // SCAL does not support negative incx - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, (cuDataType1 *)&a, x_, + cublas_native_named_func(func_name, func, err, handle, n, (cuDataType1 *)&a, x_, std::abs(incx)); }); }); @@ -720,7 +720,7 @@ inline sycl::event axpy(const char *func_name, Func func, sycl::queue &queue, in auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, (cuDataType *)&alpha, x_, + cublas_native_named_func(func_name, func, err, handle, n, (cuDataType *)&alpha, x_, incx, y_, incy); }); }); @@ -798,7 +798,7 @@ inline sycl::event rotg(const char *func_name, Func func, sycl::queue &queue, T1 cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); } cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, a_, b_, c_, s_); + cublas_native_named_func(func_name, func, err, handle, a_, b_, c_, s_); if (results_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } @@ -836,7 +836,7 @@ inline sycl::event rotm(const char *func_name, Func func, sycl::queue &queue, in auto y_ = reinterpret_cast(y); auto param_ = reinterpret_cast(param); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy, param_); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy, param_); }); }); return done; @@ -869,7 +869,7 @@ inline sycl::event copy(const char *func_name, Func func, sycl::queue &queue, in auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy); }); }); return done; @@ -909,7 +909,7 @@ inline sycl::event dot(const char *func_name, Func func, sycl::queue &queue, int cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); } cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy, res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy, res_); if (result_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } @@ -951,7 +951,7 @@ inline sycl::event rot(const char *func_name, Func func, sycl::queue &queue, int auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy, + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy, (cuDataType2 *)&c, (cuDataType3 *)&s); }); }); @@ -993,7 +993,7 @@ sycl::event sdsdot(sycl::queue &queue, int64_t n, float sb, const float *x, int6 cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); } cublasStatus_t err; - CUBLAS_ERROR_FUNC_SYNC(cublasSdot, err, handle, n, x_, incx, y_, incy, res_); + cublas_native_func(cublasSdot, err, handle, n, x_, incx, y_, incy, res_); if (result_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } @@ -1058,12 +1058,12 @@ inline sycl::event rotmg(const char *func_name, Func func, sycl::queue &queue, T cublasStatus_t err; if (results_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, d1_, d2_, x1_, y1_, param_); + cublas_native_named_func(func_name, func, err, handle, d1_, d2_, x1_, y1_, param_); cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } else { auto y1_c = reinterpret_cast(&y1); - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, d1_, d2_, x1_, y1_c, param_); + cublas_native_named_func(func_name, func, err, handle, d1_, d2_, x1_, y1_c, param_); } }); }); @@ -1120,7 +1120,7 @@ inline sycl::event iamax(const char *func_name, Func func, sycl::queue &queue, i cublasStatus_t err; // For negative incx, iamax returns 0. This behaviour is similar to that of // reference iamax. - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, int_res_p); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, int_res_p); if (result_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } @@ -1168,7 +1168,7 @@ inline sycl::event swap(const char *func_name, Func func, sycl::queue &queue, in auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, y_, incy); }); }); return done; @@ -1221,7 +1221,7 @@ inline sycl::event iamin(const char *func_name, Func func, sycl::queue &queue, i cublasStatus_t err; // For negative incx, iamin returns 0. This behaviour is similar to that of // implemented iamin. - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, int_res_p); + cublas_native_named_func(func_name, func, err, handle, n, x_, incx, int_res_p); if (result_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } @@ -1277,7 +1277,7 @@ inline sycl::event nrm2(const char *func_name, Func func, sycl::queue &queue, in } cublasStatus_t err; // NRM2 does not support negative index - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, std::abs(incx), res_); + cublas_native_named_func(func_name, func, err, handle, n, x_, std::abs(incx), res_); if (result_on_device) { cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); } diff --git a/src/blas/backends/cublas/cublas_level2.cpp b/src/blas/backends/cublas/cublas_level2.cpp index 8f711243b..5ce6e5eaf 100644 --- a/src/blas/backends/cublas/cublas_level2.cpp +++ b/src/blas/backends/cublas/cublas_level2.cpp @@ -46,7 +46,7 @@ inline void gemv(const char *func_name, Func func, sycl::queue &queue, transpose auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), m, + cublas_native_named_func(func_name, func, err, handle, get_cublas_operation(trans), m, n, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -83,7 +83,7 @@ inline void gbmv(const char *func_name, Func func, sycl::queue &queue, transpose auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), m, + cublas_native_named_func(func_name, func, err, handle, get_cublas_operation(trans), m, n, kl, ku, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -120,7 +120,7 @@ inline void ger(const char *func_name, Func func, sycl::queue &queue, int64_t m, auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, (cuDataType *)&alpha, x_, + cublas_native_named_func(func_name, func, err, handle, m, n, (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); }); @@ -157,7 +157,7 @@ inline void hbmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, k, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -192,7 +192,7 @@ inline void hemv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -227,7 +227,7 @@ inline void her(const char *func_name, Func func, sycl::queue &queue, uplo upper auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuScalarType *)&alpha, x_, incx, a_, lda); }); @@ -262,7 +262,7 @@ inline void her2(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); @@ -298,7 +298,7 @@ inline void hpmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -333,7 +333,7 @@ inline void hpr(const char *func_name, Func func, sycl::queue &queue, uplo upper auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuScalarType *)&alpha, x_, incx, a_); }); @@ -367,7 +367,7 @@ inline void hpr2(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_); }); @@ -402,7 +402,7 @@ inline void sbmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, k, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -438,7 +438,7 @@ inline void symv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -471,7 +471,7 @@ inline void syr(const char *func_name, Func func, sycl::queue &queue, uplo upper auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, a_, lda); }); @@ -507,7 +507,7 @@ inline void syr2(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); @@ -546,7 +546,7 @@ inline void spmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -579,7 +579,7 @@ inline void spr(const char *func_name, Func func, sycl::queue &queue, uplo upper auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, a_); }); @@ -613,7 +613,7 @@ inline void spr2(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto x_ = sc.get_mem(x_acc); auto y_ = sc.get_mem(y_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_); }); @@ -646,7 +646,7 @@ inline void tbmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, k, a_, lda, x_, incx); }); @@ -682,7 +682,7 @@ inline void tbsv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, k, a_, lda, x_, incx); }); @@ -718,7 +718,7 @@ inline void tpmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, x_, incx); }); @@ -753,7 +753,7 @@ inline void tpsv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, x_, incx); }); @@ -788,7 +788,7 @@ inline void trmv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, lda, x_, incx); }); @@ -823,7 +823,7 @@ inline void trsv(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto x_ = sc.get_mem(x_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, lda, x_, incx); }); @@ -864,7 +864,7 @@ inline sycl::event gemv(const char *func_name, Func func, sycl::queue &queue, tr auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), m, + cublas_native_named_func(func_name, func, err, handle, get_cublas_operation(trans), m, n, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -904,7 +904,7 @@ inline sycl::event gbmv(const char *func_name, Func func, sycl::queue &queue, tr auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), m, + cublas_native_named_func(func_name, func, err, handle, get_cublas_operation(trans), m, n, kl, ku, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -944,7 +944,7 @@ inline sycl::event ger(const char *func_name, Func func, sycl::queue &queue, int auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, (cuDataType *)&alpha, x_, + cublas_native_named_func(func_name, func, err, handle, m, n, (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); }); @@ -985,7 +985,7 @@ inline sycl::event hbmv(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, k, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -1022,7 +1022,7 @@ inline sycl::event hemv(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -1060,7 +1060,7 @@ inline sycl::event her(const char *func_name, Func func, sycl::queue &queue, upl auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuScalarType *)&alpha, x_, incx, a_, lda); }); @@ -1098,7 +1098,7 @@ inline sycl::event her2(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); @@ -1136,7 +1136,7 @@ inline sycl::event hpmv(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -1174,7 +1174,7 @@ inline sycl::event hpr(const char *func_name, Func func, sycl::queue &queue, upl auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuScalarType *)&alpha, x_, incx, a_); }); @@ -1212,7 +1212,7 @@ inline sycl::event hpr2(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_); }); @@ -1251,7 +1251,7 @@ inline sycl::event sbmv(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, k, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -1289,7 +1289,7 @@ inline sycl::event symv(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, lda, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -1326,7 +1326,7 @@ inline sycl::event syr(const char *func_name, Func func, sycl::queue &queue, upl auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, a_, lda); }); @@ -1366,7 +1366,7 @@ inline sycl::event syr2(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); @@ -1407,7 +1407,7 @@ inline sycl::event spmv(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, a_, x_, incx, (cuDataType *)&beta, y_, incy); }); @@ -1444,7 +1444,7 @@ inline sycl::event spr(const char *func_name, Func func, sycl::queue &queue, upl auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, a_); }); @@ -1481,7 +1481,7 @@ inline sycl::event spr2(const char *func_name, Func func, sycl::queue &queue, up auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), n, (cuDataType *)&alpha, x_, incx, y_, incy, a_); }); @@ -1519,7 +1519,7 @@ inline sycl::event tbmv(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, k, a_, lda, x_, incx); }); @@ -1559,7 +1559,7 @@ inline sycl::event tbsv(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, k, a_, lda, x_, incx); }); @@ -1598,7 +1598,7 @@ inline sycl::event tpmv(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, x_, incx); }); @@ -1637,7 +1637,7 @@ inline sycl::event tpsv(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, x_, incx); }); @@ -1676,7 +1676,7 @@ inline sycl::event trmv(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, lda, x_, incx); }); @@ -1715,7 +1715,7 @@ inline sycl::event trsv(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto x_ = reinterpret_cast(x); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), n, a_, lda, x_, incx); }); diff --git a/src/blas/backends/cublas/cublas_level3.cpp b/src/blas/backends/cublas/cublas_level3.cpp index 5ea4e2152..be634a15c 100644 --- a/src/blas/backends/cublas/cublas_level3.cpp +++ b/src/blas/backends/cublas/cublas_level3.cpp @@ -47,7 +47,7 @@ inline void gemm(const char *func_name, Func func, sycl::queue &queue, transpose auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), + cublas_native_named_func(func_name, func, err, handle, get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); }); @@ -94,10 +94,17 @@ inline void gemm_ex(DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, sycl::que auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUBLAS_ERROR_FUNC_SYNC(cublasGemmEx, err, handle, get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, (cuDataType_C *)&alpha, a_, DT_A, lda, b_, DT_B, ldb, (cuDataType_C *)&beta, c_, DT_C, ldc, DT_C, CUBLAS_GEMM_DEFAULT); +#else + CUBLAS_ERROR_FUNC(cublasGemmEx, err, handle, get_cublas_operation(transa), + get_cublas_operation(transb), m, n, k, (cuDataType_C *)&alpha, + a_, DT_A, lda, b_, DT_B, ldb, (cuDataType_C *)&beta, c_, DT_C, + ldc, DT_C, CUBLAS_GEMM_DEFAULT); +#endif }); }); } @@ -139,7 +146,7 @@ inline void symm(const char *func_name, Func func, sycl::queue &queue, side left auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); }); @@ -178,7 +185,7 @@ inline void hemm(const char *func_name, Func func, sycl::queue &queue, side left auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); }); @@ -211,7 +218,7 @@ inline void syrk(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, (cuDataType *)&beta, c_, ldc); @@ -250,7 +257,7 @@ inline void herk(const char *func_name, Func func, sycl::queue &queue, uplo uppe auto a_ = sc.get_mem(a_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuScalarType *)&alpha, a_, lda, (cuScalarType *)&beta, c_, ldc); @@ -288,7 +295,7 @@ inline void syr2k(const char *func_name, Func func, sycl::queue &queue, uplo upp auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); @@ -328,7 +335,7 @@ inline void her2k(const char *func_name, Func func, sycl::queue &queue, uplo upp auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuScalarType *)&beta, c_, ldc); @@ -368,7 +375,7 @@ inline void trmm(const char *func_name, Func func, sycl::queue &queue, side left auto a_ = sc.get_mem(a_acc); auto b_ = sc.get_mem(b_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb, b_, ldb); @@ -404,7 +411,7 @@ inline void trsm(const char *func_name, Func func, sycl::queue &queue, side left auto a_ = sc.get_mem(a_acc); auto b_ = sc.get_mem(b_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb); @@ -446,7 +453,7 @@ inline sycl::event gemm(const char *func_name, Func func, sycl::queue &queue, tr auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), + cublas_native_named_func(func_name, func, err, handle, get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); }); @@ -492,10 +499,17 @@ inline sycl::event gemm_ex_usm(DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); cublasStatus_t err; +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND CUBLAS_ERROR_FUNC_SYNC(cublasGemmEx, err, handle, get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, (cuDataType_C *)&alpha, a_, DT_A, lda, b_, DT_B, ldb, (cuDataType_C *)&beta, c_, DT_C, ldc, DT_C, CUBLAS_GEMM_DEFAULT); +#else + CUBLAS_ERROR_FUNC(cublasGemmEx, err, handle, get_cublas_operation(transa), + get_cublas_operation(transb), m, n, k, (cuDataType_C *)&alpha, + a_, DT_A, lda, b_, DT_B, ldb, (cuDataType_C *)&beta, c_, DT_C, + ldc, DT_C, CUBLAS_GEMM_DEFAULT); +#endif }); }); return done; @@ -541,7 +555,7 @@ inline sycl::event symm(const char *func_name, Func func, sycl::queue &queue, si auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); }); @@ -583,7 +597,7 @@ inline sycl::event hemm(const char *func_name, Func func, sycl::queue &queue, si auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); }); @@ -620,7 +634,7 @@ inline sycl::event syrk(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, (cuDataType *)&beta, c_, ldc); @@ -662,7 +676,7 @@ inline sycl::event herk(const char *func_name, Func func, sycl::queue &queue, up auto a_ = reinterpret_cast(a); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuScalarType *)&alpha, a_, lda, (cuScalarType *)&beta, c_, ldc); @@ -703,7 +717,7 @@ inline sycl::event syr2k(const char *func_name, Func func, sycl::queue &queue, u auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuDataType *)&beta, c_, ldc); @@ -747,7 +761,7 @@ inline sycl::event her2k(const char *func_name, Func func, sycl::queue &queue, u auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), n, k, (cuDataType *)&alpha, a_, lda, b_, ldb, (cuScalarType *)&beta, c_, ldc); @@ -791,7 +805,7 @@ inline sycl::event trmm(const char *func_name, Func func, sycl::queue &queue, si auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb, b_, ldb); @@ -831,7 +845,7 @@ inline sycl::event trsm(const char *func_name, Func func, sycl::queue &queue, si auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(left_right), + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right), get_cublas_fill_mode(upper_lower), get_cublas_operation(trans), get_cublas_diag_type(unit_diag), m, n, (cuDataType *)&alpha, a_, lda, b_, ldb); diff --git a/src/blas/backends/cublas/cublas_task.hpp b/src/blas/backends/cublas/cublas_task.hpp index a486aafee..4fbdfdda2 100644 --- a/src/blas/backends/cublas/cublas_task.hpp +++ b/src/blas/backends/cublas/cublas_task.hpp @@ -67,7 +67,11 @@ static inline void host_task_internal(H &cgh, sycl::queue queue, F f) { #else template static inline void host_task_internal(H &cgh, sycl::queue queue, F f) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command([f, queue](sycl::interop_handle ih){ +#else cgh.host_task([f, queue](sycl::interop_handle ih) { +#endif auto sc = CublasScopedContextHandler(queue, ih); f(sc); }); From a28cd4dbd2df3b566c6614502b225adf00eb7c74 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Wed, 18 Sep 2024 11:32:34 -0700 Subject: [PATCH 03/11] Update lapack tests to manage queue syncs Signed-off-by: JackAKirk --- tests/unit_tests/lapack/source/gebrd.cpp | 3 +++ tests/unit_tests/lapack/source/geqrf.cpp | 3 +++ tests/unit_tests/lapack/source/geqrf_batch_group.cpp | 3 +++ tests/unit_tests/lapack/source/geqrf_batch_stride.cpp | 3 +++ tests/unit_tests/lapack/source/gerqf.cpp | 3 +++ tests/unit_tests/lapack/source/gesvd.cpp | 3 +++ tests/unit_tests/lapack/source/getrf.cpp | 3 +++ tests/unit_tests/lapack/source/getrf_batch_group.cpp | 3 +++ tests/unit_tests/lapack/source/getrf_batch_stride.cpp | 3 +++ tests/unit_tests/lapack/source/getri.cpp | 3 +++ tests/unit_tests/lapack/source/getri_batch_group.cpp | 3 +++ tests/unit_tests/lapack/source/getri_batch_stride.cpp | 3 +++ tests/unit_tests/lapack/source/getrs.cpp | 3 +++ tests/unit_tests/lapack/source/getrs_batch_group.cpp | 3 +++ tests/unit_tests/lapack/source/getrs_batch_stride.cpp | 3 +++ tests/unit_tests/lapack/source/heevd.cpp | 3 +++ tests/unit_tests/lapack/source/hegvd.cpp | 3 +++ tests/unit_tests/lapack/source/hetrd.cpp | 3 +++ tests/unit_tests/lapack/source/hetrf.cpp | 3 +++ tests/unit_tests/lapack/source/orgbr.cpp | 3 +++ tests/unit_tests/lapack/source/orgqr.cpp | 3 +++ tests/unit_tests/lapack/source/orgqr_batch_group.cpp | 3 +++ tests/unit_tests/lapack/source/orgqr_batch_stride.cpp | 3 +++ tests/unit_tests/lapack/source/orgtr.cpp | 3 +++ tests/unit_tests/lapack/source/ormqr.cpp | 3 +++ tests/unit_tests/lapack/source/ormtr.cpp | 3 +++ tests/unit_tests/lapack/source/potrf.cpp | 3 +++ tests/unit_tests/lapack/source/potrf_batch_group.cpp | 3 +++ tests/unit_tests/lapack/source/potrf_batch_stride.cpp | 3 +++ tests/unit_tests/lapack/source/potri.cpp | 3 +++ tests/unit_tests/lapack/source/potrs.cpp | 3 +++ tests/unit_tests/lapack/source/potrs_batch_group.cpp | 3 +++ tests/unit_tests/lapack/source/potrs_batch_stride.cpp | 3 +++ tests/unit_tests/lapack/source/syevd.cpp | 3 +++ tests/unit_tests/lapack/source/sygvd.cpp | 3 +++ tests/unit_tests/lapack/source/sytrd.cpp | 3 +++ tests/unit_tests/lapack/source/sytrf.cpp | 3 +++ tests/unit_tests/lapack/source/trtrs.cpp | 3 +++ tests/unit_tests/lapack/source/ungbr.cpp | 3 +++ tests/unit_tests/lapack/source/ungqr.cpp | 3 +++ tests/unit_tests/lapack/source/ungqr_batch_group.cpp | 3 +++ tests/unit_tests/lapack/source/ungqr_batch_stride.cpp | 3 +++ tests/unit_tests/lapack/source/ungtr.cpp | 3 +++ tests/unit_tests/lapack/source/unmqr.cpp | 3 +++ tests/unit_tests/lapack/source/unmrq.cpp | 3 +++ tests/unit_tests/lapack/source/unmtr.cpp | 3 +++ 46 files changed, 138 insertions(+) diff --git a/tests/unit_tests/lapack/source/gebrd.cpp b/tests/unit_tests/lapack/source/gebrd.cpp index 66eb0b231..8695a6ce2 100644 --- a/tests/unit_tests/lapack/source/gebrd.cpp +++ b/tests/unit_tests/lapack/source/gebrd.cpp @@ -76,6 +76,7 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::gebrd_scratchpad_size, m, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -149,6 +150,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::gebrd_scratchpad_size, m, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -166,6 +168,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, d_dev, e_dev, tauq_dev, taup_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/geqrf.cpp b/tests/unit_tests/lapack/source/geqrf.cpp index 27577e972..ea4a5fab3 100644 --- a/tests/unit_tests/lapack/source/geqrf.cpp +++ b/tests/unit_tests/lapack/source/geqrf.cpp @@ -68,6 +68,7 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::geqrf_scratchpad_size, m, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -125,6 +126,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::geqrf_scratchpad_size, m, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -142,6 +144,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/geqrf_batch_group.cpp b/tests/unit_tests/lapack/source/geqrf_batch_group.cpp index 416466028..087598bd3 100644 --- a/tests/unit_tests/lapack/source/geqrf_batch_group.cpp +++ b/tests/unit_tests/lapack/source/geqrf_batch_group.cpp @@ -103,6 +103,7 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -241,6 +242,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -271,6 +273,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { tau_dev_ptrs, group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/geqrf_batch_stride.cpp b/tests/unit_tests/lapack/source/geqrf_batch_stride.cpp index 16ceef63a..112f5a673 100644 --- a/tests/unit_tests/lapack/source/geqrf_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/geqrf_batch_stride.cpp @@ -69,6 +69,7 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, int64_ queue, scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, m, n, lda, stride_a, stride_tau, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -142,6 +143,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, queue, scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, m, n, lda, stride_a, stride_tau, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -159,6 +161,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, lda, stride_a, tau_dev, stride_tau, batch_size, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/gerqf.cpp b/tests/unit_tests/lapack/source/gerqf.cpp index dac6d79aa..cb83781dc 100644 --- a/tests/unit_tests/lapack/source/gerqf.cpp +++ b/tests/unit_tests/lapack/source/gerqf.cpp @@ -68,6 +68,7 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::gerqf_scratchpad_size, m, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -125,6 +126,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::gerqf_scratchpad_size, m, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -142,6 +144,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/gesvd.cpp b/tests/unit_tests/lapack/source/gesvd.cpp index 1e143315b..a5b94eab5 100644 --- a/tests/unit_tests/lapack/source/gesvd.cpp +++ b/tests/unit_tests/lapack/source/gesvd.cpp @@ -79,6 +79,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::jobsvd jobu, oneapi::mkl::jo scratchpad_size = oneapi::mkl::lapack::gesvd_scratchpad_size, jobu, jobvt, m, n, lda, ldu, ldvt); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -201,6 +202,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::jobsvd jobu, oneapi::m scratchpad_size = oneapi::mkl::lapack::gesvd_scratchpad_size, jobu, jobvt, m, n, lda, ldu, ldvt); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -218,6 +220,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::jobsvd jobu, oneapi::m A_dev, lda, s_dev, U_dev, ldu, Vt_dev, ldvt, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getrf.cpp b/tests/unit_tests/lapack/source/getrf.cpp index 4537ef665..3521cd373 100644 --- a/tests/unit_tests/lapack/source/getrf.cpp +++ b/tests/unit_tests/lapack/source/getrf.cpp @@ -71,6 +71,7 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::getrf_scratchpad_size, m, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -128,6 +129,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::getrf_scratchpad_size, m, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -145,6 +147,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, ipiv_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getrf_batch_group.cpp b/tests/unit_tests/lapack/source/getrf_batch_group.cpp index 12e651746..0410597e3 100644 --- a/tests/unit_tests/lapack/source/getrf_batch_group.cpp +++ b/tests/unit_tests/lapack/source/getrf_batch_group.cpp @@ -107,6 +107,7 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -250,6 +251,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -280,6 +282,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { A_dev_ptrs, lda_vec.data(), ipiv_dev_ptrs, group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getrf_batch_stride.cpp b/tests/unit_tests/lapack/source/getrf_batch_stride.cpp index 3e4ef6589..1dde4bed3 100644 --- a/tests/unit_tests/lapack/source/getrf_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/getrf_batch_stride.cpp @@ -69,6 +69,7 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, int64_ queue, scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, m, n, lda, stride_a, stride_ipiv, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -142,6 +143,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, queue, scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, m, n, lda, stride_a, stride_ipiv, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -159,6 +161,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, lda, stride_a, ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getri.cpp b/tests/unit_tests/lapack/source/getri.cpp index a1aa2deda..f3b682683 100644 --- a/tests/unit_tests/lapack/source/getri.cpp +++ b/tests/unit_tests/lapack/source/getri.cpp @@ -76,6 +76,7 @@ bool accuracy(const sycl::device& dev, int64_t n, int64_t lda, uint64_t seed) { TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::getri_scratchpad_size, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -138,6 +139,7 @@ bool usm_dependency(const sycl::device& dev, int64_t n, int64_t lda, uint64_t se TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::getri_scratchpad_size, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -156,6 +158,7 @@ bool usm_dependency(const sycl::device& dev, int64_t n, int64_t lda, uint64_t se ipiv_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getri_batch_group.cpp b/tests/unit_tests/lapack/source/getri_batch_group.cpp index 244acfcc8..6a23549b8 100644 --- a/tests/unit_tests/lapack/source/getri_batch_group.cpp +++ b/tests/unit_tests/lapack/source/getri_batch_group.cpp @@ -114,6 +114,7 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -262,6 +263,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -295,6 +297,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getri_batch_stride.cpp b/tests/unit_tests/lapack/source/getri_batch_stride.cpp index 5a71d2d7e..6dbe5908e 100644 --- a/tests/unit_tests/lapack/source/getri_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/getri_batch_stride.cpp @@ -76,6 +76,7 @@ bool accuracy(const sycl::device& dev, int64_t n, int64_t lda, int64_t stride_a, queue, scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, n, lda, stride_a, stride_ipiv, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -156,6 +157,7 @@ bool usm_dependency(const sycl::device& dev, int64_t n, int64_t lda, int64_t str queue, scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, n, lda, stride_a, stride_ipiv, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -174,6 +176,7 @@ bool usm_dependency(const sycl::device& dev, int64_t n, int64_t lda, int64_t str lda, stride_a, ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getrs.cpp b/tests/unit_tests/lapack/source/getrs.cpp index bfc271758..6a39ce656 100644 --- a/tests/unit_tests/lapack/source/getrs.cpp +++ b/tests/unit_tests/lapack/source/getrs.cpp @@ -77,6 +77,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::transpose trans, int64_t n, scratchpad_size = oneapi::mkl::lapack::getrs_scratchpad_size, trans, n, nrhs, lda, ldb); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -149,6 +150,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::transpose trans, int64 scratchpad_size = oneapi::mkl::lapack::getrs_scratchpad_size, trans, n, nrhs, lda, ldb); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -168,6 +170,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::transpose trans, int64 A_dev, lda, ipiv_dev, B_dev, ldb, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getrs_batch_group.cpp b/tests/unit_tests/lapack/source/getrs_batch_group.cpp index 2027663e4..bc1d4aaee 100644 --- a/tests/unit_tests/lapack/source/getrs_batch_group.cpp +++ b/tests/unit_tests/lapack/source/getrs_batch_group.cpp @@ -137,6 +137,7 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { trans_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), ldb_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -323,6 +324,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { trans_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), ldb_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -362,6 +364,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getrs_batch_stride.cpp b/tests/unit_tests/lapack/source/getrs_batch_stride.cpp index 1faf3d3e6..609a67dda 100644 --- a/tests/unit_tests/lapack/source/getrs_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/getrs_batch_stride.cpp @@ -82,6 +82,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::transpose trans, int64_t n, queue, scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size, trans, n, nrhs, lda, stride_a, stride_ipiv, ldb, stride_b, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -171,6 +172,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::transpose trans, int64 queue, scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size, trans, n, nrhs, lda, stride_a, stride_ipiv, ldb, stride_b, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -192,6 +194,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::transpose trans, int64 stride_b, batch_size, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/heevd.cpp b/tests/unit_tests/lapack/source/heevd.cpp index 62c23c3ad..10d927c6d 100644 --- a/tests/unit_tests/lapack/source/heevd.cpp +++ b/tests/unit_tests/lapack/source/heevd.cpp @@ -66,6 +66,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl::uplo scratchpad_size = oneapi::mkl::lapack::heevd_scratchpad_size, jobz, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -124,6 +125,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl: scratchpad_size = oneapi::mkl::lapack::heevd_scratchpad_size, jobz, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -141,6 +143,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl: A_dev, lda, w_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/hegvd.cpp b/tests/unit_tests/lapack/source/hegvd.cpp index 9a109e6b8..b58fb0199 100644 --- a/tests/unit_tests/lapack/source/hegvd.cpp +++ b/tests/unit_tests/lapack/source/hegvd.cpp @@ -72,6 +72,7 @@ bool accuracy(const sycl::device& dev, int64_t itype, oneapi::mkl::job jobz, one scratchpad_size = oneapi::mkl::lapack::hegvd_scratchpad_size, itype, jobz, uplo, n, lda, ldb); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -260,6 +261,7 @@ bool usm_dependency(const sycl::device& dev, int64_t itype, oneapi::mkl::job job scratchpad_size = oneapi::mkl::lapack::hegvd_scratchpad_size, itype, jobz, uplo, n, lda, ldb); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -278,6 +280,7 @@ bool usm_dependency(const sycl::device& dev, int64_t itype, oneapi::mkl::job job n, A_dev, lda, B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/hetrd.cpp b/tests/unit_tests/lapack/source/hetrd.cpp index 13172d64f..6ee81c54f 100644 --- a/tests/unit_tests/lapack/source/hetrd.cpp +++ b/tests/unit_tests/lapack/source/hetrd.cpp @@ -69,6 +69,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::hetrd_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -182,6 +183,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::hetrd_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -202,6 +204,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, d_dev, e_dev, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/hetrf.cpp b/tests/unit_tests/lapack/source/hetrf.cpp index 73535a77f..cb4589fe9 100644 --- a/tests/unit_tests/lapack/source/hetrf.cpp +++ b/tests/unit_tests/lapack/source/hetrf.cpp @@ -69,6 +69,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::hetrf_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -242,6 +243,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::hetrf_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -260,6 +262,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, ipiv_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/orgbr.cpp b/tests/unit_tests/lapack/source/orgbr.cpp index 274cafce0..40f853611 100644 --- a/tests/unit_tests/lapack/source/orgbr.cpp +++ b/tests/unit_tests/lapack/source/orgbr.cpp @@ -86,6 +86,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::generate vect, int64_t m, in scratchpad_size = oneapi::mkl::lapack::orgbr_scratchpad_size, vect, m, n, k, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -161,6 +162,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::generate vect, int64_t scratchpad_size = oneapi::mkl::lapack::orgbr_scratchpad_size, vect, m, n, k, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -179,6 +181,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::generate vect, int64_t A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/orgqr.cpp b/tests/unit_tests/lapack/source/orgqr.cpp index 9d62daf5f..d16ea3992 100644 --- a/tests/unit_tests/lapack/source/orgqr.cpp +++ b/tests/unit_tests/lapack/source/orgqr.cpp @@ -74,6 +74,7 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::orgqr_scratchpad_size, m, n, k, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -135,6 +136,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::orgqr_scratchpad_size, m, n, k, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -153,6 +155,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/orgqr_batch_group.cpp b/tests/unit_tests/lapack/source/orgqr_batch_group.cpp index 3af796e7d..2e7e6a489 100644 --- a/tests/unit_tests/lapack/source/orgqr_batch_group.cpp +++ b/tests/unit_tests/lapack/source/orgqr_batch_group.cpp @@ -111,6 +111,7 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -256,6 +257,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -289,6 +291,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/orgqr_batch_stride.cpp b/tests/unit_tests/lapack/source/orgqr_batch_stride.cpp index 1cf3471c5..b8b37673b 100644 --- a/tests/unit_tests/lapack/source/orgqr_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/orgqr_batch_stride.cpp @@ -75,6 +75,7 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t queue, scratchpad_size = oneapi::mkl::lapack::orgqr_batch_scratchpad_size, m, n, k, lda, stride_a, stride_tau, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -153,6 +154,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in queue, scratchpad_size = oneapi::mkl::lapack::orgqr_batch_scratchpad_size, m, n, k, lda, stride_a, stride_tau, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -172,6 +174,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/orgtr.cpp b/tests/unit_tests/lapack/source/orgtr.cpp index 5a01745d5..309a90b6e 100644 --- a/tests/unit_tests/lapack/source/orgtr.cpp +++ b/tests/unit_tests/lapack/source/orgtr.cpp @@ -72,6 +72,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::orgtr_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -136,6 +137,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::orgtr_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -154,6 +156,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ormqr.cpp b/tests/unit_tests/lapack/source/ormqr.cpp index e2ed49b96..e8192c5db 100644 --- a/tests/unit_tests/lapack/source/ormqr.cpp +++ b/tests/unit_tests/lapack/source/ormqr.cpp @@ -83,6 +83,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl scratchpad_size = oneapi::mkl::lapack::ormqr_scratchpad_size, left_right, trans, m, n, k, lda, ldc); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -170,6 +171,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, scratchpad_size = oneapi::mkl::lapack::ormqr_scratchpad_size, left_right, trans, m, n, k, lda, ldc); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -189,6 +191,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ormtr.cpp b/tests/unit_tests/lapack/source/ormtr.cpp index 4e8dd95b9..1038f155e 100644 --- a/tests/unit_tests/lapack/source/ormtr.cpp +++ b/tests/unit_tests/lapack/source/ormtr.cpp @@ -81,6 +81,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, int64_ scratchpad_size = oneapi::mkl::lapack::ormtr_scratchpad_size, side, uplo, trans, m, n, lda, ldc); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -169,6 +170,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, scratchpad_size = oneapi::mkl::lapack::ormtr_scratchpad_size, side, uplo, trans, m, n, lda, ldc); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -188,6 +190,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, m, n, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potrf.cpp b/tests/unit_tests/lapack/source/potrf.cpp index 7d2df8ea9..82de4884c 100644 --- a/tests/unit_tests/lapack/source/potrf.cpp +++ b/tests/unit_tests/lapack/source/potrf.cpp @@ -67,6 +67,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::potrf_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -120,6 +121,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::potrf_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -137,6 +139,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potrf_batch_group.cpp b/tests/unit_tests/lapack/source/potrf_batch_group.cpp index 4a5b8dd58..33a233b73 100644 --- a/tests/unit_tests/lapack/source/potrf_batch_group.cpp +++ b/tests/unit_tests/lapack/source/potrf_batch_group.cpp @@ -99,6 +99,7 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, uplo_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -222,6 +223,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, uplo_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -249,6 +251,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potrf_batch_stride.cpp b/tests/unit_tests/lapack/source/potrf_batch_stride.cpp index fae4f0bcc..deb31fb43 100644 --- a/tests/unit_tests/lapack/source/potrf_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/potrf_batch_stride.cpp @@ -66,6 +66,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ queue, scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, uplo, n, lda, stride_a, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -132,6 +133,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, queue, scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, uplo, n, lda, stride_a, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -149,6 +151,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, A_dev, lda, stride_a, batch_size, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potri.cpp b/tests/unit_tests/lapack/source/potri.cpp index cd2f86449..bcef465e4 100644 --- a/tests/unit_tests/lapack/source/potri.cpp +++ b/tests/unit_tests/lapack/source/potri.cpp @@ -71,6 +71,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::potri_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -154,6 +155,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::potri_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -171,6 +173,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potrs.cpp b/tests/unit_tests/lapack/source/potrs.cpp index c534ec8ba..83400241c 100644 --- a/tests/unit_tests/lapack/source/potrs.cpp +++ b/tests/unit_tests/lapack/source/potrs.cpp @@ -75,6 +75,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ scratchpad_size = oneapi::mkl::lapack::potrs_scratchpad_size, uplo, n, nrhs, lda, ldb); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -142,6 +143,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, scratchpad_size = oneapi::mkl::lapack::potrs_scratchpad_size, uplo, n, nrhs, lda, ldb); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -160,6 +162,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, A_dev, lda, B_dev, ldb, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potrs_batch_group.cpp b/tests/unit_tests/lapack/source/potrs_batch_group.cpp index 35c5ead0c..2f76d9cac 100644 --- a/tests/unit_tests/lapack/source/potrs_batch_group.cpp +++ b/tests/unit_tests/lapack/source/potrs_batch_group.cpp @@ -124,6 +124,7 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { uplo_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), ldb_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -287,6 +288,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { uplo_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), ldb_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -320,6 +322,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potrs_batch_stride.cpp b/tests/unit_tests/lapack/source/potrs_batch_stride.cpp index de2568e86..079f32564 100644 --- a/tests/unit_tests/lapack/source/potrs_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/potrs_batch_stride.cpp @@ -80,6 +80,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ queue, scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size, uplo, n, nrhs, lda, stride_a, ldb, stride_b, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -165,6 +166,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, queue, scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size, uplo, n, nrhs, lda, stride_a, ldb, stride_b, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -184,6 +186,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/syevd.cpp b/tests/unit_tests/lapack/source/syevd.cpp index 291713354..6e166843e 100644 --- a/tests/unit_tests/lapack/source/syevd.cpp +++ b/tests/unit_tests/lapack/source/syevd.cpp @@ -66,6 +66,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl::uplo scratchpad_size = oneapi::mkl::lapack::syevd_scratchpad_size, jobz, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -124,6 +125,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl: scratchpad_size = oneapi::mkl::lapack::syevd_scratchpad_size, jobz, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -141,6 +143,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl: A_dev, lda, w_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/sygvd.cpp b/tests/unit_tests/lapack/source/sygvd.cpp index f800b03dd..726b8da8a 100644 --- a/tests/unit_tests/lapack/source/sygvd.cpp +++ b/tests/unit_tests/lapack/source/sygvd.cpp @@ -72,6 +72,7 @@ bool accuracy(const sycl::device& dev, int64_t itype, oneapi::mkl::job jobz, one scratchpad_size = oneapi::mkl::lapack::sygvd_scratchpad_size, itype, jobz, uplo, n, lda, ldb); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -265,6 +266,7 @@ bool usm_dependency(const sycl::device& dev, int64_t itype, oneapi::mkl::job job scratchpad_size = oneapi::mkl::lapack::sygvd_scratchpad_size, itype, jobz, uplo, n, lda, ldb); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -283,6 +285,7 @@ bool usm_dependency(const sycl::device& dev, int64_t itype, oneapi::mkl::job job n, A_dev, lda, B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/sytrd.cpp b/tests/unit_tests/lapack/source/sytrd.cpp index 01ffe0dff..78ed3aaa3 100644 --- a/tests/unit_tests/lapack/source/sytrd.cpp +++ b/tests/unit_tests/lapack/source/sytrd.cpp @@ -69,6 +69,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::sytrd_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -182,6 +183,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::sytrd_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -202,6 +204,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, d_dev, e_dev, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/sytrf.cpp b/tests/unit_tests/lapack/source/sytrf.cpp index 81d7fdb2d..42419fcec 100644 --- a/tests/unit_tests/lapack/source/sytrf.cpp +++ b/tests/unit_tests/lapack/source/sytrf.cpp @@ -66,6 +66,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::sytrf_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -239,6 +240,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::sytrf_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -257,6 +259,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, ipiv_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/trtrs.cpp b/tests/unit_tests/lapack/source/trtrs.cpp index 4018a2c51..7c0d86c09 100644 --- a/tests/unit_tests/lapack/source/trtrs.cpp +++ b/tests/unit_tests/lapack/source/trtrs.cpp @@ -78,6 +78,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, oneapi::mkl::tran scratchpad_size = oneapi::mkl::lapack::trtrs_scratchpad_size, uplo, trans, diag, n, nrhs, lda, ldb); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -141,6 +142,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, oneapi::mkl scratchpad_size = oneapi::mkl::lapack::trtrs_scratchpad_size, uplo, trans, diag, n, nrhs, lda, ldb); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -159,6 +161,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, oneapi::mkl n, nrhs, A_dev, lda, B_dev, ldb, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ungbr.cpp b/tests/unit_tests/lapack/source/ungbr.cpp index 7cdf8e52a..4e1919348 100644 --- a/tests/unit_tests/lapack/source/ungbr.cpp +++ b/tests/unit_tests/lapack/source/ungbr.cpp @@ -86,6 +86,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::generate vect, int64_t m, in scratchpad_size = oneapi::mkl::lapack::ungbr_scratchpad_size, vect, m, n, k, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -161,6 +162,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::generate vect, int64_t scratchpad_size = oneapi::mkl::lapack::ungbr_scratchpad_size, vect, m, n, k, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -179,6 +181,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::generate vect, int64_t A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ungqr.cpp b/tests/unit_tests/lapack/source/ungqr.cpp index 08b8b1192..14e775b8d 100644 --- a/tests/unit_tests/lapack/source/ungqr.cpp +++ b/tests/unit_tests/lapack/source/ungqr.cpp @@ -73,6 +73,7 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::ungqr_scratchpad_size, m, n, k, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -134,6 +135,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::ungqr_scratchpad_size, m, n, k, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -152,6 +154,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ungqr_batch_group.cpp b/tests/unit_tests/lapack/source/ungqr_batch_group.cpp index ddb350828..c03f89837 100644 --- a/tests/unit_tests/lapack/source/ungqr_batch_group.cpp +++ b/tests/unit_tests/lapack/source/ungqr_batch_group.cpp @@ -111,6 +111,7 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -256,6 +257,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -289,6 +291,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ungqr_batch_stride.cpp b/tests/unit_tests/lapack/source/ungqr_batch_stride.cpp index e656b9fb7..f9b4e5f73 100644 --- a/tests/unit_tests/lapack/source/ungqr_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/ungqr_batch_stride.cpp @@ -75,6 +75,7 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t queue, scratchpad_size = oneapi::mkl::lapack::ungqr_batch_scratchpad_size, m, n, k, lda, stride_a, stride_tau, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -153,6 +154,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in queue, scratchpad_size = oneapi::mkl::lapack::ungqr_batch_scratchpad_size, m, n, k, lda, stride_a, stride_tau, batch_size); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -172,6 +174,7 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ungtr.cpp b/tests/unit_tests/lapack/source/ungtr.cpp index b0ad8e8f2..e5edc7ef5 100644 --- a/tests/unit_tests/lapack/source/ungtr.cpp +++ b/tests/unit_tests/lapack/source/ungtr.cpp @@ -72,6 +72,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::ungtr_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -136,6 +137,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::ungtr_scratchpad_size, uplo, n, lda); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -154,6 +156,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/unmqr.cpp b/tests/unit_tests/lapack/source/unmqr.cpp index 2f555c1ca..51cfdd004 100644 --- a/tests/unit_tests/lapack/source/unmqr.cpp +++ b/tests/unit_tests/lapack/source/unmqr.cpp @@ -83,6 +83,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl scratchpad_size = oneapi::mkl::lapack::unmqr_scratchpad_size, left_right, trans, m, n, k, lda, ldc); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -170,6 +171,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, scratchpad_size = oneapi::mkl::lapack::unmqr_scratchpad_size, left_right, trans, m, n, k, lda, ldc); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -189,6 +191,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/unmrq.cpp b/tests/unit_tests/lapack/source/unmrq.cpp index 628063837..6526533b5 100644 --- a/tests/unit_tests/lapack/source/unmrq.cpp +++ b/tests/unit_tests/lapack/source/unmrq.cpp @@ -93,6 +93,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl scratchpad_size = oneapi::mkl::lapack::unmrq_scratchpad_size, left_right, trans, m, n, k, lda, ldc); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -179,6 +180,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, scratchpad_size = oneapi::mkl::lapack::unmrq_scratchpad_size, left_right, trans, m, n, k, lda, ldc); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -198,6 +200,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/unmtr.cpp b/tests/unit_tests/lapack/source/unmtr.cpp index 8148c644d..fa8ee60d6 100644 --- a/tests/unit_tests/lapack/source/unmtr.cpp +++ b/tests/unit_tests/lapack/source/unmtr.cpp @@ -81,6 +81,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, int64_ scratchpad_size = oneapi::mkl::lapack::unmtr_scratchpad_size, side, uplo, trans, m, n, lda, ldc); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -169,6 +170,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, scratchpad_size = oneapi::mkl::lapack::unmtr_scratchpad_size, side, uplo, trans, m, n, lda, ldc); #endif + queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -188,6 +190,7 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, m, n, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif + queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); From b33b9e3011453a7e092a2228e22b74ff22ef5c7c Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Thu, 19 Sep 2024 08:10:03 -0700 Subject: [PATCH 04/11] Update name of function cublas_native_named_func Signed-off-by: JackAKirk --- src/lapack/backends/cusolver/cusolver_batch.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lapack/backends/cusolver/cusolver_batch.cpp b/src/lapack/backends/cusolver/cusolver_batch.cpp index a2fdc4b99..f4017f873 100644 --- a/src/lapack/backends/cusolver/cusolver_batch.cpp +++ b/src/lapack/backends/cusolver/cusolver_batch.cpp @@ -137,7 +137,7 @@ inline void getri_batch(const char *func_name, Func func, sycl::queue &queue, st sizeof(T *) * batch_size); auto **scratch_dev_ = reinterpret_cast(scratch_dev); - blas::cublas::cublas_native_func(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32_, + blas::cublas::cublas_native_named_func(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32_, scratch_dev_, lda, info_, batch_size); free(a_batched); @@ -859,7 +859,7 @@ sycl::event getri_batch(const char *func_name, Func func, sycl::queue &queue, st sizeof(T *) * batch_size); auto **scratch_dev_ = reinterpret_cast(scratch_dev); - blas::cublas::cublas_native_func(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32, + blas::cublas::cublas_native_named_func(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32, scratch_dev_, lda, devInfo, batch_size); free(a_batched); From 8fba3193765a6c3d9ac828347675a5a41010f74a Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Fri, 20 Sep 2024 08:18:58 -0700 Subject: [PATCH 05/11] Use cublas_native_named_func more Signed-off-by: JackAKirk --- src/blas/backends/cublas/cublas_batch.cpp | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index 8b10c7744..031c11a2f 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -832,24 +832,13 @@ inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &que for (int64_t i = 0; i < group_count; i++) { auto **a_ = reinterpret_cast(a); auto **b_ = reinterpret_cast(b); -#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - CUBLAS_ERROR_FUNC_T(func_name, func, err, handle, + cublas_native_named_func(func_name, func, err, handle, get_cublas_side_mode(left_right[i]), get_cublas_fill_mode(upper_lower[i]), get_cublas_operation(trans[i]), get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i], (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i], (int)group_size[i]); -#else - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, - get_cublas_side_mode(left_right[i]), - get_cublas_fill_mode(upper_lower[i]), - get_cublas_operation(trans[i]), - get_cublas_diag_type(unit_diag[i]), (int)m[i], - (int)n[i], (cuDataType *)&alpha[i], a_ + offset, - (int)lda[i], b_ + offset, (int)ldb[i], - (int)group_size[i]); -#endif offset += group_size[i]; } }); From af4d1fd3335447c2d486d31d1da0ab67aa616260 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Tue, 1 Oct 2024 09:43:13 -0700 Subject: [PATCH 06/11] Remove dep check this dep check is overzealous because it enforces that a dependent event cannot be submitted to run on the native device queue but not completed before a later event it is dependent upon has also been marked running on the device. This is not part of the sycl spec and unnecessarily slows down execution. Signed-off-by: JackAKirk --- tests/unit_tests/lapack/common/dependency_check.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit_tests/lapack/common/dependency_check.cpp b/tests/unit_tests/lapack/common/dependency_check.cpp index 30d2d1d4a..86e313aa3 100644 --- a/tests/unit_tests/lapack/common/dependency_check.cpp +++ b/tests/unit_tests/lapack/common/dependency_check.cpp @@ -56,8 +56,7 @@ bool check_dependency(sycl::queue queue, sycl::event in_event, sycl::event func_ do { func_status = func_event.get_info(); - } while (func_status != sycl::info::event_command_status::running && - func_status != sycl::info::event_command_status::complete); + } while (func_status != sycl::info::event_command_status::complete); in_status = in_event.get_info(); /* Print results */ From c082d4a0a90bfa3ecd4a144bd9cb4b0f5adb54cc Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Wed, 2 Oct 2024 06:42:28 -0700 Subject: [PATCH 07/11] Improve synchronous impl for scratchpad These funcs are async in the cusolver backend. Signed-off-by: JackAKirk --- .../backends/cusolver/cusolver_lapack.cpp | 115 +++++++----------- 1 file changed, 46 insertions(+), 69 deletions(-) diff --git a/src/lapack/backends/cusolver/cusolver_lapack.cpp b/src/lapack/backends/cusolver/cusolver_lapack.cpp index 3d54c403a..c8190f50d 100644 --- a/src/lapack/backends/cusolver/cusolver_lapack.cpp +++ b/src/lapack/backends/cusolver/cusolver_lapack.cpp @@ -2455,10 +2455,9 @@ inline void gebrd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, scratch_size); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define GEBRD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2507,10 +2506,9 @@ inline void geqrf_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, nullptr, lda, scratch_size); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, nullptr, lda, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define GEQRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2539,10 +2537,9 @@ inline void gesvd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, scratch_size); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define GESVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2571,10 +2568,9 @@ inline void getrf_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, nullptr, lda, scratch_size); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, nullptr, lda, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define GETRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2632,12 +2628,11 @@ inline void heevd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_job(jobz), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define HEEVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2665,12 +2660,11 @@ inline void hegvd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_itype(itype), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_itype(itype), get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, ldb, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define HEGVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2697,11 +2691,10 @@ inline void hetrd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, nullptr, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define HETRD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2738,11 +2731,10 @@ inline void orgbr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define ORGBR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2769,11 +2761,10 @@ inline void orgtr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define ORGTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2799,11 +2790,10 @@ inline void orgqr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, nullptr, lda, nullptr, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define ORGQR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2845,12 +2835,11 @@ inline void ormqr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_operation(trans), m, n, k, nullptr, lda, nullptr, nullptr, ldc, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define ORMQRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2878,12 +2867,11 @@ inline void ormtr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, nullptr, lda, nullptr, nullptr, ldc, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define ORMTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2911,11 +2899,10 @@ inline void potrf_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nullptr, lda, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define POTRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2959,11 +2946,10 @@ inline void potri_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nullptr, lda, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define POTRI_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2991,10 +2977,9 @@ inline void sytrf_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, n, nullptr, lda, scratch_size); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, nullptr, lda, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define SYTRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3022,12 +3007,11 @@ inline void syevd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_job(jobz), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define SYEVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3055,12 +3039,11 @@ inline void sygvd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_itype(itype), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_itype(itype), get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, ldb, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define SYGVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3087,11 +3070,10 @@ inline void sytrd_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, nullptr, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define SYTRD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3148,11 +3130,10 @@ inline void ungbr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define UNGBR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3179,11 +3160,10 @@ inline void ungqr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, nullptr, lda, nullptr, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define UNGQR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3209,11 +3189,10 @@ inline void ungtr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nullptr, lda, nullptr, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define UNGTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3257,12 +3236,11 @@ inline void unmqr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_operation(trans), m, n, k, nullptr, lda, nullptr, nullptr, ldc, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define UNMQR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3290,12 +3268,11 @@ inline void unmtr_scratchpad_size(const char *func_name, Func func, sycl::queue onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_side_mode(side), get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, nullptr, lda, nullptr, nullptr, ldc, scratch_size); }); - }); - queue.wait(); + }).wait(); } #define UNMTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ From ba20653096f768fc83b13128521390f3f4540697 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Wed, 2 Oct 2024 06:43:43 -0700 Subject: [PATCH 08/11] Revert changes to lapack tests Signed-off-by: JackAKirk --- tests/unit_tests/lapack/source/gebrd.cpp | 3 --- tests/unit_tests/lapack/source/geqrf.cpp | 3 --- tests/unit_tests/lapack/source/geqrf_batch_group.cpp | 3 --- tests/unit_tests/lapack/source/geqrf_batch_stride.cpp | 3 --- tests/unit_tests/lapack/source/gerqf.cpp | 3 --- tests/unit_tests/lapack/source/gesvd.cpp | 3 --- tests/unit_tests/lapack/source/getrf.cpp | 3 --- tests/unit_tests/lapack/source/getrf_batch_group.cpp | 3 --- tests/unit_tests/lapack/source/getrf_batch_stride.cpp | 3 --- tests/unit_tests/lapack/source/getri.cpp | 3 --- tests/unit_tests/lapack/source/getri_batch_group.cpp | 3 --- tests/unit_tests/lapack/source/getri_batch_stride.cpp | 3 --- tests/unit_tests/lapack/source/getrs.cpp | 3 --- tests/unit_tests/lapack/source/getrs_batch_group.cpp | 3 --- tests/unit_tests/lapack/source/getrs_batch_stride.cpp | 3 --- tests/unit_tests/lapack/source/heevd.cpp | 3 --- tests/unit_tests/lapack/source/hegvd.cpp | 3 --- tests/unit_tests/lapack/source/hetrd.cpp | 3 --- tests/unit_tests/lapack/source/hetrf.cpp | 3 --- tests/unit_tests/lapack/source/orgbr.cpp | 3 --- tests/unit_tests/lapack/source/orgqr.cpp | 3 --- tests/unit_tests/lapack/source/orgqr_batch_group.cpp | 3 --- tests/unit_tests/lapack/source/orgqr_batch_stride.cpp | 3 --- tests/unit_tests/lapack/source/orgtr.cpp | 3 --- tests/unit_tests/lapack/source/ormqr.cpp | 3 --- tests/unit_tests/lapack/source/ormtr.cpp | 3 --- tests/unit_tests/lapack/source/potrf.cpp | 3 --- tests/unit_tests/lapack/source/potrf_batch_group.cpp | 3 --- tests/unit_tests/lapack/source/potrf_batch_stride.cpp | 3 --- tests/unit_tests/lapack/source/potri.cpp | 3 --- tests/unit_tests/lapack/source/potrs.cpp | 3 --- tests/unit_tests/lapack/source/potrs_batch_group.cpp | 3 --- tests/unit_tests/lapack/source/potrs_batch_stride.cpp | 3 --- tests/unit_tests/lapack/source/syevd.cpp | 3 --- tests/unit_tests/lapack/source/sygvd.cpp | 3 --- tests/unit_tests/lapack/source/sytrd.cpp | 3 --- tests/unit_tests/lapack/source/sytrf.cpp | 3 --- tests/unit_tests/lapack/source/trtrs.cpp | 3 --- tests/unit_tests/lapack/source/ungbr.cpp | 3 --- tests/unit_tests/lapack/source/ungqr.cpp | 3 --- tests/unit_tests/lapack/source/ungqr_batch_group.cpp | 3 --- tests/unit_tests/lapack/source/ungqr_batch_stride.cpp | 3 --- tests/unit_tests/lapack/source/ungtr.cpp | 3 --- tests/unit_tests/lapack/source/unmqr.cpp | 3 --- tests/unit_tests/lapack/source/unmrq.cpp | 3 --- tests/unit_tests/lapack/source/unmtr.cpp | 3 --- 46 files changed, 138 deletions(-) diff --git a/tests/unit_tests/lapack/source/gebrd.cpp b/tests/unit_tests/lapack/source/gebrd.cpp index 8695a6ce2..66eb0b231 100644 --- a/tests/unit_tests/lapack/source/gebrd.cpp +++ b/tests/unit_tests/lapack/source/gebrd.cpp @@ -76,7 +76,6 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::gebrd_scratchpad_size, m, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -150,7 +149,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::gebrd_scratchpad_size, m, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -168,7 +166,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, d_dev, e_dev, tauq_dev, taup_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/geqrf.cpp b/tests/unit_tests/lapack/source/geqrf.cpp index ea4a5fab3..27577e972 100644 --- a/tests/unit_tests/lapack/source/geqrf.cpp +++ b/tests/unit_tests/lapack/source/geqrf.cpp @@ -68,7 +68,6 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::geqrf_scratchpad_size, m, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -126,7 +125,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::geqrf_scratchpad_size, m, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -144,7 +142,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/geqrf_batch_group.cpp b/tests/unit_tests/lapack/source/geqrf_batch_group.cpp index 087598bd3..416466028 100644 --- a/tests/unit_tests/lapack/source/geqrf_batch_group.cpp +++ b/tests/unit_tests/lapack/source/geqrf_batch_group.cpp @@ -103,7 +103,6 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -242,7 +241,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -273,7 +271,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { tau_dev_ptrs, group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/geqrf_batch_stride.cpp b/tests/unit_tests/lapack/source/geqrf_batch_stride.cpp index 112f5a673..16ceef63a 100644 --- a/tests/unit_tests/lapack/source/geqrf_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/geqrf_batch_stride.cpp @@ -69,7 +69,6 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, int64_ queue, scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, m, n, lda, stride_a, stride_tau, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -143,7 +142,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, queue, scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, m, n, lda, stride_a, stride_tau, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -161,7 +159,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, lda, stride_a, tau_dev, stride_tau, batch_size, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/gerqf.cpp b/tests/unit_tests/lapack/source/gerqf.cpp index cb83781dc..dac6d79aa 100644 --- a/tests/unit_tests/lapack/source/gerqf.cpp +++ b/tests/unit_tests/lapack/source/gerqf.cpp @@ -68,7 +68,6 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::gerqf_scratchpad_size, m, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -126,7 +125,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::gerqf_scratchpad_size, m, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -144,7 +142,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/gesvd.cpp b/tests/unit_tests/lapack/source/gesvd.cpp index a5b94eab5..1e143315b 100644 --- a/tests/unit_tests/lapack/source/gesvd.cpp +++ b/tests/unit_tests/lapack/source/gesvd.cpp @@ -79,7 +79,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::jobsvd jobu, oneapi::mkl::jo scratchpad_size = oneapi::mkl::lapack::gesvd_scratchpad_size, jobu, jobvt, m, n, lda, ldu, ldvt); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -202,7 +201,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::jobsvd jobu, oneapi::m scratchpad_size = oneapi::mkl::lapack::gesvd_scratchpad_size, jobu, jobvt, m, n, lda, ldu, ldvt); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -220,7 +218,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::jobsvd jobu, oneapi::m A_dev, lda, s_dev, U_dev, ldu, Vt_dev, ldvt, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getrf.cpp b/tests/unit_tests/lapack/source/getrf.cpp index 3521cd373..4537ef665 100644 --- a/tests/unit_tests/lapack/source/getrf.cpp +++ b/tests/unit_tests/lapack/source/getrf.cpp @@ -71,7 +71,6 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::getrf_scratchpad_size, m, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -129,7 +128,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::getrf_scratchpad_size, m, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -147,7 +145,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, ipiv_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getrf_batch_group.cpp b/tests/unit_tests/lapack/source/getrf_batch_group.cpp index 0410597e3..12e651746 100644 --- a/tests/unit_tests/lapack/source/getrf_batch_group.cpp +++ b/tests/unit_tests/lapack/source/getrf_batch_group.cpp @@ -107,7 +107,6 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -251,7 +250,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -282,7 +280,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { A_dev_ptrs, lda_vec.data(), ipiv_dev_ptrs, group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getrf_batch_stride.cpp b/tests/unit_tests/lapack/source/getrf_batch_stride.cpp index 1dde4bed3..3e4ef6589 100644 --- a/tests/unit_tests/lapack/source/getrf_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/getrf_batch_stride.cpp @@ -69,7 +69,6 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, int64_ queue, scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, m, n, lda, stride_a, stride_ipiv, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -143,7 +142,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, queue, scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, m, n, lda, stride_a, stride_ipiv, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -161,7 +159,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, lda, stride_a, ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getri.cpp b/tests/unit_tests/lapack/source/getri.cpp index f3b682683..a1aa2deda 100644 --- a/tests/unit_tests/lapack/source/getri.cpp +++ b/tests/unit_tests/lapack/source/getri.cpp @@ -76,7 +76,6 @@ bool accuracy(const sycl::device& dev, int64_t n, int64_t lda, uint64_t seed) { TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::getri_scratchpad_size, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -139,7 +138,6 @@ bool usm_dependency(const sycl::device& dev, int64_t n, int64_t lda, uint64_t se TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::getri_scratchpad_size, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -158,7 +156,6 @@ bool usm_dependency(const sycl::device& dev, int64_t n, int64_t lda, uint64_t se ipiv_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getri_batch_group.cpp b/tests/unit_tests/lapack/source/getri_batch_group.cpp index 6a23549b8..244acfcc8 100644 --- a/tests/unit_tests/lapack/source/getri_batch_group.cpp +++ b/tests/unit_tests/lapack/source/getri_batch_group.cpp @@ -114,7 +114,6 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -263,7 +262,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -297,7 +295,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getri_batch_stride.cpp b/tests/unit_tests/lapack/source/getri_batch_stride.cpp index 6dbe5908e..5a71d2d7e 100644 --- a/tests/unit_tests/lapack/source/getri_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/getri_batch_stride.cpp @@ -76,7 +76,6 @@ bool accuracy(const sycl::device& dev, int64_t n, int64_t lda, int64_t stride_a, queue, scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, n, lda, stride_a, stride_ipiv, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -157,7 +156,6 @@ bool usm_dependency(const sycl::device& dev, int64_t n, int64_t lda, int64_t str queue, scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, n, lda, stride_a, stride_ipiv, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -176,7 +174,6 @@ bool usm_dependency(const sycl::device& dev, int64_t n, int64_t lda, int64_t str lda, stride_a, ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getrs.cpp b/tests/unit_tests/lapack/source/getrs.cpp index 6a39ce656..bfc271758 100644 --- a/tests/unit_tests/lapack/source/getrs.cpp +++ b/tests/unit_tests/lapack/source/getrs.cpp @@ -77,7 +77,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::transpose trans, int64_t n, scratchpad_size = oneapi::mkl::lapack::getrs_scratchpad_size, trans, n, nrhs, lda, ldb); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -150,7 +149,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::transpose trans, int64 scratchpad_size = oneapi::mkl::lapack::getrs_scratchpad_size, trans, n, nrhs, lda, ldb); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -170,7 +168,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::transpose trans, int64 A_dev, lda, ipiv_dev, B_dev, ldb, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getrs_batch_group.cpp b/tests/unit_tests/lapack/source/getrs_batch_group.cpp index bc1d4aaee..2027663e4 100644 --- a/tests/unit_tests/lapack/source/getrs_batch_group.cpp +++ b/tests/unit_tests/lapack/source/getrs_batch_group.cpp @@ -137,7 +137,6 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { trans_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), ldb_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -324,7 +323,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { trans_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), ldb_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -364,7 +362,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/getrs_batch_stride.cpp b/tests/unit_tests/lapack/source/getrs_batch_stride.cpp index 609a67dda..1faf3d3e6 100644 --- a/tests/unit_tests/lapack/source/getrs_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/getrs_batch_stride.cpp @@ -82,7 +82,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::transpose trans, int64_t n, queue, scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size, trans, n, nrhs, lda, stride_a, stride_ipiv, ldb, stride_b, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -172,7 +171,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::transpose trans, int64 queue, scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size, trans, n, nrhs, lda, stride_a, stride_ipiv, ldb, stride_b, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -194,7 +192,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::transpose trans, int64 stride_b, batch_size, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/heevd.cpp b/tests/unit_tests/lapack/source/heevd.cpp index 10d927c6d..62c23c3ad 100644 --- a/tests/unit_tests/lapack/source/heevd.cpp +++ b/tests/unit_tests/lapack/source/heevd.cpp @@ -66,7 +66,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl::uplo scratchpad_size = oneapi::mkl::lapack::heevd_scratchpad_size, jobz, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -125,7 +124,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl: scratchpad_size = oneapi::mkl::lapack::heevd_scratchpad_size, jobz, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -143,7 +141,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl: A_dev, lda, w_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/hegvd.cpp b/tests/unit_tests/lapack/source/hegvd.cpp index b58fb0199..9a109e6b8 100644 --- a/tests/unit_tests/lapack/source/hegvd.cpp +++ b/tests/unit_tests/lapack/source/hegvd.cpp @@ -72,7 +72,6 @@ bool accuracy(const sycl::device& dev, int64_t itype, oneapi::mkl::job jobz, one scratchpad_size = oneapi::mkl::lapack::hegvd_scratchpad_size, itype, jobz, uplo, n, lda, ldb); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -261,7 +260,6 @@ bool usm_dependency(const sycl::device& dev, int64_t itype, oneapi::mkl::job job scratchpad_size = oneapi::mkl::lapack::hegvd_scratchpad_size, itype, jobz, uplo, n, lda, ldb); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -280,7 +278,6 @@ bool usm_dependency(const sycl::device& dev, int64_t itype, oneapi::mkl::job job n, A_dev, lda, B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/hetrd.cpp b/tests/unit_tests/lapack/source/hetrd.cpp index 6ee81c54f..13172d64f 100644 --- a/tests/unit_tests/lapack/source/hetrd.cpp +++ b/tests/unit_tests/lapack/source/hetrd.cpp @@ -69,7 +69,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::hetrd_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -183,7 +182,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::hetrd_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -204,7 +202,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, d_dev, e_dev, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/hetrf.cpp b/tests/unit_tests/lapack/source/hetrf.cpp index cb4589fe9..73535a77f 100644 --- a/tests/unit_tests/lapack/source/hetrf.cpp +++ b/tests/unit_tests/lapack/source/hetrf.cpp @@ -69,7 +69,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::hetrf_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -243,7 +242,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::hetrf_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -262,7 +260,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, ipiv_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/orgbr.cpp b/tests/unit_tests/lapack/source/orgbr.cpp index 40f853611..274cafce0 100644 --- a/tests/unit_tests/lapack/source/orgbr.cpp +++ b/tests/unit_tests/lapack/source/orgbr.cpp @@ -86,7 +86,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::generate vect, int64_t m, in scratchpad_size = oneapi::mkl::lapack::orgbr_scratchpad_size, vect, m, n, k, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -162,7 +161,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::generate vect, int64_t scratchpad_size = oneapi::mkl::lapack::orgbr_scratchpad_size, vect, m, n, k, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -181,7 +179,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::generate vect, int64_t A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/orgqr.cpp b/tests/unit_tests/lapack/source/orgqr.cpp index d16ea3992..9d62daf5f 100644 --- a/tests/unit_tests/lapack/source/orgqr.cpp +++ b/tests/unit_tests/lapack/source/orgqr.cpp @@ -74,7 +74,6 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::orgqr_scratchpad_size, m, n, k, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -136,7 +135,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::orgqr_scratchpad_size, m, n, k, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -155,7 +153,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/orgqr_batch_group.cpp b/tests/unit_tests/lapack/source/orgqr_batch_group.cpp index 2e7e6a489..3af796e7d 100644 --- a/tests/unit_tests/lapack/source/orgqr_batch_group.cpp +++ b/tests/unit_tests/lapack/source/orgqr_batch_group.cpp @@ -111,7 +111,6 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -257,7 +256,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -291,7 +289,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/orgqr_batch_stride.cpp b/tests/unit_tests/lapack/source/orgqr_batch_stride.cpp index b8b37673b..1cf3471c5 100644 --- a/tests/unit_tests/lapack/source/orgqr_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/orgqr_batch_stride.cpp @@ -75,7 +75,6 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t queue, scratchpad_size = oneapi::mkl::lapack::orgqr_batch_scratchpad_size, m, n, k, lda, stride_a, stride_tau, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -154,7 +153,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in queue, scratchpad_size = oneapi::mkl::lapack::orgqr_batch_scratchpad_size, m, n, k, lda, stride_a, stride_tau, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -174,7 +172,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/orgtr.cpp b/tests/unit_tests/lapack/source/orgtr.cpp index 309a90b6e..5a01745d5 100644 --- a/tests/unit_tests/lapack/source/orgtr.cpp +++ b/tests/unit_tests/lapack/source/orgtr.cpp @@ -72,7 +72,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::orgtr_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -137,7 +136,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::orgtr_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -156,7 +154,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ormqr.cpp b/tests/unit_tests/lapack/source/ormqr.cpp index e8192c5db..e2ed49b96 100644 --- a/tests/unit_tests/lapack/source/ormqr.cpp +++ b/tests/unit_tests/lapack/source/ormqr.cpp @@ -83,7 +83,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl scratchpad_size = oneapi::mkl::lapack::ormqr_scratchpad_size, left_right, trans, m, n, k, lda, ldc); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -171,7 +170,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, scratchpad_size = oneapi::mkl::lapack::ormqr_scratchpad_size, left_right, trans, m, n, k, lda, ldc); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -191,7 +189,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ormtr.cpp b/tests/unit_tests/lapack/source/ormtr.cpp index 1038f155e..4e8dd95b9 100644 --- a/tests/unit_tests/lapack/source/ormtr.cpp +++ b/tests/unit_tests/lapack/source/ormtr.cpp @@ -81,7 +81,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, int64_ scratchpad_size = oneapi::mkl::lapack::ormtr_scratchpad_size, side, uplo, trans, m, n, lda, ldc); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -170,7 +169,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, scratchpad_size = oneapi::mkl::lapack::ormtr_scratchpad_size, side, uplo, trans, m, n, lda, ldc); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -190,7 +188,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, m, n, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potrf.cpp b/tests/unit_tests/lapack/source/potrf.cpp index 82de4884c..7d2df8ea9 100644 --- a/tests/unit_tests/lapack/source/potrf.cpp +++ b/tests/unit_tests/lapack/source/potrf.cpp @@ -67,7 +67,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::potrf_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -121,7 +120,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::potrf_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -139,7 +137,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potrf_batch_group.cpp b/tests/unit_tests/lapack/source/potrf_batch_group.cpp index 33a233b73..4a5b8dd58 100644 --- a/tests/unit_tests/lapack/source/potrf_batch_group.cpp +++ b/tests/unit_tests/lapack/source/potrf_batch_group.cpp @@ -99,7 +99,6 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, uplo_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -223,7 +222,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { queue, scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, uplo_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -251,7 +249,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potrf_batch_stride.cpp b/tests/unit_tests/lapack/source/potrf_batch_stride.cpp index deb31fb43..fae4f0bcc 100644 --- a/tests/unit_tests/lapack/source/potrf_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/potrf_batch_stride.cpp @@ -66,7 +66,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ queue, scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, uplo, n, lda, stride_a, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -133,7 +132,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, queue, scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, uplo, n, lda, stride_a, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -151,7 +149,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, A_dev, lda, stride_a, batch_size, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potri.cpp b/tests/unit_tests/lapack/source/potri.cpp index bcef465e4..cd2f86449 100644 --- a/tests/unit_tests/lapack/source/potri.cpp +++ b/tests/unit_tests/lapack/source/potri.cpp @@ -71,7 +71,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::potri_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -155,7 +154,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::potri_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -173,7 +171,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potrs.cpp b/tests/unit_tests/lapack/source/potrs.cpp index 83400241c..c534ec8ba 100644 --- a/tests/unit_tests/lapack/source/potrs.cpp +++ b/tests/unit_tests/lapack/source/potrs.cpp @@ -75,7 +75,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ scratchpad_size = oneapi::mkl::lapack::potrs_scratchpad_size, uplo, n, nrhs, lda, ldb); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -143,7 +142,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, scratchpad_size = oneapi::mkl::lapack::potrs_scratchpad_size, uplo, n, nrhs, lda, ldb); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -162,7 +160,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, A_dev, lda, B_dev, ldb, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potrs_batch_group.cpp b/tests/unit_tests/lapack/source/potrs_batch_group.cpp index 2f76d9cac..35c5ead0c 100644 --- a/tests/unit_tests/lapack/source/potrs_batch_group.cpp +++ b/tests/unit_tests/lapack/source/potrs_batch_group.cpp @@ -124,7 +124,6 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { uplo_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), ldb_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -288,7 +287,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { uplo_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), ldb_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -322,7 +320,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potrs_batch_stride.cpp b/tests/unit_tests/lapack/source/potrs_batch_stride.cpp index 079f32564..de2568e86 100644 --- a/tests/unit_tests/lapack/source/potrs_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/potrs_batch_stride.cpp @@ -80,7 +80,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ queue, scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size, uplo, n, nrhs, lda, stride_a, ldb, stride_b, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -166,7 +165,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, queue, scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size, uplo, n, nrhs, lda, stride_a, ldb, stride_b, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -186,7 +184,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/syevd.cpp b/tests/unit_tests/lapack/source/syevd.cpp index 6e166843e..291713354 100644 --- a/tests/unit_tests/lapack/source/syevd.cpp +++ b/tests/unit_tests/lapack/source/syevd.cpp @@ -66,7 +66,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl::uplo scratchpad_size = oneapi::mkl::lapack::syevd_scratchpad_size, jobz, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -125,7 +124,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl: scratchpad_size = oneapi::mkl::lapack::syevd_scratchpad_size, jobz, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -143,7 +141,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl: A_dev, lda, w_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/sygvd.cpp b/tests/unit_tests/lapack/source/sygvd.cpp index 726b8da8a..f800b03dd 100644 --- a/tests/unit_tests/lapack/source/sygvd.cpp +++ b/tests/unit_tests/lapack/source/sygvd.cpp @@ -72,7 +72,6 @@ bool accuracy(const sycl::device& dev, int64_t itype, oneapi::mkl::job jobz, one scratchpad_size = oneapi::mkl::lapack::sygvd_scratchpad_size, itype, jobz, uplo, n, lda, ldb); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -266,7 +265,6 @@ bool usm_dependency(const sycl::device& dev, int64_t itype, oneapi::mkl::job job scratchpad_size = oneapi::mkl::lapack::sygvd_scratchpad_size, itype, jobz, uplo, n, lda, ldb); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -285,7 +283,6 @@ bool usm_dependency(const sycl::device& dev, int64_t itype, oneapi::mkl::job job n, A_dev, lda, B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/sytrd.cpp b/tests/unit_tests/lapack/source/sytrd.cpp index 78ed3aaa3..01ffe0dff 100644 --- a/tests/unit_tests/lapack/source/sytrd.cpp +++ b/tests/unit_tests/lapack/source/sytrd.cpp @@ -69,7 +69,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::sytrd_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -183,7 +182,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::sytrd_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -204,7 +202,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, d_dev, e_dev, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/sytrf.cpp b/tests/unit_tests/lapack/source/sytrf.cpp index 42419fcec..81d7fdb2d 100644 --- a/tests/unit_tests/lapack/source/sytrf.cpp +++ b/tests/unit_tests/lapack/source/sytrf.cpp @@ -66,7 +66,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::sytrf_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -240,7 +239,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::sytrf_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -259,7 +257,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, ipiv_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/trtrs.cpp b/tests/unit_tests/lapack/source/trtrs.cpp index 7c0d86c09..4018a2c51 100644 --- a/tests/unit_tests/lapack/source/trtrs.cpp +++ b/tests/unit_tests/lapack/source/trtrs.cpp @@ -78,7 +78,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, oneapi::mkl::tran scratchpad_size = oneapi::mkl::lapack::trtrs_scratchpad_size, uplo, trans, diag, n, nrhs, lda, ldb); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -142,7 +141,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, oneapi::mkl scratchpad_size = oneapi::mkl::lapack::trtrs_scratchpad_size, uplo, trans, diag, n, nrhs, lda, ldb); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -161,7 +159,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, oneapi::mkl n, nrhs, A_dev, lda, B_dev, ldb, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ungbr.cpp b/tests/unit_tests/lapack/source/ungbr.cpp index 4e1919348..7cdf8e52a 100644 --- a/tests/unit_tests/lapack/source/ungbr.cpp +++ b/tests/unit_tests/lapack/source/ungbr.cpp @@ -86,7 +86,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::generate vect, int64_t m, in scratchpad_size = oneapi::mkl::lapack::ungbr_scratchpad_size, vect, m, n, k, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -162,7 +161,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::generate vect, int64_t scratchpad_size = oneapi::mkl::lapack::ungbr_scratchpad_size, vect, m, n, k, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -181,7 +179,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::generate vect, int64_t A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ungqr.cpp b/tests/unit_tests/lapack/source/ungqr.cpp index 14e775b8d..08b8b1192 100644 --- a/tests/unit_tests/lapack/source/ungqr.cpp +++ b/tests/unit_tests/lapack/source/ungqr.cpp @@ -73,7 +73,6 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::ungqr_scratchpad_size, m, n, k, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -135,7 +134,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::ungqr_scratchpad_size, m, n, k, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -154,7 +152,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ungqr_batch_group.cpp b/tests/unit_tests/lapack/source/ungqr_batch_group.cpp index c03f89837..ddb350828 100644 --- a/tests/unit_tests/lapack/source/ungqr_batch_group.cpp +++ b/tests/unit_tests/lapack/source/ungqr_batch_group.cpp @@ -111,7 +111,6 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -257,7 +256,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); auto A_dev_iter = A_dev_list.begin(); @@ -291,7 +289,6 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ungqr_batch_stride.cpp b/tests/unit_tests/lapack/source/ungqr_batch_stride.cpp index f9b4e5f73..e656b9fb7 100644 --- a/tests/unit_tests/lapack/source/ungqr_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/ungqr_batch_stride.cpp @@ -75,7 +75,6 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t queue, scratchpad_size = oneapi::mkl::lapack::ungqr_batch_scratchpad_size, m, n, k, lda, stride_a, stride_tau, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -154,7 +153,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in queue, scratchpad_size = oneapi::mkl::lapack::ungqr_batch_scratchpad_size, m, n, k, lda, stride_a, stride_tau, batch_size); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -174,7 +172,6 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ungtr.cpp b/tests/unit_tests/lapack/source/ungtr.cpp index e5edc7ef5..b0ad8e8f2 100644 --- a/tests/unit_tests/lapack/source/ungtr.cpp +++ b/tests/unit_tests/lapack/source/ungtr.cpp @@ -72,7 +72,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::ungtr_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -137,7 +136,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::ungtr_scratchpad_size, uplo, n, lda); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -156,7 +154,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, lda, tau_dev, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/unmqr.cpp b/tests/unit_tests/lapack/source/unmqr.cpp index 51cfdd004..2f555c1ca 100644 --- a/tests/unit_tests/lapack/source/unmqr.cpp +++ b/tests/unit_tests/lapack/source/unmqr.cpp @@ -83,7 +83,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl scratchpad_size = oneapi::mkl::lapack::unmqr_scratchpad_size, left_right, trans, m, n, k, lda, ldc); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -171,7 +170,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, scratchpad_size = oneapi::mkl::lapack::unmqr_scratchpad_size, left_right, trans, m, n, k, lda, ldc); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -191,7 +189,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/unmrq.cpp b/tests/unit_tests/lapack/source/unmrq.cpp index 6526533b5..628063837 100644 --- a/tests/unit_tests/lapack/source/unmrq.cpp +++ b/tests/unit_tests/lapack/source/unmrq.cpp @@ -93,7 +93,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl scratchpad_size = oneapi::mkl::lapack::unmrq_scratchpad_size, left_right, trans, m, n, k, lda, ldc); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -180,7 +179,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, scratchpad_size = oneapi::mkl::lapack::unmrq_scratchpad_size, left_right, trans, m, n, k, lda, ldc); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -200,7 +198,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/unmtr.cpp b/tests/unit_tests/lapack/source/unmtr.cpp index fa8ee60d6..8148c644d 100644 --- a/tests/unit_tests/lapack/source/unmtr.cpp +++ b/tests/unit_tests/lapack/source/unmtr.cpp @@ -81,7 +81,6 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, int64_ scratchpad_size = oneapi::mkl::lapack::unmtr_scratchpad_size, side, uplo, trans, m, n, lda, ldc); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -170,7 +169,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, scratchpad_size = oneapi::mkl::lapack::unmtr_scratchpad_size, side, uplo, trans, m, n, lda, ldc); #endif - queue.wait_and_throw(); auto scratchpad_dev = device_alloc(queue, scratchpad_size); host_to_device_copy(queue, A.data(), A_dev, A.size()); @@ -190,7 +188,6 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, m, n, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif - queue.wait_and_throw(); result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); From 61c9a533ae0c251da5b8a3b9a9b838676c317dfa Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Fri, 4 Oct 2024 10:30:01 -0400 Subject: [PATCH 09/11] Fix format. Signed-off-by: JackAKirk --- src/blas/backends/cublas/cublas_batch.cpp | 2590 ++++++++++++--------- 1 file changed, 1441 insertions(+), 1149 deletions(-) diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index 031c11a2f..e882f8ee7 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -29,152 +29,170 @@ namespace column_major { // Buffer APIs -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, - int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); +void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, + int64_t incx, int64_t stridex, sycl::buffer &y, + int64_t incy, int64_t stridey, int64_t batch_size) { + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, - int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); +void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, + int64_t incx, int64_t stridex, sycl::buffer &y, + int64_t incy, int64_t stridey, int64_t batch_size) { + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, int64_t stridex, sycl::buffer, 1> &y, +void copy_batch(sycl::queue &queue, int64_t n, + sycl::buffer, 1> &x, int64_t incx, + int64_t stridex, sycl::buffer, 1> &y, int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, int64_t stridex, sycl::buffer, 1> &y, +void copy_batch(sycl::queue &queue, int64_t n, + sycl::buffer, 1> &x, int64_t incx, + int64_t stridex, sycl::buffer, 1> &y, int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -void axpy_batch(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, int64_t incx, - int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, +void axpy_batch(sycl::queue &queue, int64_t n, float alpha, + sycl::buffer &x, int64_t incx, int64_t stridex, + sycl::buffer &y, int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -void axpy_batch(sycl::queue &queue, int64_t n, double alpha, sycl::buffer &x, - int64_t incx, int64_t stridex, sycl::buffer &y, int64_t incy, - int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); +void axpy_batch(sycl::queue &queue, int64_t n, double alpha, + sycl::buffer &x, int64_t incx, int64_t stridex, + sycl::buffer &y, int64_t incy, int64_t stridey, + int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } void axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - sycl::buffer, 1> &x, int64_t incx, int64_t stridex, - sycl::buffer, 1> &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + sycl::buffer, 1> &x, int64_t incx, + int64_t stridex, sycl::buffer, 1> &y, + int64_t incy, int64_t stridey, int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } void axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - sycl::buffer, 1> &x, int64_t incx, int64_t stridex, - sycl::buffer, 1> &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + sycl::buffer, 1> &x, int64_t incx, + int64_t stridex, sycl::buffer, 1> &y, + int64_t incy, int64_t stridey, int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, - int64_t incx, int64_t stride_x, float beta, sycl::buffer &y, int64_t incy, - int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &x, int64_t incx, + int64_t stride_x, float beta, sycl::buffer &y, + int64_t incy, int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &x, int64_t incx, int64_t stride_x, double beta, - sycl::buffer &y, int64_t incy, int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &x, int64_t incx, + int64_t stride_x, double beta, sycl::buffer &y, + int64_t incy, int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, int64_t incx, - int64_t stride_x, std::complex beta, sycl::buffer, 1> &y, - int64_t incy, int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &x, + int64_t incx, int64_t stride_x, std::complex beta, + sycl::buffer, 1> &y, int64_t incy, + int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, int64_t incx, - int64_t stride_x, std::complex beta, - sycl::buffer, 1> &y, int64_t incy, int64_t stride_y, - int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &x, + int64_t incx, int64_t stride_x, std::complex beta, + sycl::buffer, 1> &y, int64_t incy, + int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, - int64_t incx, int64_t stride_x, sycl::buffer &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &x, int64_t incx, int64_t stride_x, + sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, int64_t incx, int64_t stride_x, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &x, int64_t incx, int64_t stride_x, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &x, + int64_t incx, int64_t stride_x, + sycl::buffer, 1> &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &x, int64_t incx, int64_t stride_x, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &x, + int64_t incx, int64_t stride_x, + sycl::buffer, 1> &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } template -inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, Ts alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, - Ts beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - using cuTypeA = typename CudaEquivalentType::Type; - using cuTypeB = typename CudaEquivalentType::Type; - using cuTypeC = typename CudaEquivalentType::Type; - using cuTypeS = typename CudaEquivalentType::Type; - overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, batch_size); - - cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; - queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, sycl::aspect::fp16)) { - throw oneapi::mkl::unimplemented( - "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); - } - auto a_acc = a.template get_access(cgh); - auto b_acc = b.template get_access(cgh); - auto c_acc = c.template get_access(cgh); - onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - auto a_ = sc.get_mem(a_acc); - auto b_ = sc.get_mem(b_acc); - auto c_ = sc.get_mem(c_acc); - cublasStatus_t err; +inline void gemm_batch_impl(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, + Ts alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, + int64_t ldb, int64_t stride_b, Ts beta, + sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; + overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, + batch_size); + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; + queue.submit([&](sycl::handler &cgh) { + if (!verify_support(queue, + sycl::aspect::fp16)) { + throw oneapi::mkl::unimplemented( + "blas", "sycl::half", + "half is not supported by the device or the sycl compiler"); + } + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); + cublasStatus_t err; #ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, - err, handle, get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, &alpha, a_, - get_cublas_datatype(), lda, stride_a, b_, - get_cublas_datatype(), ldb, stride_b, &beta, c_, - get_cublas_datatype(), ldc, stride_c, batch_size, - get_cublas_datatype(), cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, + &alpha, a_, get_cublas_datatype(), lda, stride_a, b_, + get_cublas_datatype(), ldb, stride_b, &beta, c_, + get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); #else CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, @@ -186,449 +204,527 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran batch_size, get_cublas_datatype(), cublas_gemm_algo); #endif - }); }); + }); } -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size) { \ - gemm_batch_impl(queue, transa, transb, m, n, k, alpha, a, \ - lda, stride_a, b, ldb, stride_b, beta, c, \ - ldc, stride_c, batch_size); \ - } +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, \ + int64_t m, int64_t n, int64_t k, TYPE_S alpha, \ + sycl::buffer &a, int64_t lda, int64_t stride_a, \ + sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, \ + int64_t stride_c, int64_t batch_size) { \ + gemm_batch_impl( \ + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, \ + stride_b, beta, c, ldc, stride_c, batch_size); \ + } GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, - std::complex) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, - std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, + std::complex, std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, + std::complex, std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size) { \ - throw unimplemented("blas", "gemm_batch", \ - std::string("for dtype unimplemented dtype combination <") + \ - dtype_string() + "," + dtype_string() + "," + \ - dtype_string() + "," + dtype_string() + ">"); \ - } +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, \ + int64_t m, int64_t n, int64_t k, TYPE_S alpha, \ + sycl::buffer &a, int64_t lda, int64_t stride_a, \ + sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, \ + int64_t stride_c, int64_t batch_size) { \ + throw unimplemented( \ + "blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, double alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, float beta, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, float beta, + sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, - double beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, double alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, double beta, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, std::complex beta, sycl::buffer, 1> &c, - int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, + sycl::buffer, 1> &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, std::complex beta, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + sycl::buffer, 1> &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, + int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, + int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, - int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, - int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, + std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, - int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + float alpha, sycl::buffer &ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, - int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + double alpha, sycl::buffer &ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &ab, - int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + std::complex alpha, + sycl::buffer, 1> &ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &ab, - int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + std::complex alpha, + sycl::buffer, 1> &ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, - float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, float beta, + sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, - double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &b, int64_t ldb, + int64_t stride_b, sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, std::complex beta, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); + sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, sycl::buffer, 1> &c, + int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, - int64_t lda, int64_t stride_a, std::complex beta, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, sycl::buffer, 1> &c, + int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } // USM APIs -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, int64_t *incx, float **y, - int64_t *incy, int64_t group_count, int64_t *group_size, +sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, + int64_t *incx, float **y, int64_t *incy, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const double **x, int64_t *incx, double **y, - int64_t *incy, int64_t group_count, int64_t *group_size, +sycl::event copy_batch(sycl::queue &queue, int64_t *n, const double **x, + int64_t *incx, double **y, int64_t *incy, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const std::complex **x, int64_t *incx, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t *n, + const std::complex **x, int64_t *incx, + std::complex **y, int64_t *incy, + int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const std::complex **x, - int64_t *incx, std::complex **y, int64_t *incy, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t *n, + const std::complex **x, int64_t *incx, + std::complex **y, int64_t *incy, + int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const float *x, int64_t incx, - std::int64_t stridex, float *y, int64_t incy, std::int64_t stridey, - std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t n, const float *x, + int64_t incx, std::int64_t stridex, float *y, + int64_t incy, std::int64_t stridey, + std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const double *x, int64_t incx, - std::int64_t stridex, double *y, int64_t incy, std::int64_t stridey, - std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t n, const double *x, + int64_t incx, std::int64_t stridex, double *y, + int64_t incy, std::int64_t stridey, + std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - std::int64_t stridex, std::complex *y, int64_t incy, - std::int64_t stridey, std::int64_t batch_size, +sycl::event copy_batch(sycl::queue &queue, int64_t n, + const std::complex *x, int64_t incx, + std::int64_t stridex, std::complex *y, + int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - std::int64_t stridex, std::complex *y, int64_t incy, - std::int64_t stridey, std::int64_t batch_size, +sycl::event copy_batch(sycl::queue &queue, int64_t n, + const std::complex *x, int64_t incx, + std::int64_t stridex, std::complex *y, + int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, float *alpha, const float **x, int64_t *incx, - float **y, int64_t *incy, int64_t group_count, int64_t *group_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, float *alpha, + const float **x, int64_t *incx, float **y, int64_t *incy, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, double *alpha, const double **x, - int64_t *incx, double **y, int64_t *incy, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, double *alpha, + const double **x, int64_t *incx, double **y, + int64_t *incy, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, std::complex *alpha, - const std::complex **x, int64_t *incx, std::complex **y, - int64_t *incy, int64_t group_count, int64_t *group_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, + std::complex *alpha, + const std::complex **x, int64_t *incx, + std::complex **y, int64_t *incy, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, std::complex *alpha, - const std::complex **x, int64_t *incx, std::complex **y, - int64_t *incy, int64_t group_count, int64_t *group_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, + std::complex *alpha, + const std::complex **x, int64_t *incx, + std::complex **y, int64_t *incy, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, float alpha, const float *x, int64_t incx, - int64_t stridex, float *y, int64_t incy, int64_t stridey, int64_t batch_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t n, float alpha, + const float *x, int64_t incx, int64_t stridex, float *y, + int64_t incy, int64_t stridey, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, double alpha, const double *x, int64_t incx, - int64_t stridex, double *y, int64_t incy, int64_t stridey, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); +sycl::event axpy_batch(sycl::queue &queue, int64_t n, double alpha, + const double *x, int64_t incx, int64_t stridex, + double *y, int64_t incy, int64_t stridey, + int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } sycl::event axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, int64_t stridex, - std::complex *y, int64_t incy, int64_t stridey, int64_t batch_size, + const std::complex *x, int64_t incx, + int64_t stridex, std::complex *y, int64_t incy, + int64_t stridey, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, int64_t stridex, - std::complex *y, int64_t incy, int64_t stridey, int64_t batch_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t n, + std::complex alpha, + const std::complex *x, int64_t incx, + int64_t stridex, std::complex *y, int64_t incy, + int64_t stridey, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, int64_t stride_a, const float *x, int64_t incx, - int64_t stride_x, float beta, float *y, int64_t incy, int64_t stride_y, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, + int64_t stride_a, const float *x, int64_t incx, + int64_t stride_x, float beta, float *y, int64_t incy, + int64_t stride_y, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, int64_t stride_a, const double *x, - int64_t incx, int64_t stride_x, double beta, double *y, int64_t incy, +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, + int64_t n, double alpha, const double *a, int64_t lda, + int64_t stride_a, const double *x, int64_t incx, + int64_t stride_x, double beta, double *y, int64_t incy, int64_t stride_y, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, int64_t incx, - int64_t stride_x, std::complex beta, std::complex *y, - int64_t incy, int64_t stride_y, int64_t batch_size, +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, + int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, const std::complex *x, + int64_t incx, int64_t stride_x, std::complex beta, + std::complex *y, int64_t incy, int64_t stride_y, + int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, int64_t incx, - int64_t stride_x, std::complex beta, std::complex *y, +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, + int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, const std::complex *x, + int64_t incx, int64_t stride_x, + std::complex beta, std::complex *y, int64_t incy, int64_t stride_y, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, float *alpha, - const float **a, int64_t *lda, const float **x, int64_t *incx, float *beta, - float **y, int64_t *incy, int64_t group_count, int64_t *groupsize, +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, + int64_t *n, float *alpha, const float **a, int64_t *lda, + const float **x, int64_t *incx, float *beta, float **y, + int64_t *incy, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, double *alpha, - const double **a, int64_t *lda, const double **x, int64_t *incx, - double *beta, double **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, + int64_t *n, double *alpha, const double **a, + int64_t *lda, const double **x, int64_t *incx, + double *beta, double **y, int64_t *incy, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, std::complex *beta, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, + int64_t *n, std::complex *alpha, + const std::complex **a, int64_t *lda, + const std::complex **x, int64_t *incx, + std::complex *beta, std::complex **y, + int64_t *incy, int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, std::complex *beta, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, + int64_t *n, std::complex *alpha, + const std::complex **a, int64_t *lda, + const std::complex **x, int64_t *incx, + std::complex *beta, std::complex **y, + int64_t *incy, int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const float *a, - int64_t lda, int64_t stride_a, const float *x, int64_t incx, - int64_t stride_x, float *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, + int64_t n, const float *a, int64_t lda, int64_t stride_a, + const float *x, int64_t incx, int64_t stride_x, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const double *a, - int64_t lda, int64_t stride_a, const double *x, int64_t incx, - int64_t stride_x, double *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, + int64_t n, const double *a, int64_t lda, + int64_t stride_a, const double *x, int64_t incx, + int64_t stride_x, double *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - const std::complex *a, int64_t lda, int64_t stride_a, - const std::complex *x, int64_t incx, int64_t stride_x, - std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, + int64_t n, const std::complex *a, int64_t lda, + int64_t stride_a, const std::complex *x, + int64_t incx, int64_t stride_x, std::complex *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - const std::complex *a, int64_t lda, int64_t stride_a, - const std::complex *x, int64_t incx, int64_t stride_x, - std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, + int64_t n, const std::complex *a, int64_t lda, + int64_t stride_a, const std::complex *x, + int64_t incx, int64_t stride_x, std::complex *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const float **a, int64_t *lda, const float **x, int64_t *incx, float **c, - int64_t *ldc, int64_t group_count, int64_t *groupsize, +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, + int64_t *n, const float **a, int64_t *lda, + const float **x, int64_t *incx, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const double **a, int64_t *lda, const double **x, int64_t *incx, double **c, +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, + int64_t *n, const double **a, int64_t *lda, + const double **x, int64_t *incx, double **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const std::complex **a, int64_t *lda, const std::complex **x, - int64_t *incx, std::complex **c, int64_t *ldc, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, + int64_t *n, const std::complex **a, int64_t *lda, + const std::complex **x, int64_t *incx, + std::complex **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const std::complex **a, int64_t *lda, const std::complex **x, - int64_t *incx, std::complex **c, int64_t *ldc, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, + int64_t *n, const std::complex **a, int64_t *lda, + const std::complex **x, int64_t *incx, + std::complex **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } template -inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, int64_t k, - Ts alpha, const Ta *a, int64_t lda, int64_t stride_a, - const Tb *b, int64_t ldb, int64_t stride_b, Ts beta, - Tc *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, - const std::vector &dependencies) { - using cuTypeA = typename CudaEquivalentType::Type; - using cuTypeB = typename CudaEquivalentType::Type; - using cuTypeC = typename CudaEquivalentType::Type; - using cuTypeS = typename CudaEquivalentType::Type; - overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, batch_size); - - cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; - auto done = queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, sycl::aspect::fp16)) { - throw oneapi::mkl::unimplemented( - "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); - } - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } - onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - cublasStatus_t err; +inline sycl::event gemm_batch_strided_usm_impl( + sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, int64_t k, Ts alpha, const Ta *a, int64_t lda, int64_t stride_a, + const Tb *b, int64_t ldb, int64_t stride_b, Ts beta, Tc *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; + overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, + batch_size); + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; + auto done = queue.submit([&](sycl::handler &cgh) { + if (!verify_support(queue, + sycl::aspect::fp16)) { + throw oneapi::mkl::unimplemented( + "blas", "sycl::half", + "half is not supported by the device or the sycl compiler"); + } + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + cublasStatus_t err; #ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, - err, handle, get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, &alpha, a, - get_cublas_datatype(), lda, stride_a, b, - get_cublas_datatype(), ldb, stride_b, &beta, c, - get_cublas_datatype(), ldc, stride_c, batch_size, - get_cublas_datatype(), cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, + &alpha, a, get_cublas_datatype(), lda, stride_a, b, + get_cublas_datatype(), ldb, stride_b, &beta, c, + get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); #else CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, @@ -640,44 +736,47 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra batch_size, get_cublas_datatype(), cublas_gemm_algo); #endif - }); }); - return done; -} - -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ - int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ - TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size, const std::vector &dependencies) { \ - return gemm_batch_strided_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, \ - stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, \ - batch_size, dependencies); \ - } + }); + return done; +} + +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch( \ + sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + return gemm_batch_strided_usm_impl( \ + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, \ + stride_b, beta, c, ldc, stride_c, batch_size, dependencies); \ + } GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, - std::complex) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, - std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, + std::complex, std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, + std::complex, std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ - int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ - TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size, const std::vector &dependencies) { \ - throw unimplemented("blas", "gemm_batch", \ - std::string("for dtype unimplemented dtype combination <") + \ - dtype_string() + "," + dtype_string() + "," + \ - dtype_string() + "," + dtype_string() + ">"); \ - } +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch( \ + sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented( \ + "blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) @@ -685,45 +784,48 @@ GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM template -inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, - int64_t *m, int64_t *n, int64_t *k, Ts *alpha, const Ta **a, - int64_t *lda, const Tb **b, int64_t *ldb, Ts *beta, Tc **c, - int64_t *ldc, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - using cuTypeA = typename CudaEquivalentType::Type; - using cuTypeB = typename CudaEquivalentType::Type; - using cuTypeC = typename CudaEquivalentType::Type; - using cuTypeS = typename CudaEquivalentType::Type; - for (int64_t i = 0; i < group_count; i++) { - overflow_check(m[i], n[i], k[i], lda[i], ldb[i], ldc[i], group_size[i]); +inline sycl::event +gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, + int64_t *m, int64_t *n, int64_t *k, Ts *alpha, const Ta **a, + int64_t *lda, const Tb **b, int64_t *ldb, Ts *beta, Tc **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; + for (int64_t i = 0; i < group_count; i++) { + overflow_check(m[i], n[i], k[i], lda[i], ldb[i], ldc[i], group_size[i]); + } + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; + auto done = queue.submit([&](sycl::handler &cgh) { + if (!verify_support(queue, + sycl::aspect::fp16)) { + throw oneapi::mkl::unimplemented( + "blas", "sycl::half", + "half is not supported by the device or the sycl compiler"); } - - cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; - auto done = queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, sycl::aspect::fp16)) { - throw oneapi::mkl::unimplemented( - "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); - } - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } - onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - int64_t offset = 0; - cublasStatus_t err; - for (int64_t i = 0; i < group_count; i++) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + int64_t offset = 0; + cublasStatus_t err; + for (int64_t i = 0; i < group_count; i++) { #ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - CUBLAS_ERROR_FUNC_T("cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, - get_cublas_operation(transa[i]), - get_cublas_operation(transb[i]), (int)m[i], (int)n[i], - (int)k[i], &alpha[i], (const void *const *)(a + offset), - get_cublas_datatype(), (int)lda[i], - (const void *const *)(b + offset), - get_cublas_datatype(), (int)ldb[i], &beta[i], - (void *const *)(c + offset), get_cublas_datatype(), - (int)ldc[i], (int)group_size[i], - get_cublas_datatype(), cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T( + "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, + get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), + (int)m[i], (int)n[i], (int)k[i], &alpha[i], + (const void *const *)(a + offset), get_cublas_datatype(), + (int)lda[i], (const void *const *)(b + offset), + get_cublas_datatype(), (int)ldb[i], &beta[i], + (void *const *)(c + offset), get_cublas_datatype(), + (int)ldc[i], (int)group_size[i], get_cublas_datatype(), + cublas_gemm_algo); #else CUBLAS_ERROR_FUNC_T_SYNC( "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, @@ -735,127 +837,139 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr get_cublas_datatype(), (int)ldc[i], (int)group_size[i], get_cublas_datatype(), cublas_gemm_algo); #endif - offset += group_size[i]; - } - }); + offset += group_size[i]; + } }); - return done; -} - -#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ - const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return gemm_batch_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, \ - ldc, group_count, group_size, dependencies); \ - } + }); + return done; +} + +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch( \ + sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return gemm_batch_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, \ + b, ldb, beta, c, ldc, group_count, group_size, \ + dependencies); \ + } GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) GEMM_BATCH_LAUNCHER_USM(float, float, float, float) GEMM_BATCH_LAUNCHER_USM(double, double, double, double) -GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, - std::complex) -GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, - std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, + std::complex, std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, + std::complex, std::complex) #undef GEMM_BATCH_LAUNCHER_USM -#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ - const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - throw unimplemented("blas", "gemm_batch", \ - std::string("for dtype unimplemented dtype combination <") + \ - dtype_string() + "," + dtype_string() + "," + \ - dtype_string() + "," + dtype_string() + ">"); \ - } +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch( \ + sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented( \ + "blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_BATCH_LAUNCHER_USM -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, - int64_t lda, int64_t stride_a, float *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + float alpha, const float *a, int64_t lda, + int64_t stride_a, float *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, double alpha, const double *a, - int64_t lda, int64_t stride_a, double *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + double alpha, const double *a, int64_t lda, + int64_t stride_a, double *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, int64_t stride_a, - std::complex *b, int64_t ldb, int64_t stride_b, int64_t batch_size, +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex *b, + int64_t ldb, int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, int64_t stride_a, - std::complex *b, int64_t ldb, int64_t stride_b, int64_t batch_size, +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } template -inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &queue, - side *left_right, uplo *upper_lower, transpose *trans, - diag *unit_diag, int64_t *m, int64_t *n, T *alpha, const T **a, - int64_t *lda, T **b, int64_t *ldb, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - using cuDataType = typename CudaEquivalentType::Type; - for (int64_t i = 0; i < group_count; i++) { - overflow_check(m[i], n[i], lda[i], ldb[i], group_size[i]); +inline sycl::event +trsm_batch(const char *func_name, Func func, sycl::queue &queue, + side *left_right, uplo *upper_lower, transpose *trans, + diag *unit_diag, int64_t *m, int64_t *n, T *alpha, const T **a, + int64_t *lda, T **b, int64_t *ldb, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + for (int64_t i = 0; i < group_count; i++) { + overflow_check(m[i], n[i], lda[i], ldb[i], group_size[i]); + } + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); } - auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } - onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - int64_t offset = 0; - cublasStatus_t err; - for (int64_t i = 0; i < group_count; i++) { - auto **a_ = reinterpret_cast(a); - auto **b_ = reinterpret_cast(b); - cublas_native_named_func(func_name, func, err, handle, - get_cublas_side_mode(left_right[i]), - get_cublas_fill_mode(upper_lower[i]), - get_cublas_operation(trans[i]), - get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i], - (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], - b_ + offset, (int)ldb[i], (int)group_size[i]); - offset += group_size[i]; - } - }); + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + int64_t offset = 0; + cublasStatus_t err; + for (int64_t i = 0; i < group_count; i++) { + auto **a_ = reinterpret_cast(a); + auto **b_ = reinterpret_cast(b); + cublas_native_named_func( + func_name, func, err, handle, get_cublas_side_mode(left_right[i]), + get_cublas_fill_mode(upper_lower[i]), + get_cublas_operation(trans[i]), get_cublas_diag_type(unit_diag[i]), + (int)m[i], (int)n[i], (cuDataType *)&alpha[i], a_ + offset, + (int)lda[i], b_ + offset, (int)ldb[i], (int)group_size[i]); + offset += group_size[i]; + } }); - return done; -} - -#define TRSM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event trsm_batch(sycl::queue &queue, side *left_right, uplo *upper_lower, \ - transpose *trans, diag *unit_diag, int64_t *m, int64_t *n, TYPE *alpha, \ - const TYPE **a, int64_t *lda, TYPE **b, int64_t *ldb, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return trsm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, left_right, upper_lower, trans, \ - unit_diag, m, n, alpha, a, lda, b, ldb, group_count, group_size, \ - dependencies); \ - } + }); + return done; +} + +#define TRSM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ + sycl::event trsm_batch(sycl::queue &queue, side *left_right, \ + uplo *upper_lower, transpose *trans, diag *unit_diag, \ + int64_t *m, int64_t *n, TYPE *alpha, const TYPE **a, \ + int64_t *lda, TYPE **b, int64_t *ldb, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return trsm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, left_right, \ + upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, \ + ldb, group_count, group_size, dependencies); \ + } TRSM_BATCH_LAUNCHER_USM(float, cublasStrsmBatched) TRSM_BATCH_LAUNCHER_USM(double, cublasDtrsmBatched) @@ -864,209 +978,249 @@ TRSM_BATCH_LAUNCHER_USM(std::complex, cublasZtrsmBatched) #undef TRSM_BATCH_LAUNCHER_USM -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, float *alpha, const float **a, int64_t *lda, float *beta, - float **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, + int64_t *n, int64_t *k, float *alpha, const float **a, + int64_t *lda, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, double *alpha, const double **a, int64_t *lda, double *beta, - double **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, + int64_t *n, int64_t *k, double *alpha, const double **a, + int64_t *lda, double *beta, double **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, std::complex *alpha, const std::complex **a, - int64_t *lda, std::complex *beta, std::complex **c, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, + int64_t *n, int64_t *k, std::complex *alpha, + const std::complex **a, int64_t *lda, + std::complex *beta, std::complex **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, std::complex *alpha, const std::complex **a, - int64_t *lda, std::complex *beta, std::complex **c, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, + int64_t *n, int64_t *k, std::complex *alpha, + const std::complex **a, int64_t *lda, + std::complex *beta, std::complex **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - float alpha, const float *a, int64_t lda, int64_t stride_a, float beta, - float *c, int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, float alpha, const float *a, + int64_t lda, int64_t stride_a, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - double alpha, const double *a, int64_t lda, int64_t stride_a, double beta, - double *c, int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, double alpha, const double *a, + int64_t lda, int64_t stride_a, double beta, double *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex beta, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex beta, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, + int64_t stride_a, float *b, int64_t ldb, int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, - int64_t stride_b, int64_t batch_size, +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, double alpha, const double *a, + int64_t lda, int64_t stride_a, double *b, + int64_t ldb, int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, + int64_t ldb, int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, + int64_t ldb, int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, float alpha, float *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - double *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, double alpha, double *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, std::complex *ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, std::complex alpha, + std::complex *ab, int64_t lda, int64_t ldb, + int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, std::complex *ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, std::complex alpha, + std::complex *ab, int64_t lda, int64_t ldb, + int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, - float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, + float beta, const float *b, int64_t ldb, + int64_t stride_b, float *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, - double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, + double beta, const double *b, int64_t ldb, + int64_t stride_b, double *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, std::complex alpha, const std::complex *a, - int64_t lda, int64_t stride_a, std::complex beta, - const std::complex *b, int64_t ldb, int64_t stride_b, - std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, + std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, + int64_t stride_b, std::complex *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, std::complex alpha, const std::complex *a, - int64_t lda, int64_t stride_a, std::complex beta, - const std::complex *b, int64_t ldb, int64_t stride_b, - std::complex *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, + std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, + int64_t stride_b, std::complex *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - float *alpha, const float **a, int64_t *lda, float **b, int64_t *ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, float *alpha, const float **a, + int64_t *lda, float **b, int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - double *alpha, const double **a, int64_t *lda, double **b, int64_t *ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, double *alpha, const double **a, + int64_t *lda, double **b, int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - std::complex **b, int64_t *ldb, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, std::complex *alpha, + const std::complex **a, int64_t *lda, + std::complex **b, int64_t *ldb, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, - int64_t *lda, std::complex **b, int64_t *ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, std::complex *alpha, + const std::complex **a, int64_t *lda, + std::complex **b, int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - float *alpha, float **ab, int64_t *lda, int64_t *ldb, - int64_t group_count, int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, float *alpha, float **ab, int64_t *lda, + int64_t *ldb, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - double *alpha, double **ab, int64_t *lda, int64_t *ldb, - int64_t group_count, int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, double *alpha, double **ab, int64_t *lda, + int64_t *ldb, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - std::complex *alpha, std::complex **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, std::complex *alpha, + std::complex **ab, int64_t *lda, int64_t *ldb, + int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - std::complex *alpha, std::complex **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, std::complex *alpha, + std::complex **ab, int64_t *lda, + int64_t *ldb, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } } // namespace column_major @@ -1074,125 +1228,139 @@ namespace row_major { // Buffer APIs -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, - int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, + int64_t incx, int64_t stridex, sycl::buffer &y, + int64_t incy, int64_t stridey, int64_t batch_size) { + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, - int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, + int64_t incx, int64_t stridex, sycl::buffer &y, + int64_t incy, int64_t stridey, int64_t batch_size) { + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, int64_t stridex, sycl::buffer, 1> &y, +void copy_batch(sycl::queue &queue, int64_t n, + sycl::buffer, 1> &x, int64_t incx, + int64_t stridex, sycl::buffer, 1> &y, int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, int64_t stridex, sycl::buffer, 1> &y, +void copy_batch(sycl::queue &queue, int64_t n, + sycl::buffer, 1> &x, int64_t incx, + int64_t stridex, sycl::buffer, 1> &y, int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -void axpy_batch(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, int64_t incx, - int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, +void axpy_batch(sycl::queue &queue, int64_t n, float alpha, + sycl::buffer &x, int64_t incx, int64_t stridex, + sycl::buffer &y, int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -void axpy_batch(sycl::queue &queue, int64_t n, double alpha, sycl::buffer &x, - int64_t incx, int64_t stridex, sycl::buffer &y, int64_t incy, - int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); +void axpy_batch(sycl::queue &queue, int64_t n, double alpha, + sycl::buffer &x, int64_t incx, int64_t stridex, + sycl::buffer &y, int64_t incy, int64_t stridey, + int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } void axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - sycl::buffer, 1> &x, int64_t incx, int64_t stridex, - sycl::buffer, 1> &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + sycl::buffer, 1> &x, int64_t incx, + int64_t stridex, sycl::buffer, 1> &y, + int64_t incy, int64_t stridey, int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } void axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - sycl::buffer, 1> &x, int64_t incx, int64_t stridex, - sycl::buffer, 1> &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + sycl::buffer, 1> &x, int64_t incx, + int64_t stridex, sycl::buffer, 1> &y, + int64_t incy, int64_t stridey, int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, - int64_t incx, int64_t stride_x, float beta, sycl::buffer &y, int64_t incy, - int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); +void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &x, int64_t incx, + int64_t stride_x, float beta, sycl::buffer &y, + int64_t incy, int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &x, int64_t incx, int64_t stride_x, double beta, - sycl::buffer &y, int64_t incy, int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); +void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &x, int64_t incx, + int64_t stride_x, double beta, sycl::buffer &y, + int64_t incy, int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, int64_t incx, - int64_t stride_x, std::complex beta, sycl::buffer, 1> &y, - int64_t incy, int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &x, + int64_t incx, int64_t stride_x, std::complex beta, + sycl::buffer, 1> &y, int64_t incy, + int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, int64_t incx, - int64_t stride_x, std::complex beta, - sycl::buffer, 1> &y, int64_t incy, int64_t stride_y, - int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &x, + int64_t incx, int64_t stride_x, std::complex beta, + sycl::buffer, 1> &y, int64_t incy, + int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, - int64_t incx, int64_t stride_x, sycl::buffer &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &x, int64_t incx, int64_t stride_x, + sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, int64_t incx, int64_t stride_x, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &x, int64_t incx, int64_t stride_x, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &x, + int64_t incx, int64_t stride_x, + sycl::buffer, 1> &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &x, int64_t incx, int64_t stride_x, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &x, + int64_t incx, int64_t stride_x, + sycl::buffer, 1> &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size) { \ - throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ - } +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, \ + int64_t m, int64_t n, int64_t k, TYPE_S alpha, \ + sycl::buffer &a, int64_t lda, int64_t stride_a, \ + sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, \ + int64_t stride_c, int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ + } GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) @@ -1200,386 +1368,460 @@ GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, - std::complex) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, - std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, + std::complex, std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, + std::complex, std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, double alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, float beta, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, float beta, + sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, - double beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, double alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, double beta, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, std::complex beta, sycl::buffer, 1> &c, - int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, + sycl::buffer, 1> &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, std::complex beta, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + sycl::buffer, 1> &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, + int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, + int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, - int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, - int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, + std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, - int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + float alpha, sycl::buffer &ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, - int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + double alpha, sycl::buffer &ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &ab, - int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + std::complex alpha, + sycl::buffer, 1> &ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &ab, - int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + std::complex alpha, + sycl::buffer, 1> &ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, - float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, float beta, + sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, - double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &b, int64_t ldb, + int64_t stride_b, sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, std::complex beta, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); + sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, sycl::buffer, 1> &c, + int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, - int64_t lda, int64_t stride_a, std::complex beta, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, sycl::buffer, 1> &c, + int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } // USM APIs -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, int64_t *incx, float **y, - int64_t *incy, int64_t group_count, int64_t *group_size, +sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, + int64_t *incx, float **y, int64_t *incy, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const double **x, int64_t *incx, double **y, - int64_t *incy, int64_t group_count, int64_t *group_size, +sycl::event copy_batch(sycl::queue &queue, int64_t *n, const double **x, + int64_t *incx, double **y, int64_t *incy, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const std::complex **x, int64_t *incx, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t *n, + const std::complex **x, int64_t *incx, + std::complex **y, int64_t *incy, + int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const std::complex **x, - int64_t *incx, std::complex **y, int64_t *incy, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t *n, + const std::complex **x, int64_t *incx, + std::complex **y, int64_t *incy, + int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const float *x, int64_t incx, - std::int64_t stridex, float *y, int64_t incy, std::int64_t stridey, - std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t n, const float *x, + int64_t incx, std::int64_t stridex, float *y, + int64_t incy, std::int64_t stridey, + std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const double *x, int64_t incx, - std::int64_t stridex, double *y, int64_t incy, std::int64_t stridey, - std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t n, const double *x, + int64_t incx, std::int64_t stridex, double *y, + int64_t incy, std::int64_t stridey, + std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - std::int64_t stridex, std::complex *y, int64_t incy, - std::int64_t stridey, std::int64_t batch_size, +sycl::event copy_batch(sycl::queue &queue, int64_t n, + const std::complex *x, int64_t incx, + std::int64_t stridex, std::complex *y, + int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - std::int64_t stridex, std::complex *y, int64_t incy, - std::int64_t stridey, std::int64_t batch_size, +sycl::event copy_batch(sycl::queue &queue, int64_t n, + const std::complex *x, int64_t incx, + std::int64_t stridex, std::complex *y, + int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, float *alpha, const float **x, int64_t *incx, - float **y, int64_t *incy, int64_t group_count, int64_t *group_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, float *alpha, + const float **x, int64_t *incx, float **y, int64_t *incy, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, double *alpha, const double **x, - int64_t *incx, double **y, int64_t *incy, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, double *alpha, + const double **x, int64_t *incx, double **y, + int64_t *incy, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, std::complex *alpha, - const std::complex **x, int64_t *incx, std::complex **y, - int64_t *incy, int64_t group_count, int64_t *group_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, + std::complex *alpha, + const std::complex **x, int64_t *incx, + std::complex **y, int64_t *incy, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, std::complex *alpha, - const std::complex **x, int64_t *incx, std::complex **y, - int64_t *incy, int64_t group_count, int64_t *group_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, + std::complex *alpha, + const std::complex **x, int64_t *incx, + std::complex **y, int64_t *incy, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, float alpha, const float *x, int64_t incx, - int64_t stridex, float *y, int64_t incy, int64_t stridey, int64_t batch_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t n, float alpha, + const float *x, int64_t incx, int64_t stridex, float *y, + int64_t incy, int64_t stridey, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, double alpha, const double *x, int64_t incx, - int64_t stridex, double *y, int64_t incy, int64_t stridey, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); +sycl::event axpy_batch(sycl::queue &queue, int64_t n, double alpha, + const double *x, int64_t incx, int64_t stridex, + double *y, int64_t incy, int64_t stridey, + int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } sycl::event axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, int64_t stridex, - std::complex *y, int64_t incy, int64_t stridey, int64_t batch_size, + const std::complex *x, int64_t incx, + int64_t stridex, std::complex *y, int64_t incy, + int64_t stridey, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, int64_t stridex, - std::complex *y, int64_t incy, int64_t stridey, int64_t batch_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t n, + std::complex alpha, + const std::complex *x, int64_t incx, + int64_t stridex, std::complex *y, int64_t incy, + int64_t stridey, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, int64_t stride_a, const float *x, int64_t incx, - int64_t stride_x, float beta, float *y, int64_t incy, int64_t stride_y, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, + int64_t stride_a, const float *x, int64_t incx, + int64_t stride_x, float beta, float *y, int64_t incy, + int64_t stride_y, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, int64_t stride_a, const double *x, - int64_t incx, int64_t stride_x, double beta, double *y, int64_t incy, +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, + int64_t n, double alpha, const double *a, int64_t lda, + int64_t stride_a, const double *x, int64_t incx, + int64_t stride_x, double beta, double *y, int64_t incy, int64_t stride_y, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, int64_t incx, - int64_t stride_x, std::complex beta, std::complex *y, - int64_t incy, int64_t stride_y, int64_t batch_size, +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, + int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, const std::complex *x, + int64_t incx, int64_t stride_x, std::complex beta, + std::complex *y, int64_t incy, int64_t stride_y, + int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, int64_t incx, - int64_t stride_x, std::complex beta, std::complex *y, +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, + int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, const std::complex *x, + int64_t incx, int64_t stride_x, + std::complex beta, std::complex *y, int64_t incy, int64_t stride_y, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, float *alpha, - const float **a, int64_t *lda, const float **x, int64_t *incx, float *beta, - float **y, int64_t *incy, int64_t group_count, int64_t *groupsize, +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, + int64_t *n, float *alpha, const float **a, int64_t *lda, + const float **x, int64_t *incx, float *beta, float **y, + int64_t *incy, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, double *alpha, - const double **a, int64_t *lda, const double **x, int64_t *incx, - double *beta, double **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, + int64_t *n, double *alpha, const double **a, + int64_t *lda, const double **x, int64_t *incx, + double *beta, double **y, int64_t *incy, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, std::complex *beta, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, + int64_t *n, std::complex *alpha, + const std::complex **a, int64_t *lda, + const std::complex **x, int64_t *incx, + std::complex *beta, std::complex **y, + int64_t *incy, int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, std::complex *beta, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, + int64_t *n, std::complex *alpha, + const std::complex **a, int64_t *lda, + const std::complex **x, int64_t *incx, + std::complex *beta, std::complex **y, + int64_t *incy, int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const float *a, - int64_t lda, int64_t stride_a, const float *x, int64_t incx, - int64_t stride_x, float *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, + int64_t n, const float *a, int64_t lda, int64_t stride_a, + const float *x, int64_t incx, int64_t stride_x, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const double *a, - int64_t lda, int64_t stride_a, const double *x, int64_t incx, - int64_t stride_x, double *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, + int64_t n, const double *a, int64_t lda, + int64_t stride_a, const double *x, int64_t incx, + int64_t stride_x, double *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - const std::complex *a, int64_t lda, int64_t stride_a, - const std::complex *x, int64_t incx, int64_t stride_x, - std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, + int64_t n, const std::complex *a, int64_t lda, + int64_t stride_a, const std::complex *x, + int64_t incx, int64_t stride_x, std::complex *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - const std::complex *a, int64_t lda, int64_t stride_a, - const std::complex *x, int64_t incx, int64_t stride_x, - std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, + int64_t n, const std::complex *a, int64_t lda, + int64_t stride_a, const std::complex *x, + int64_t incx, int64_t stride_x, std::complex *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const float **a, int64_t *lda, const float **x, int64_t *incx, float **c, - int64_t *ldc, int64_t group_count, int64_t *groupsize, +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, + int64_t *n, const float **a, int64_t *lda, + const float **x, int64_t *incx, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const double **a, int64_t *lda, const double **x, int64_t *incx, double **c, +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, + int64_t *n, const double **a, int64_t *lda, + const double **x, int64_t *incx, double **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const std::complex **a, int64_t *lda, const std::complex **x, - int64_t *incx, std::complex **c, int64_t *ldc, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, + int64_t *n, const std::complex **a, int64_t *lda, + const std::complex **x, int64_t *incx, + std::complex **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const std::complex **a, int64_t *lda, const std::complex **x, - int64_t *incx, std::complex **c, int64_t *ldc, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, + int64_t *n, const std::complex **a, int64_t *lda, + const std::complex **x, int64_t *incx, + std::complex **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ - int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ - TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size, const std::vector &dependencies) { \ - throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ - } +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch( \ + sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ + } GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) @@ -1587,21 +1829,22 @@ GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, - std::complex) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, - std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, + std::complex, std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, + std::complex, std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ - const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ - } +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch( \ + sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ + } GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) @@ -1609,62 +1852,71 @@ GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) GEMM_BATCH_LAUNCHER_USM(float, float, float, float) GEMM_BATCH_LAUNCHER_USM(double, double, double, double) -GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, - std::complex) -GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, - std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, + std::complex, std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, + std::complex, std::complex) #undef GEMM_BATCH_LAUNCHER_USM -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, - int64_t lda, int64_t stride_a, float *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + float alpha, const float *a, int64_t lda, + int64_t stride_a, float *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, double alpha, const double *a, - int64_t lda, int64_t stride_a, double *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + double alpha, const double *a, int64_t lda, + int64_t stride_a, double *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, int64_t stride_a, - std::complex *b, int64_t ldb, int64_t stride_b, int64_t batch_size, +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex *b, + int64_t ldb, int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, int64_t stride_a, - std::complex *b, int64_t ldb, int64_t stride_b, int64_t batch_size, +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, + std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } template -inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &queue, - side *left_right, uplo *upper_lower, transpose *trans, - diag *unit_diag, int64_t *m, int64_t *n, T *alpha, const T **a, - int64_t *lda, T **b, int64_t *ldb, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); -} - -#define TRSM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event trsm_batch(sycl::queue &queue, side *left_right, uplo *upper_lower, \ - transpose *trans, diag *unit_diag, int64_t *m, int64_t *n, TYPE *alpha, \ - const TYPE **a, int64_t *lda, TYPE **b, int64_t *ldb, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return trsm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, left_right, upper_lower, trans, \ - unit_diag, m, n, alpha, a, lda, b, ldb, group_count, group_size, \ - dependencies); \ - } +inline sycl::event +trsm_batch(const char *func_name, Func func, sycl::queue &queue, + side *left_right, uplo *upper_lower, transpose *trans, + diag *unit_diag, int64_t *m, int64_t *n, T *alpha, const T **a, + int64_t *lda, T **b, int64_t *ldb, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", "for row_major layout"); +} + +#define TRSM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ + sycl::event trsm_batch(sycl::queue &queue, side *left_right, \ + uplo *upper_lower, transpose *trans, diag *unit_diag, \ + int64_t *m, int64_t *n, TYPE *alpha, const TYPE **a, \ + int64_t *lda, TYPE **b, int64_t *ldb, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return trsm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, left_right, \ + upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, \ + ldb, group_count, group_size, dependencies); \ + } TRSM_BATCH_LAUNCHER_USM(float, cublasStrsmBatched) TRSM_BATCH_LAUNCHER_USM(double, cublasDtrsmBatched) @@ -1673,209 +1925,249 @@ TRSM_BATCH_LAUNCHER_USM(std::complex, cublasZtrsmBatched) #undef TRSM_BATCH_LAUNCHER_USM -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, float *alpha, const float **a, int64_t *lda, float *beta, - float **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, + int64_t *n, int64_t *k, float *alpha, const float **a, + int64_t *lda, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, double *alpha, const double **a, int64_t *lda, double *beta, - double **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, + int64_t *n, int64_t *k, double *alpha, const double **a, + int64_t *lda, double *beta, double **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, std::complex *alpha, const std::complex **a, - int64_t *lda, std::complex *beta, std::complex **c, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, + int64_t *n, int64_t *k, std::complex *alpha, + const std::complex **a, int64_t *lda, + std::complex *beta, std::complex **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, std::complex *alpha, const std::complex **a, - int64_t *lda, std::complex *beta, std::complex **c, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, + int64_t *n, int64_t *k, std::complex *alpha, + const std::complex **a, int64_t *lda, + std::complex *beta, std::complex **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - float alpha, const float *a, int64_t lda, int64_t stride_a, float beta, - float *c, int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, float alpha, const float *a, + int64_t lda, int64_t stride_a, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - double alpha, const double *a, int64_t lda, int64_t stride_a, double beta, - double *c, int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, double alpha, const double *a, + int64_t lda, int64_t stride_a, double beta, double *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex beta, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex beta, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, + int64_t stride_a, float *b, int64_t ldb, int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, - int64_t stride_b, int64_t batch_size, +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, double alpha, const double *a, + int64_t lda, int64_t stride_a, double *b, + int64_t ldb, int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, + int64_t ldb, int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, + int64_t ldb, int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, float alpha, float *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - double *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, double alpha, double *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, std::complex *ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, std::complex alpha, + std::complex *ab, int64_t lda, int64_t ldb, + int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, std::complex *ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, + int64_t n, std::complex alpha, + std::complex *ab, int64_t lda, int64_t ldb, + int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, - float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, + float beta, const float *b, int64_t ldb, + int64_t stride_b, float *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, - double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, + double beta, const double *b, int64_t ldb, + int64_t stride_b, double *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, std::complex alpha, const std::complex *a, - int64_t lda, int64_t stride_a, std::complex beta, - const std::complex *b, int64_t ldb, int64_t stride_b, - std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, + std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, + int64_t stride_b, std::complex *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, std::complex alpha, const std::complex *a, - int64_t lda, int64_t stride_a, std::complex beta, - const std::complex *b, int64_t ldb, int64_t stride_b, - std::complex *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, + std::complex alpha, + const std::complex *a, int64_t lda, + int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, + int64_t stride_b, std::complex *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - float *alpha, const float **a, int64_t *lda, float **b, int64_t *ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, float *alpha, const float **a, + int64_t *lda, float **b, int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - double *alpha, const double **a, int64_t *lda, double **b, int64_t *ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, double *alpha, const double **a, + int64_t *lda, double **b, int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - std::complex **b, int64_t *ldb, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, std::complex *alpha, + const std::complex **a, int64_t *lda, + std::complex **b, int64_t *ldb, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, - int64_t *lda, std::complex **b, int64_t *ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, std::complex *alpha, + const std::complex **a, int64_t *lda, + std::complex **b, int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - float *alpha, float **ab, int64_t *lda, int64_t *ldb, - int64_t group_count, int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, float *alpha, float **ab, int64_t *lda, + int64_t *ldb, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - double *alpha, double **ab, int64_t *lda, int64_t *ldb, - int64_t group_count, int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, double *alpha, double **ab, int64_t *lda, + int64_t *ldb, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - std::complex *alpha, std::complex **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, std::complex *alpha, + std::complex **ab, int64_t *lda, int64_t *ldb, + int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - std::complex *alpha, std::complex **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, std::complex *alpha, + std::complex **ab, int64_t *lda, + int64_t *ldb, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } } // namespace row_major From 27b251f15e287c30d76ab8bf3ad71501e09a4fd9 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Mon, 7 Oct 2024 14:02:43 +0100 Subject: [PATCH 10/11] Revert "Fix format." This reverts commit 61c9a533ae0c251da5b8a3b9a9b838676c317dfa. --- src/blas/backends/cublas/cublas_batch.cpp | 2590 +++++++++------------ 1 file changed, 1149 insertions(+), 1441 deletions(-) diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index e882f8ee7..031c11a2f 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -29,170 +29,152 @@ namespace column_major { // Buffer APIs -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, - int64_t incx, int64_t stridex, sycl::buffer &y, - int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); +void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, + int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, + int64_t batch_size) { + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, - int64_t incx, int64_t stridex, sycl::buffer &y, - int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); +void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, + int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, + int64_t batch_size) { + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -void copy_batch(sycl::queue &queue, int64_t n, - sycl::buffer, 1> &x, int64_t incx, - int64_t stridex, sycl::buffer, 1> &y, +void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, + int64_t incx, int64_t stridex, sycl::buffer, 1> &y, int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -void copy_batch(sycl::queue &queue, int64_t n, - sycl::buffer, 1> &x, int64_t incx, - int64_t stridex, sycl::buffer, 1> &y, +void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, + int64_t incx, int64_t stridex, sycl::buffer, 1> &y, int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -void axpy_batch(sycl::queue &queue, int64_t n, float alpha, - sycl::buffer &x, int64_t incx, int64_t stridex, - sycl::buffer &y, int64_t incy, int64_t stridey, +void axpy_batch(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, int64_t incx, + int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -void axpy_batch(sycl::queue &queue, int64_t n, double alpha, - sycl::buffer &x, int64_t incx, int64_t stridex, - sycl::buffer &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); +void axpy_batch(sycl::queue &queue, int64_t n, double alpha, sycl::buffer &x, + int64_t incx, int64_t stridex, sycl::buffer &y, int64_t incy, + int64_t stridey, int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } void axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - sycl::buffer, 1> &x, int64_t incx, - int64_t stridex, sycl::buffer, 1> &y, - int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + sycl::buffer, 1> &x, int64_t incx, int64_t stridex, + sycl::buffer, 1> &y, int64_t incy, int64_t stridey, + int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } void axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - sycl::buffer, 1> &x, int64_t incx, - int64_t stridex, sycl::buffer, 1> &y, - int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + sycl::buffer, 1> &x, int64_t incx, int64_t stridex, + sycl::buffer, 1> &y, int64_t incy, int64_t stridey, + int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &x, int64_t incx, - int64_t stride_x, float beta, sycl::buffer &y, - int64_t incy, int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, + int64_t incx, int64_t stride_x, float beta, sycl::buffer &y, int64_t incy, + int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &x, int64_t incx, - int64_t stride_x, double beta, sycl::buffer &y, - int64_t incy, int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &x, int64_t incx, int64_t stride_x, double beta, + sycl::buffer &y, int64_t incy, int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, - int64_t incx, int64_t stride_x, std::complex beta, - sycl::buffer, 1> &y, int64_t incy, - int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &x, int64_t incx, + int64_t stride_x, std::complex beta, sycl::buffer, 1> &y, + int64_t incy, int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, - int64_t incx, int64_t stride_x, std::complex beta, - sycl::buffer, 1> &y, int64_t incy, - int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &x, int64_t incx, + int64_t stride_x, std::complex beta, + sycl::buffer, 1> &y, int64_t incy, int64_t stride_y, + int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &x, int64_t incx, int64_t stride_x, - sycl::buffer &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, + int64_t incx, int64_t stride_x, sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, int64_t incx, int64_t stride_x, - sycl::buffer &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, - int64_t incx, int64_t stride_x, - sycl::buffer, 1> &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, + sycl::buffer, 1> &x, int64_t incx, int64_t stride_x, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, - int64_t incx, int64_t stride_x, - sycl::buffer, 1> &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, + sycl::buffer, 1> &x, int64_t incx, int64_t stride_x, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } template -inline void gemm_batch_impl(sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, int64_t k, - Ts alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &b, - int64_t ldb, int64_t stride_b, Ts beta, - sycl::buffer &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - using cuTypeA = typename CudaEquivalentType::Type; - using cuTypeB = typename CudaEquivalentType::Type; - using cuTypeC = typename CudaEquivalentType::Type; - using cuTypeS = typename CudaEquivalentType::Type; - overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, - batch_size); - - cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; - queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, - sycl::aspect::fp16)) { - throw oneapi::mkl::unimplemented( - "blas", "sycl::half", - "half is not supported by the device or the sycl compiler"); - } - auto a_acc = a.template get_access(cgh); - auto b_acc = b.template get_access(cgh); - auto c_acc = c.template get_access(cgh); - onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - auto a_ = sc.get_mem(a_acc); - auto b_ = sc.get_mem(b_acc); - auto c_ = sc.get_mem(c_acc); - cublasStatus_t err; +inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, int64_t k, Ts alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + Ts beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; + overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, batch_size); + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; + queue.submit([&](sycl::handler &cgh) { + if (!verify_support(queue, sycl::aspect::fp16)) { + throw oneapi::mkl::unimplemented( + "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); + } + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); + cublasStatus_t err; #ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - CUBLAS_ERROR_FUNC_T( - "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, - get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, - &alpha, a_, get_cublas_datatype(), lda, stride_a, b_, - get_cublas_datatype(), ldb, stride_b, &beta, c_, - get_cublas_datatype(), ldc, stride_c, batch_size, - get_cublas_datatype(), cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, + err, handle, get_cublas_operation(transa), + get_cublas_operation(transb), m, n, k, &alpha, a_, + get_cublas_datatype(), lda, stride_a, b_, + get_cublas_datatype(), ldb, stride_b, &beta, c_, + get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); #else CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, @@ -204,527 +186,449 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, batch_size, get_cublas_datatype(), cublas_gemm_algo); #endif + }); }); - }); } -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, \ - int64_t m, int64_t n, int64_t k, TYPE_S alpha, \ - sycl::buffer &a, int64_t lda, int64_t stride_a, \ - sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE_S beta, sycl::buffer &c, int64_t ldc, \ - int64_t stride_c, int64_t batch_size) { \ - gemm_batch_impl( \ - queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, \ - stride_b, beta, c, ldc, stride_c, batch_size); \ - } +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size) { \ + gemm_batch_impl(queue, transa, transb, m, n, k, alpha, a, \ + lda, stride_a, b, ldb, stride_b, beta, c, \ + ldc, stride_c, batch_size); \ + } GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, - std::complex, std::complex) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, - std::complex, std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, \ - int64_t m, int64_t n, int64_t k, TYPE_S alpha, \ - sycl::buffer &a, int64_t lda, int64_t stride_a, \ - sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE_S beta, sycl::buffer &c, int64_t ldc, \ - int64_t stride_c, int64_t batch_size) { \ - throw unimplemented( \ - "blas", "gemm_batch", \ - std::string("for dtype unimplemented dtype combination <") + \ - dtype_string() + "," + dtype_string() + "," + \ - dtype_string() + "," + dtype_string() + ">"); \ - } +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &b, int64_t ldb, +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &b, int64_t ldb, +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, double alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &b, - int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + int64_t batch_size) { + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &b, - int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + int64_t batch_size) { + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, float alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, float beta, - sycl::buffer &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, float beta, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, double alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, double beta, - sycl::buffer &c, int64_t ldc, int64_t stride_c, +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, std::complex beta, - sycl::buffer, 1> &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, sycl::buffer, 1> &c, + int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, std::complex beta, - sycl::buffer, 1> &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &b, int64_t ldb, - int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &b, int64_t ldb, - int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &b, - int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &b, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - float alpha, sycl::buffer &ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - double alpha, sycl::buffer &ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, float alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, float beta, - sycl::buffer &b, int64_t ldb, int64_t stride_b, - sycl::buffer &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - double beta, sycl::buffer &b, int64_t ldb, - int64_t stride_b, sycl::buffer &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, std::complex beta, - sycl::buffer, 1> &b, int64_t ldb, - int64_t stride_b, sycl::buffer, 1> &c, - int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, std::complex beta, - sycl::buffer, 1> &b, int64_t ldb, - int64_t stride_b, sycl::buffer, 1> &c, - int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } // USM APIs -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, - int64_t *incx, float **y, int64_t *incy, - int64_t group_count, int64_t *group_size, +sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, int64_t *incx, float **y, + int64_t *incy, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const double **x, - int64_t *incx, double **y, int64_t *incy, - int64_t group_count, int64_t *group_size, +sycl::event copy_batch(sycl::queue &queue, int64_t *n, const double **x, int64_t *incx, double **y, + int64_t *incy, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, - const std::complex **x, int64_t *incx, - std::complex **y, int64_t *incy, - int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t *n, const std::complex **x, int64_t *incx, + std::complex **y, int64_t *incy, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, - const std::complex **x, int64_t *incx, - std::complex **y, int64_t *incy, - int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t *n, const std::complex **x, + int64_t *incx, std::complex **y, int64_t *incy, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const float *x, - int64_t incx, std::int64_t stridex, float *y, - int64_t incy, std::int64_t stridey, - std::int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t n, const float *x, int64_t incx, + std::int64_t stridex, float *y, int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const double *x, - int64_t incx, std::int64_t stridex, double *y, - int64_t incy, std::int64_t stridey, - std::int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t n, const double *x, int64_t incx, + std::int64_t stridex, double *y, int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, - const std::complex *x, int64_t incx, - std::int64_t stridex, std::complex *y, - int64_t incy, std::int64_t stridey, - std::int64_t batch_size, +sycl::event copy_batch(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, + std::int64_t stridex, std::complex *y, int64_t incy, + std::int64_t stridey, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, - const std::complex *x, int64_t incx, - std::int64_t stridex, std::complex *y, - int64_t incy, std::int64_t stridey, - std::int64_t batch_size, +sycl::event copy_batch(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, + std::int64_t stridex, std::complex *y, int64_t incy, + std::int64_t stridey, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + throw unimplemented("blas", "copy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, float *alpha, - const float **x, int64_t *incx, float **y, int64_t *incy, - int64_t group_count, int64_t *group_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, float *alpha, const float **x, int64_t *incx, + float **y, int64_t *incy, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, double *alpha, - const double **x, int64_t *incx, double **y, - int64_t *incy, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, double *alpha, const double **x, + int64_t *incx, double **y, int64_t *incy, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, - std::complex *alpha, - const std::complex **x, int64_t *incx, - std::complex **y, int64_t *incy, - int64_t group_count, int64_t *group_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, std::complex *alpha, + const std::complex **x, int64_t *incx, std::complex **y, + int64_t *incy, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, - std::complex *alpha, - const std::complex **x, int64_t *incx, - std::complex **y, int64_t *incy, - int64_t group_count, int64_t *group_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, std::complex *alpha, + const std::complex **x, int64_t *incx, std::complex **y, + int64_t *incy, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, float alpha, - const float *x, int64_t incx, int64_t stridex, float *y, - int64_t incy, int64_t stridey, int64_t batch_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t n, float alpha, const float *x, int64_t incx, + int64_t stridex, float *y, int64_t incy, int64_t stridey, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, double alpha, - const double *x, int64_t incx, int64_t stridex, - double *y, int64_t incy, int64_t stridey, - int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); +sycl::event axpy_batch(sycl::queue &queue, int64_t n, double alpha, const double *x, int64_t incx, + int64_t stridex, double *y, int64_t incy, int64_t stridey, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } sycl::event axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, - int64_t stridex, std::complex *y, int64_t incy, - int64_t stridey, int64_t batch_size, + const std::complex *x, int64_t incx, int64_t stridex, + std::complex *y, int64_t incy, int64_t stridey, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, - std::complex alpha, - const std::complex *x, int64_t incx, - int64_t stridex, std::complex *y, int64_t incy, - int64_t stridey, int64_t batch_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, + const std::complex *x, int64_t incx, int64_t stridex, + std::complex *y, int64_t incy, int64_t stridey, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + throw unimplemented("blas", "axpy_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, - int64_t n, float alpha, const float *a, int64_t lda, - int64_t stride_a, const float *x, int64_t incx, - int64_t stride_x, float beta, float *y, int64_t incy, - int64_t stride_y, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, const float *x, int64_t incx, + int64_t stride_x, float beta, float *y, int64_t incy, int64_t stride_y, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, - int64_t n, double alpha, const double *a, int64_t lda, - int64_t stride_a, const double *x, int64_t incx, - int64_t stride_x, double beta, double *y, int64_t incy, +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, const double *x, + int64_t incx, int64_t stride_x, double beta, double *y, int64_t incy, int64_t stride_y, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, - int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, - int64_t incx, int64_t stride_x, std::complex beta, - std::complex *y, int64_t incy, int64_t stride_y, - int64_t batch_size, +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, const std::complex *x, int64_t incx, + int64_t stride_x, std::complex beta, std::complex *y, + int64_t incy, int64_t stride_y, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, - int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, - int64_t incx, int64_t stride_x, - std::complex beta, std::complex *y, +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, const std::complex *x, int64_t incx, + int64_t stride_x, std::complex beta, std::complex *y, int64_t incy, int64_t stride_y, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, - int64_t *n, float *alpha, const float **a, int64_t *lda, - const float **x, int64_t *incx, float *beta, float **y, - int64_t *incy, int64_t group_count, int64_t *groupsize, +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, float *alpha, + const float **a, int64_t *lda, const float **x, int64_t *incx, float *beta, + float **y, int64_t *incy, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, - int64_t *n, double *alpha, const double **a, - int64_t *lda, const double **x, int64_t *incx, - double *beta, double **y, int64_t *incy, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, double *alpha, + const double **a, int64_t *lda, const double **x, int64_t *incx, + double *beta, double **y, int64_t *incy, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, - int64_t *n, std::complex *alpha, - const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, - std::complex *beta, std::complex **y, - int64_t *incy, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, + std::complex *alpha, const std::complex **a, int64_t *lda, + const std::complex **x, int64_t *incx, std::complex *beta, + std::complex **y, int64_t *incy, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, - int64_t *n, std::complex *alpha, - const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, - std::complex *beta, std::complex **y, - int64_t *incy, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, + std::complex *alpha, const std::complex **a, int64_t *lda, + const std::complex **x, int64_t *incx, std::complex *beta, + std::complex **y, int64_t *incy, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, - int64_t n, const float *a, int64_t lda, int64_t stride_a, - const float *x, int64_t incx, int64_t stride_x, float *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const float *a, + int64_t lda, int64_t stride_a, const float *x, int64_t incx, + int64_t stride_x, float *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, - int64_t n, const double *a, int64_t lda, - int64_t stride_a, const double *x, int64_t incx, - int64_t stride_x, double *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const double *a, + int64_t lda, int64_t stride_a, const double *x, int64_t incx, + int64_t stride_x, double *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, - int64_t n, const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, - int64_t incx, int64_t stride_x, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, + const std::complex *a, int64_t lda, int64_t stride_a, + const std::complex *x, int64_t incx, int64_t stride_x, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, - int64_t n, const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, - int64_t incx, int64_t stride_x, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, + const std::complex *a, int64_t lda, int64_t stride_a, + const std::complex *x, int64_t incx, int64_t stride_x, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, - int64_t *n, const float **a, int64_t *lda, - const float **x, int64_t *incx, float **c, int64_t *ldc, - int64_t group_count, int64_t *groupsize, +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, + const float **a, int64_t *lda, const float **x, int64_t *incx, float **c, + int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, - int64_t *n, const double **a, int64_t *lda, - const double **x, int64_t *incx, double **c, +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, + const double **a, int64_t *lda, const double **x, int64_t *incx, double **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, - int64_t *n, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, - std::complex **c, int64_t *ldc, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, + const std::complex **a, int64_t *lda, const std::complex **x, + int64_t *incx, std::complex **c, int64_t *ldc, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, - int64_t *n, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, - std::complex **c, int64_t *ldc, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, + const std::complex **a, int64_t *lda, const std::complex **x, + int64_t *incx, std::complex **c, int64_t *ldc, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } template -inline sycl::event gemm_batch_strided_usm_impl( - sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, Ts alpha, const Ta *a, int64_t lda, int64_t stride_a, - const Tb *b, int64_t ldb, int64_t stride_b, Ts beta, Tc *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - using cuTypeA = typename CudaEquivalentType::Type; - using cuTypeB = typename CudaEquivalentType::Type; - using cuTypeC = typename CudaEquivalentType::Type; - using cuTypeS = typename CudaEquivalentType::Type; - overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, - batch_size); - - cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; - auto done = queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, - sycl::aspect::fp16)) { - throw oneapi::mkl::unimplemented( - "blas", "sycl::half", - "half is not supported by the device or the sycl compiler"); - } - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } - onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - cublasStatus_t err; +inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, + Ts alpha, const Ta *a, int64_t lda, int64_t stride_a, + const Tb *b, int64_t ldb, int64_t stride_b, Ts beta, + Tc *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, + const std::vector &dependencies) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; + overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, batch_size); + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; + auto done = queue.submit([&](sycl::handler &cgh) { + if (!verify_support(queue, sycl::aspect::fp16)) { + throw oneapi::mkl::unimplemented( + "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); + } + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + cublasStatus_t err; #ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - CUBLAS_ERROR_FUNC_T( - "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, - get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, - &alpha, a, get_cublas_datatype(), lda, stride_a, b, - get_cublas_datatype(), ldb, stride_b, &beta, c, - get_cublas_datatype(), ldc, stride_c, batch_size, - get_cublas_datatype(), cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, + err, handle, get_cublas_operation(transa), + get_cublas_operation(transb), m, n, k, &alpha, a, + get_cublas_datatype(), lda, stride_a, b, + get_cublas_datatype(), ldb, stride_b, &beta, c, + get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); #else CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, @@ -736,47 +640,44 @@ inline sycl::event gemm_batch_strided_usm_impl( batch_size, get_cublas_datatype(), cublas_gemm_algo); #endif + }); }); - }); - return done; -} - -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - sycl::event gemm_batch( \ - sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ - int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ - TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size, const std::vector &dependencies) { \ - return gemm_batch_strided_usm_impl( \ - queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, \ - stride_b, beta, c, ldc, stride_c, batch_size, dependencies); \ - } + return done; +} + +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + return gemm_batch_strided_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, \ + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, \ + batch_size, dependencies); \ + } GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, - std::complex, std::complex) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, - std::complex, std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - sycl::event gemm_batch( \ - sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ - int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ - TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size, const std::vector &dependencies) { \ - throw unimplemented( \ - "blas", "gemm_batch", \ - std::string("for dtype unimplemented dtype combination <") + \ - dtype_string() + "," + dtype_string() + "," + \ - dtype_string() + "," + dtype_string() + ">"); \ - } +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) @@ -784,48 +685,45 @@ GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM template -inline sycl::event -gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, - int64_t *m, int64_t *n, int64_t *k, Ts *alpha, const Ta **a, - int64_t *lda, const Tb **b, int64_t *ldb, Ts *beta, Tc **c, - int64_t *ldc, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - using cuTypeA = typename CudaEquivalentType::Type; - using cuTypeB = typename CudaEquivalentType::Type; - using cuTypeC = typename CudaEquivalentType::Type; - using cuTypeS = typename CudaEquivalentType::Type; - for (int64_t i = 0; i < group_count; i++) { - overflow_check(m[i], n[i], k[i], lda[i], ldb[i], ldc[i], group_size[i]); - } - - cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; - auto done = queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, - sycl::aspect::fp16)) { - throw oneapi::mkl::unimplemented( - "blas", "sycl::half", - "half is not supported by the device or the sycl compiler"); +inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, + int64_t *m, int64_t *n, int64_t *k, Ts *alpha, const Ta **a, + int64_t *lda, const Tb **b, int64_t *ldb, Ts *beta, Tc **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; + for (int64_t i = 0; i < group_count; i++) { + overflow_check(m[i], n[i], k[i], lda[i], ldb[i], ldc[i], group_size[i]); } - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } - onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - int64_t offset = 0; - cublasStatus_t err; - for (int64_t i = 0; i < group_count; i++) { + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; + auto done = queue.submit([&](sycl::handler &cgh) { + if (!verify_support(queue, sycl::aspect::fp16)) { + throw oneapi::mkl::unimplemented( + "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); + } + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + int64_t offset = 0; + cublasStatus_t err; + for (int64_t i = 0; i < group_count; i++) { #ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - CUBLAS_ERROR_FUNC_T( - "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, - get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), - (int)m[i], (int)n[i], (int)k[i], &alpha[i], - (const void *const *)(a + offset), get_cublas_datatype(), - (int)lda[i], (const void *const *)(b + offset), - get_cublas_datatype(), (int)ldb[i], &beta[i], - (void *const *)(c + offset), get_cublas_datatype(), - (int)ldc[i], (int)group_size[i], get_cublas_datatype(), - cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T("cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, + get_cublas_operation(transa[i]), + get_cublas_operation(transb[i]), (int)m[i], (int)n[i], + (int)k[i], &alpha[i], (const void *const *)(a + offset), + get_cublas_datatype(), (int)lda[i], + (const void *const *)(b + offset), + get_cublas_datatype(), (int)ldb[i], &beta[i], + (void *const *)(c + offset), get_cublas_datatype(), + (int)ldc[i], (int)group_size[i], + get_cublas_datatype(), cublas_gemm_algo); #else CUBLAS_ERROR_FUNC_T_SYNC( "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, @@ -837,139 +735,127 @@ gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, get_cublas_datatype(), (int)ldc[i], (int)group_size[i], get_cublas_datatype(), cublas_gemm_algo); #endif - offset += group_size[i]; - } + offset += group_size[i]; + } + }); }); - }); - return done; -} - -#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - sycl::event gemm_batch( \ - sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ - const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return gemm_batch_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, \ - b, ldb, beta, c, ldc, group_count, group_size, \ - dependencies); \ - } + return done; +} + +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return gemm_batch_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, \ + ldc, group_count, group_size, dependencies); \ + } GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) GEMM_BATCH_LAUNCHER_USM(float, float, float, float) GEMM_BATCH_LAUNCHER_USM(double, double, double, double) -GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, - std::complex, std::complex) -GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, - std::complex, std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_BATCH_LAUNCHER_USM -#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - sycl::event gemm_batch( \ - sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ - const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - throw unimplemented( \ - "blas", "gemm_batch", \ - std::string("for dtype unimplemented dtype combination <") + \ - dtype_string() + "," + dtype_string() + "," + \ - dtype_string() + "," + dtype_string() + ">"); \ - } +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_BATCH_LAUNCHER_USM -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - float alpha, const float *a, int64_t lda, - int64_t stride_a, float *b, int64_t ldb, - int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, + int64_t lda, int64_t stride_a, float *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - double alpha, const double *a, int64_t lda, - int64_t stride_a, double *b, int64_t ldb, - int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, double alpha, const double *a, + int64_t lda, int64_t stride_a, double *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, - int64_t lda, int64_t stride_a, std::complex *b, - int64_t ldb, int64_t stride_b, int64_t batch_size, +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, int64_t stride_a, + std::complex *b, int64_t ldb, int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, int64_t ldb, - int64_t stride_b, int64_t batch_size, +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, int64_t stride_a, + std::complex *b, int64_t ldb, int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); + throw unimplemented("blas", "trsm_batch", "for column_major layout"); } template -inline sycl::event -trsm_batch(const char *func_name, Func func, sycl::queue &queue, - side *left_right, uplo *upper_lower, transpose *trans, - diag *unit_diag, int64_t *m, int64_t *n, T *alpha, const T **a, - int64_t *lda, T **b, int64_t *ldb, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - using cuDataType = typename CudaEquivalentType::Type; - for (int64_t i = 0; i < group_count; i++) { - overflow_check(m[i], n[i], lda[i], ldb[i], group_size[i]); - } - auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); +inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &queue, + side *left_right, uplo *upper_lower, transpose *trans, + diag *unit_diag, int64_t *m, int64_t *n, T *alpha, const T **a, + int64_t *lda, T **b, int64_t *ldb, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + for (int64_t i = 0; i < group_count; i++) { + overflow_check(m[i], n[i], lda[i], ldb[i], group_size[i]); } - onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - int64_t offset = 0; - cublasStatus_t err; - for (int64_t i = 0; i < group_count; i++) { - auto **a_ = reinterpret_cast(a); - auto **b_ = reinterpret_cast(b); - cublas_native_named_func( - func_name, func, err, handle, get_cublas_side_mode(left_right[i]), - get_cublas_fill_mode(upper_lower[i]), - get_cublas_operation(trans[i]), get_cublas_diag_type(unit_diag[i]), - (int)m[i], (int)n[i], (cuDataType *)&alpha[i], a_ + offset, - (int)lda[i], b_ + offset, (int)ldb[i], (int)group_size[i]); - offset += group_size[i]; - } + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + int64_t offset = 0; + cublasStatus_t err; + for (int64_t i = 0; i < group_count; i++) { + auto **a_ = reinterpret_cast(a); + auto **b_ = reinterpret_cast(b); + cublas_native_named_func(func_name, func, err, handle, + get_cublas_side_mode(left_right[i]), + get_cublas_fill_mode(upper_lower[i]), + get_cublas_operation(trans[i]), + get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i], + (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], + b_ + offset, (int)ldb[i], (int)group_size[i]); + offset += group_size[i]; + } + }); }); - }); - return done; -} - -#define TRSM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event trsm_batch(sycl::queue &queue, side *left_right, \ - uplo *upper_lower, transpose *trans, diag *unit_diag, \ - int64_t *m, int64_t *n, TYPE *alpha, const TYPE **a, \ - int64_t *lda, TYPE **b, int64_t *ldb, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return trsm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, left_right, \ - upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, \ - ldb, group_count, group_size, dependencies); \ - } + return done; +} + +#define TRSM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ + sycl::event trsm_batch(sycl::queue &queue, side *left_right, uplo *upper_lower, \ + transpose *trans, diag *unit_diag, int64_t *m, int64_t *n, TYPE *alpha, \ + const TYPE **a, int64_t *lda, TYPE **b, int64_t *ldb, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return trsm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, left_right, upper_lower, trans, \ + unit_diag, m, n, alpha, a, lda, b, ldb, group_count, group_size, \ + dependencies); \ + } TRSM_BATCH_LAUNCHER_USM(float, cublasStrsmBatched) TRSM_BATCH_LAUNCHER_USM(double, cublasDtrsmBatched) @@ -978,249 +864,209 @@ TRSM_BATCH_LAUNCHER_USM(std::complex, cublasZtrsmBatched) #undef TRSM_BATCH_LAUNCHER_USM -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, - int64_t *n, int64_t *k, float *alpha, const float **a, - int64_t *lda, float *beta, float **c, int64_t *ldc, - int64_t group_count, int64_t *groupsize, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, + int64_t *k, float *alpha, const float **a, int64_t *lda, float *beta, + float **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, - int64_t *n, int64_t *k, double *alpha, const double **a, - int64_t *lda, double *beta, double **c, int64_t *ldc, - int64_t group_count, int64_t *groupsize, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, + int64_t *k, double *alpha, const double **a, int64_t *lda, double *beta, + double **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, - int64_t *n, int64_t *k, std::complex *alpha, - const std::complex **a, int64_t *lda, - std::complex *beta, std::complex **c, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, + int64_t *k, std::complex *alpha, const std::complex **a, + int64_t *lda, std::complex *beta, std::complex **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, - int64_t *n, int64_t *k, std::complex *alpha, - const std::complex **a, int64_t *lda, - std::complex *beta, std::complex **c, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, + int64_t *k, std::complex *alpha, const std::complex **a, + int64_t *lda, std::complex *beta, std::complex **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, float alpha, const float *a, - int64_t lda, int64_t stride_a, float beta, float *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + float alpha, const float *a, int64_t lda, int64_t stride_a, float beta, + float *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, double alpha, const double *a, - int64_t lda, int64_t stride_a, double beta, double *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + double alpha, const double *a, int64_t lda, int64_t stride_a, double beta, + double *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, - std::complex *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex beta, std::complex *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, - std::complex *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex beta, std::complex *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + throw unimplemented("blas", "syrk_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, float alpha, const float *a, int64_t lda, - int64_t stride_a, float *b, int64_t ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, double alpha, const double *a, - int64_t lda, int64_t stride_a, double *b, - int64_t ldb, int64_t stride_b, int64_t batch_size, +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, - int64_t ldb, int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, - int64_t ldb, int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, float alpha, float *ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, double alpha, double *ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + double *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, std::complex alpha, - std::complex *ab, int64_t lda, int64_t ldb, - int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, std::complex alpha, - std::complex *ab, int64_t lda, int64_t ldb, - int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, int64_t stride_a, - float beta, const float *b, int64_t ldb, - int64_t stride_b, float *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, + float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, int64_t stride_a, - double beta, const double *b, int64_t ldb, - int64_t stride_b, double *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, + double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, - std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, - const std::complex *b, int64_t ldb, - int64_t stride_b, std::complex *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, - std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, - const std::complex *b, int64_t ldb, - int64_t stride_b, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, float *alpha, const float **a, - int64_t *lda, float **b, int64_t *ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + float *alpha, const float **a, int64_t *lda, float **b, int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, double *alpha, const double **a, - int64_t *lda, double **b, int64_t *ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + double *alpha, const double **a, int64_t *lda, double **b, int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, std::complex *alpha, - const std::complex **a, int64_t *lda, - std::complex **b, int64_t *ldb, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + std::complex *alpha, const std::complex **a, int64_t *lda, + std::complex **b, int64_t *ldb, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, std::complex *alpha, - const std::complex **a, int64_t *lda, - std::complex **b, int64_t *ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + std::complex *alpha, const std::complex **a, + int64_t *lda, std::complex **b, int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, float *alpha, float **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, - int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + float *alpha, float **ab, int64_t *lda, int64_t *ldb, + int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, double *alpha, double **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, - int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + double *alpha, double **ab, int64_t *lda, int64_t *ldb, + int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, std::complex *alpha, - std::complex **ab, int64_t *lda, int64_t *ldb, - int64_t group_count, int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + std::complex *alpha, std::complex **ab, int64_t *lda, + int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, std::complex *alpha, - std::complex **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, - int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + std::complex *alpha, std::complex **ab, int64_t *lda, + int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } } // namespace column_major @@ -1228,139 +1074,125 @@ namespace row_major { // Buffer APIs -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, - int64_t incx, int64_t stridex, sycl::buffer &y, - int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, + int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, + int64_t batch_size) { + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, - int64_t incx, int64_t stridex, sycl::buffer &y, - int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, + int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, + int64_t batch_size) { + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -void copy_batch(sycl::queue &queue, int64_t n, - sycl::buffer, 1> &x, int64_t incx, - int64_t stridex, sycl::buffer, 1> &y, +void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, + int64_t incx, int64_t stridex, sycl::buffer, 1> &y, int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -void copy_batch(sycl::queue &queue, int64_t n, - sycl::buffer, 1> &x, int64_t incx, - int64_t stridex, sycl::buffer, 1> &y, +void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, + int64_t incx, int64_t stridex, sycl::buffer, 1> &y, int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -void axpy_batch(sycl::queue &queue, int64_t n, float alpha, - sycl::buffer &x, int64_t incx, int64_t stridex, - sycl::buffer &y, int64_t incy, int64_t stridey, +void axpy_batch(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, int64_t incx, + int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -void axpy_batch(sycl::queue &queue, int64_t n, double alpha, - sycl::buffer &x, int64_t incx, int64_t stridex, - sycl::buffer &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); +void axpy_batch(sycl::queue &queue, int64_t n, double alpha, sycl::buffer &x, + int64_t incx, int64_t stridex, sycl::buffer &y, int64_t incy, + int64_t stridey, int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } void axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - sycl::buffer, 1> &x, int64_t incx, - int64_t stridex, sycl::buffer, 1> &y, - int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + sycl::buffer, 1> &x, int64_t incx, int64_t stridex, + sycl::buffer, 1> &y, int64_t incy, int64_t stridey, + int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } void axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - sycl::buffer, 1> &x, int64_t incx, - int64_t stridex, sycl::buffer, 1> &y, - int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + sycl::buffer, 1> &x, int64_t incx, int64_t stridex, + sycl::buffer, 1> &y, int64_t incy, int64_t stridey, + int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &x, int64_t incx, - int64_t stride_x, float beta, sycl::buffer &y, - int64_t incy, int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); +void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, + int64_t incx, int64_t stride_x, float beta, sycl::buffer &y, int64_t incy, + int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &x, int64_t incx, - int64_t stride_x, double beta, sycl::buffer &y, - int64_t incy, int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); +void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &x, int64_t incx, int64_t stride_x, double beta, + sycl::buffer &y, int64_t incy, int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, - int64_t incx, int64_t stride_x, std::complex beta, - sycl::buffer, 1> &y, int64_t incy, - int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &x, int64_t incx, + int64_t stride_x, std::complex beta, sycl::buffer, 1> &y, + int64_t incy, int64_t stride_y, int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, - int64_t incx, int64_t stride_x, std::complex beta, - sycl::buffer, 1> &y, int64_t incy, - int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &x, int64_t incx, + int64_t stride_x, std::complex beta, + sycl::buffer, 1> &y, int64_t incy, int64_t stride_y, + int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &x, int64_t incx, int64_t stride_x, - sycl::buffer &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, + int64_t incx, int64_t stride_x, sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, int64_t incx, int64_t stride_x, - sycl::buffer &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, - int64_t incx, int64_t stride_x, - sycl::buffer, 1> &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, + sycl::buffer, 1> &x, int64_t incx, int64_t stride_x, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, - int64_t incx, int64_t stride_x, - sycl::buffer, 1> &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, + sycl::buffer, 1> &x, int64_t incx, int64_t stride_x, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, \ - int64_t m, int64_t n, int64_t k, TYPE_S alpha, \ - sycl::buffer &a, int64_t lda, int64_t stride_a, \ - sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE_S beta, sycl::buffer &c, int64_t ldc, \ - int64_t stride_c, int64_t batch_size) { \ - throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ - } +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ + } GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) @@ -1368,460 +1200,386 @@ GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, - std::complex, std::complex) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, - std::complex, std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &b, int64_t ldb, +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &b, int64_t ldb, +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, double alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &b, - int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + int64_t batch_size) { + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &b, - int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); +void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + int64_t batch_size) { + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, float alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, float beta, - sycl::buffer &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, float beta, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, double alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, double beta, - sycl::buffer &c, int64_t ldc, int64_t stride_c, +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, std::complex beta, - sycl::buffer, 1> &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, sycl::buffer, 1> &c, + int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, +void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, std::complex beta, - sycl::buffer, 1> &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &b, int64_t ldb, - int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &b, int64_t ldb, - int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &b, - int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &b, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - float alpha, sycl::buffer &ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - double alpha, sycl::buffer &ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, - sycl::buffer, 1> &ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, float alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, float beta, - sycl::buffer &b, int64_t ldb, int64_t stride_b, - sycl::buffer &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - double beta, sycl::buffer &b, int64_t ldb, - int64_t stride_b, sycl::buffer &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, std::complex beta, - sycl::buffer, 1> &b, int64_t ldb, - int64_t stride_b, sycl::buffer, 1> &c, - int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, std::complex beta, - sycl::buffer, 1> &b, int64_t ldb, - int64_t stride_b, sycl::buffer, 1> &c, - int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } // USM APIs -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, - int64_t *incx, float **y, int64_t *incy, - int64_t group_count, int64_t *group_size, +sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, int64_t *incx, float **y, + int64_t *incy, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const double **x, - int64_t *incx, double **y, int64_t *incy, - int64_t group_count, int64_t *group_size, +sycl::event copy_batch(sycl::queue &queue, int64_t *n, const double **x, int64_t *incx, double **y, + int64_t *incy, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, - const std::complex **x, int64_t *incx, - std::complex **y, int64_t *incy, - int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t *n, const std::complex **x, int64_t *incx, + std::complex **y, int64_t *incy, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, - const std::complex **x, int64_t *incx, - std::complex **y, int64_t *incy, - int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t *n, const std::complex **x, + int64_t *incx, std::complex **y, int64_t *incy, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const float *x, - int64_t incx, std::int64_t stridex, float *y, - int64_t incy, std::int64_t stridey, - std::int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t n, const float *x, int64_t incx, + std::int64_t stridex, float *y, int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const double *x, - int64_t incx, std::int64_t stridex, double *y, - int64_t incy, std::int64_t stridey, - std::int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +sycl::event copy_batch(sycl::queue &queue, int64_t n, const double *x, int64_t incx, + std::int64_t stridex, double *y, int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, - const std::complex *x, int64_t incx, - std::int64_t stridex, std::complex *y, - int64_t incy, std::int64_t stridey, - std::int64_t batch_size, +sycl::event copy_batch(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, + std::int64_t stridex, std::complex *y, int64_t incy, + std::int64_t stridey, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event copy_batch(sycl::queue &queue, int64_t n, - const std::complex *x, int64_t incx, - std::int64_t stridex, std::complex *y, - int64_t incy, std::int64_t stridey, - std::int64_t batch_size, +sycl::event copy_batch(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, + std::int64_t stridex, std::complex *y, int64_t incy, + std::int64_t stridey, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); + throw unimplemented("blas", "copy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, float *alpha, - const float **x, int64_t *incx, float **y, int64_t *incy, - int64_t group_count, int64_t *group_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, float *alpha, const float **x, int64_t *incx, + float **y, int64_t *incy, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, double *alpha, - const double **x, int64_t *incx, double **y, - int64_t *incy, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, double *alpha, const double **x, + int64_t *incx, double **y, int64_t *incy, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, - std::complex *alpha, - const std::complex **x, int64_t *incx, - std::complex **y, int64_t *incy, - int64_t group_count, int64_t *group_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, std::complex *alpha, + const std::complex **x, int64_t *incx, std::complex **y, + int64_t *incy, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, - std::complex *alpha, - const std::complex **x, int64_t *incx, - std::complex **y, int64_t *incy, - int64_t group_count, int64_t *group_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t *n, std::complex *alpha, + const std::complex **x, int64_t *incx, std::complex **y, + int64_t *incy, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, float alpha, - const float *x, int64_t incx, int64_t stridex, float *y, - int64_t incy, int64_t stridey, int64_t batch_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t n, float alpha, const float *x, int64_t incx, + int64_t stridex, float *y, int64_t incy, int64_t stridey, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, double alpha, - const double *x, int64_t incx, int64_t stridex, - double *y, int64_t incy, int64_t stridey, - int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); +sycl::event axpy_batch(sycl::queue &queue, int64_t n, double alpha, const double *x, int64_t incx, + int64_t stridex, double *y, int64_t incy, int64_t stridey, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } sycl::event axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, - int64_t stridex, std::complex *y, int64_t incy, - int64_t stridey, int64_t batch_size, + const std::complex *x, int64_t incx, int64_t stridex, + std::complex *y, int64_t incy, int64_t stridey, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, - std::complex alpha, - const std::complex *x, int64_t incx, - int64_t stridex, std::complex *y, int64_t incy, - int64_t stridey, int64_t batch_size, +sycl::event axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, + const std::complex *x, int64_t incx, int64_t stridex, + std::complex *y, int64_t incy, int64_t stridey, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); + throw unimplemented("blas", "axpy_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, - int64_t n, float alpha, const float *a, int64_t lda, - int64_t stride_a, const float *x, int64_t incx, - int64_t stride_x, float beta, float *y, int64_t incy, - int64_t stride_y, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, const float *x, int64_t incx, + int64_t stride_x, float beta, float *y, int64_t incy, int64_t stride_y, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, - int64_t n, double alpha, const double *a, int64_t lda, - int64_t stride_a, const double *x, int64_t incx, - int64_t stride_x, double beta, double *y, int64_t incy, +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, const double *x, + int64_t incx, int64_t stride_x, double beta, double *y, int64_t incy, int64_t stride_y, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, - int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, - int64_t incx, int64_t stride_x, std::complex beta, - std::complex *y, int64_t incy, int64_t stride_y, - int64_t batch_size, +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, const std::complex *x, int64_t incx, + int64_t stride_x, std::complex beta, std::complex *y, + int64_t incy, int64_t stride_y, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, - int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, - int64_t incx, int64_t stride_x, - std::complex beta, std::complex *y, +sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, const std::complex *x, int64_t incx, + int64_t stride_x, std::complex beta, std::complex *y, int64_t incy, int64_t stride_y, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, - int64_t *n, float *alpha, const float **a, int64_t *lda, - const float **x, int64_t *incx, float *beta, float **y, - int64_t *incy, int64_t group_count, int64_t *groupsize, +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, float *alpha, + const float **a, int64_t *lda, const float **x, int64_t *incx, float *beta, + float **y, int64_t *incy, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, - int64_t *n, double *alpha, const double **a, - int64_t *lda, const double **x, int64_t *incx, - double *beta, double **y, int64_t *incy, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, double *alpha, + const double **a, int64_t *lda, const double **x, int64_t *incx, + double *beta, double **y, int64_t *incy, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, - int64_t *n, std::complex *alpha, - const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, - std::complex *beta, std::complex **y, - int64_t *incy, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, + std::complex *alpha, const std::complex **a, int64_t *lda, + const std::complex **x, int64_t *incx, std::complex *beta, + std::complex **y, int64_t *incy, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, - int64_t *n, std::complex *alpha, - const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, - std::complex *beta, std::complex **y, - int64_t *incy, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); +sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, + std::complex *alpha, const std::complex **a, int64_t *lda, + const std::complex **x, int64_t *incx, std::complex *beta, + std::complex **y, int64_t *incy, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, - int64_t n, const float *a, int64_t lda, int64_t stride_a, - const float *x, int64_t incx, int64_t stride_x, float *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const float *a, + int64_t lda, int64_t stride_a, const float *x, int64_t incx, + int64_t stride_x, float *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, - int64_t n, const double *a, int64_t lda, - int64_t stride_a, const double *x, int64_t incx, - int64_t stride_x, double *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const double *a, + int64_t lda, int64_t stride_a, const double *x, int64_t incx, + int64_t stride_x, double *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, - int64_t n, const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, - int64_t incx, int64_t stride_x, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, + const std::complex *a, int64_t lda, int64_t stride_a, + const std::complex *x, int64_t incx, int64_t stride_x, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, - int64_t n, const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, - int64_t incx, int64_t stride_x, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, + const std::complex *a, int64_t lda, int64_t stride_a, + const std::complex *x, int64_t incx, int64_t stride_x, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, - int64_t *n, const float **a, int64_t *lda, - const float **x, int64_t *incx, float **c, int64_t *ldc, - int64_t group_count, int64_t *groupsize, +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, + const float **a, int64_t *lda, const float **x, int64_t *incx, float **c, + int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, - int64_t *n, const double **a, int64_t *lda, - const double **x, int64_t *incx, double **c, +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, + const double **a, int64_t *lda, const double **x, int64_t *incx, double **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, - int64_t *n, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, - std::complex **c, int64_t *ldc, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, + const std::complex **a, int64_t *lda, const std::complex **x, + int64_t *incx, std::complex **c, int64_t *ldc, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, - int64_t *n, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, - std::complex **c, int64_t *ldc, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); +sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, + const std::complex **a, int64_t *lda, const std::complex **x, + int64_t *incx, std::complex **c, int64_t *ldc, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - sycl::event gemm_batch( \ - sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ - int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ - TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size, const std::vector &dependencies) { \ - throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ - } +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ + } GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) @@ -1829,22 +1587,21 @@ GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, - std::complex, std::complex) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, - std::complex, std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ - sycl::event gemm_batch( \ - sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ - const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ - } +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ + } GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) @@ -1852,71 +1609,62 @@ GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) GEMM_BATCH_LAUNCHER_USM(float, float, float, float) GEMM_BATCH_LAUNCHER_USM(double, double, double, double) -GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, - std::complex, std::complex) -GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, - std::complex, std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_BATCH_LAUNCHER_USM -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - float alpha, const float *a, int64_t lda, - int64_t stride_a, float *b, int64_t ldb, - int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, + int64_t lda, int64_t stride_a, float *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - double alpha, const double *a, int64_t lda, - int64_t stride_a, double *b, int64_t ldb, - int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, double alpha, const double *a, + int64_t lda, int64_t stride_a, double *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, - int64_t lda, int64_t stride_a, std::complex *b, - int64_t ldb, int64_t stride_b, int64_t batch_size, +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, int64_t stride_a, + std::complex *b, int64_t ldb, int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, - transpose trans, diag unit_diag, int64_t m, int64_t n, - std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, int64_t ldb, - int64_t stride_b, int64_t batch_size, +sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, + diag unit_diag, int64_t m, int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, int64_t stride_a, + std::complex *b, int64_t ldb, int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); + throw unimplemented("blas", "trsm_batch", "for row_major layout"); } template -inline sycl::event -trsm_batch(const char *func_name, Func func, sycl::queue &queue, - side *left_right, uplo *upper_lower, transpose *trans, - diag *unit_diag, int64_t *m, int64_t *n, T *alpha, const T **a, - int64_t *lda, T **b, int64_t *ldb, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); -} - -#define TRSM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event trsm_batch(sycl::queue &queue, side *left_right, \ - uplo *upper_lower, transpose *trans, diag *unit_diag, \ - int64_t *m, int64_t *n, TYPE *alpha, const TYPE **a, \ - int64_t *lda, TYPE **b, int64_t *ldb, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return trsm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, left_right, \ - upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, \ - ldb, group_count, group_size, dependencies); \ - } +inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &queue, + side *left_right, uplo *upper_lower, transpose *trans, + diag *unit_diag, int64_t *m, int64_t *n, T *alpha, const T **a, + int64_t *lda, T **b, int64_t *ldb, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", "for row_major layout"); +} + +#define TRSM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ + sycl::event trsm_batch(sycl::queue &queue, side *left_right, uplo *upper_lower, \ + transpose *trans, diag *unit_diag, int64_t *m, int64_t *n, TYPE *alpha, \ + const TYPE **a, int64_t *lda, TYPE **b, int64_t *ldb, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return trsm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, left_right, upper_lower, trans, \ + unit_diag, m, n, alpha, a, lda, b, ldb, group_count, group_size, \ + dependencies); \ + } TRSM_BATCH_LAUNCHER_USM(float, cublasStrsmBatched) TRSM_BATCH_LAUNCHER_USM(double, cublasDtrsmBatched) @@ -1925,249 +1673,209 @@ TRSM_BATCH_LAUNCHER_USM(std::complex, cublasZtrsmBatched) #undef TRSM_BATCH_LAUNCHER_USM -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, - int64_t *n, int64_t *k, float *alpha, const float **a, - int64_t *lda, float *beta, float **c, int64_t *ldc, - int64_t group_count, int64_t *groupsize, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, + int64_t *k, float *alpha, const float **a, int64_t *lda, float *beta, + float **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, - int64_t *n, int64_t *k, double *alpha, const double **a, - int64_t *lda, double *beta, double **c, int64_t *ldc, - int64_t group_count, int64_t *groupsize, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, + int64_t *k, double *alpha, const double **a, int64_t *lda, double *beta, + double **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, - int64_t *n, int64_t *k, std::complex *alpha, - const std::complex **a, int64_t *lda, - std::complex *beta, std::complex **c, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, + int64_t *k, std::complex *alpha, const std::complex **a, + int64_t *lda, std::complex *beta, std::complex **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, - int64_t *n, int64_t *k, std::complex *alpha, - const std::complex **a, int64_t *lda, - std::complex *beta, std::complex **c, +sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, + int64_t *k, std::complex *alpha, const std::complex **a, + int64_t *lda, std::complex *beta, std::complex **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, float alpha, const float *a, - int64_t lda, int64_t stride_a, float beta, float *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + float alpha, const float *a, int64_t lda, int64_t stride_a, float beta, + float *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, double alpha, const double *a, - int64_t lda, int64_t stride_a, double beta, double *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + double alpha, const double *a, int64_t lda, int64_t stride_a, double beta, + double *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, - std::complex *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex beta, std::complex *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, - int64_t n, int64_t k, std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, - std::complex *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, +sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex beta, std::complex *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + throw unimplemented("blas", "syrk_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, float alpha, const float *a, int64_t lda, - int64_t stride_a, float *b, int64_t ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, double alpha, const double *a, - int64_t lda, int64_t stride_a, double *b, - int64_t ldb, int64_t stride_b, int64_t batch_size, +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, - int64_t ldb, int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, - int64_t ldb, int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, float alpha, float *ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, double alpha, double *ab, int64_t lda, - int64_t ldb, int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + double *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, std::complex alpha, - std::complex *ab, int64_t lda, int64_t ldb, - int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, - int64_t n, std::complex alpha, - std::complex *ab, int64_t lda, int64_t ldb, - int64_t stride, int64_t batch_size, +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, int64_t stride_a, - float beta, const float *b, int64_t ldb, - int64_t stride_b, float *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, + float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, int64_t stride_a, - double beta, const double *b, int64_t ldb, - int64_t stride_b, double *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, + double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, - std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, - const std::complex *b, int64_t ldb, - int64_t stride_b, std::complex *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, - std::complex alpha, - const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, - const std::complex *b, int64_t ldb, - int64_t stride_b, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, float *alpha, const float **a, - int64_t *lda, float **b, int64_t *ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + float *alpha, const float **a, int64_t *lda, float **b, int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, double *alpha, const double **a, - int64_t *lda, double **b, int64_t *ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + double *alpha, const double **a, int64_t *lda, double **b, int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, std::complex *alpha, - const std::complex **a, int64_t *lda, - std::complex **b, int64_t *ldb, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + std::complex *alpha, const std::complex **a, int64_t *lda, + std::complex **b, int64_t *ldb, int64_t group_count, + int64_t *groupsize, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, std::complex *alpha, - const std::complex **a, int64_t *lda, - std::complex **b, int64_t *ldb, +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + std::complex *alpha, const std::complex **a, + int64_t *lda, std::complex **b, int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, float *alpha, float **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, - int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + float *alpha, float **ab, int64_t *lda, int64_t *ldb, + int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, double *alpha, double **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, - int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + double *alpha, double **ab, int64_t *lda, int64_t *ldb, + int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, std::complex *alpha, - std::complex **ab, int64_t *lda, int64_t *ldb, - int64_t group_count, int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + std::complex *alpha, std::complex **ab, int64_t *lda, + int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, - int64_t *n, std::complex *alpha, - std::complex **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, - int64_t *groupsize, +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + std::complex *alpha, std::complex **ab, int64_t *lda, + int64_t *ldb, int64_t group_count, int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } } // namespace row_major From 94dcc7e6db8a3c1e27d8ae6aea05073935650d19 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Mon, 7 Oct 2024 14:28:17 +0100 Subject: [PATCH 11/11] Fix format try 2. Signed-off-by: JackAKirk --- src/blas/backends/cublas/cublas_batch.cpp | 97 ++++++++++------------- 1 file changed, 43 insertions(+), 54 deletions(-) diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index 031c11a2f..9f198b653 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -168,23 +168,19 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran auto c_ = sc.get_mem(c_acc); cublasStatus_t err; #ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, - err, handle, get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, &alpha, a_, - get_cublas_datatype(), lda, stride_a, b_, - get_cublas_datatype(), ldb, stride_b, &beta, c_, - get_cublas_datatype(), ldc, stride_c, batch_size, - get_cublas_datatype(), cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_, + get_cublas_datatype(), lda, stride_a, b_, get_cublas_datatype(), + ldb, stride_b, &beta, c_, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); #else - CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx", - cublasGemmStridedBatchedEx, err, handle, - get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, &alpha, a_, - get_cublas_datatype(), lda, stride_a, b_, - get_cublas_datatype(), ldb, stride_b, &beta, - c_, get_cublas_datatype(), ldc, stride_c, - batch_size, get_cublas_datatype(), - cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T_SYNC( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_, + get_cublas_datatype(), lda, stride_a, b_, get_cublas_datatype(), + ldb, stride_b, &beta, c_, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); #endif }); }); @@ -622,23 +618,19 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra auto handle = sc.get_handle(queue); cublasStatus_t err; #ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, - err, handle, get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, &alpha, a, - get_cublas_datatype(), lda, stride_a, b, - get_cublas_datatype(), ldb, stride_b, &beta, c, - get_cublas_datatype(), ldc, stride_c, batch_size, - get_cublas_datatype(), cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a, + get_cublas_datatype(), lda, stride_a, b, get_cublas_datatype(), + ldb, stride_b, &beta, c, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); #else - CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx", - cublasGemmStridedBatchedEx, err, handle, - get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, &alpha, a, - get_cublas_datatype(), lda, stride_a, b, - get_cublas_datatype(), ldb, stride_b, &beta, - c, get_cublas_datatype(), ldc, stride_c, - batch_size, get_cublas_datatype(), - cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T_SYNC( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a, + get_cublas_datatype(), lda, stride_a, b, get_cublas_datatype(), + ldb, stride_b, &beta, c, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); #endif }); }); @@ -714,26 +706,23 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr cublasStatus_t err; for (int64_t i = 0; i < group_count; i++) { #ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - CUBLAS_ERROR_FUNC_T("cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, - get_cublas_operation(transa[i]), - get_cublas_operation(transb[i]), (int)m[i], (int)n[i], - (int)k[i], &alpha[i], (const void *const *)(a + offset), - get_cublas_datatype(), (int)lda[i], - (const void *const *)(b + offset), - get_cublas_datatype(), (int)ldb[i], &beta[i], - (void *const *)(c + offset), get_cublas_datatype(), - (int)ldc[i], (int)group_size[i], - get_cublas_datatype(), cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T( + "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, + get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i], + (int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset), + get_cublas_datatype(), (int)lda[i], (const void *const *)(b + offset), + get_cublas_datatype(), (int)ldb[i], &beta[i], + (void *const *)(c + offset), get_cublas_datatype(), (int)ldc[i], + (int)group_size[i], get_cublas_datatype(), cublas_gemm_algo); #else CUBLAS_ERROR_FUNC_T_SYNC( "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset), - get_cublas_datatype(), (int)lda[i], - (const void *const *)(b + offset), get_cublas_datatype(), - (int)ldb[i], &beta[i], (void *const *)(c + offset), - get_cublas_datatype(), (int)ldc[i], (int)group_size[i], - get_cublas_datatype(), cublas_gemm_algo); + get_cublas_datatype(), (int)lda[i], (const void *const *)(b + offset), + get_cublas_datatype(), (int)ldb[i], &beta[i], + (void *const *)(c + offset), get_cublas_datatype(), (int)ldc[i], + (int)group_size[i], get_cublas_datatype(), cublas_gemm_algo); #endif offset += group_size[i]; } @@ -832,13 +821,13 @@ inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &que for (int64_t i = 0; i < group_count; i++) { auto **a_ = reinterpret_cast(a); auto **b_ = reinterpret_cast(b); - cublas_native_named_func(func_name, func, err, handle, - get_cublas_side_mode(left_right[i]), - get_cublas_fill_mode(upper_lower[i]), - get_cublas_operation(trans[i]), - get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i], - (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], - b_ + offset, (int)ldb[i], (int)group_size[i]); + cublas_native_named_func( + func_name, func, err, handle, get_cublas_side_mode(left_right[i]), + get_cublas_fill_mode(upper_lower[i]), get_cublas_operation(trans[i]), + get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i], + (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i], + (int)group_size[i]); + offset += group_size[i]; } });