Skip to content

Commit

Permalink
[cublas] add missing support for gemv_batch (#586)
Browse files Browse the repository at this point in the history
Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
  • Loading branch information
JackAKirk authored Oct 11, 2024
1 parent d19d454 commit 21229ee
Showing 1 changed file with 43 additions and 25 deletions.
68 changes: 43 additions & 25 deletions src/blas/backends/cublas/cublas_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,35 +502,53 @@ sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t
throw unimplemented("blas", "gemv_batch", "for column_major layout");
}

sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, float *alpha,
const float **a, int64_t *lda, const float **x, int64_t *incx, float *beta,
float **y, int64_t *incy, int64_t group_count, int64_t *groupsize,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemv_batch", "for column_major layout");
template <typename Func, typename T>
inline sycl::event gemv_batch(const char *func_name, Func func, sycl::queue &queue, transpose *trans, int64_t *m,
int64_t *n, T *alpha, const T **a, int64_t *lda, const T **x,
int64_t *incx, T *beta, T **y, int64_t *incy, int64_t group_count,
int64_t *group_size, const std::vector<sycl::event> &dependencies) {
using cuDataType = typename CudaEquivalentType<T>::Type;
for (int64_t i = 0; i < group_count; i++) {
overflow_check(m[i], n[i], lda[i], incx[i], incy[i], group_size[i]);
}
auto done = queue.submit([&](sycl::handler &cgh) {
cgh.depends_on(dependencies);
onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
int64_t offset = 0;
cublasStatus_t err;
auto **a_ = reinterpret_cast<const cuDataType **>(a);
auto **x_ = reinterpret_cast<const cuDataType **>(x);
auto **y_ = reinterpret_cast<cuDataType **>(y);
for (int64_t i = 0; i < group_count; i++) {
cublas_native_named_func(
func_name, func, err, handle, get_cublas_operation(trans[i]),
(int)m[i], (int)n[i],
(cuDataType *)&alpha[i], a_ + offset, (int)lda[i], x_ + offset, (int)incx[i],
(cuDataType *)&beta[i], y_ + offset, (int)incy[i], (int)group_size[i]);
offset += group_size[i];
}
});
});
return done;
}

sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, double *alpha,
const double **a, int64_t *lda, const double **x, int64_t *incx,
double *beta, double **y, int64_t *incy, int64_t group_count,
int64_t *groupsize, const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemv_batch", "for column_major layout");
}
#define GEMV_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \
sycl::event gemv_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, \
TYPE *alpha, const TYPE **a, int64_t *lda, const TYPE **x, \
int64_t *incx, TYPE *beta, TYPE **y, int64_t *incy, \
int64_t group_count, int64_t *group_size, \
const std::vector<sycl::event> &dependencies) { \
return gemv_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, \
x, incx, beta, y, incy, group_count, group_size, dependencies); \
}

sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n,
std::complex<float> *alpha, const std::complex<float> **a, int64_t *lda,
const std::complex<float> **x, int64_t *incx, std::complex<float> *beta,
std::complex<float> **y, int64_t *incy, int64_t group_count,
int64_t *groupsize, const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemv_batch", "for column_major layout");
}
GEMV_BATCH_LAUNCHER_USM(float, cublasSgemvBatched)
GEMV_BATCH_LAUNCHER_USM(double, cublasDgemvBatched)
GEMV_BATCH_LAUNCHER_USM(std::complex<float>, cublasCgemvBatched)
GEMV_BATCH_LAUNCHER_USM(std::complex<double>, cublasZgemvBatched)

sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n,
std::complex<double> *alpha, const std::complex<double> **a, int64_t *lda,
const std::complex<double> **x, int64_t *incx, std::complex<double> *beta,
std::complex<double> **y, int64_t *incy, int64_t group_count,
int64_t *groupsize, const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemv_batch", "for column_major layout");
}
#undef GEMV_BATCH_LAUNCHER_USM

sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const float *a,
int64_t lda, int64_t stride_a, const float *x, int64_t incx,
Expand Down

0 comments on commit 21229ee

Please sign in to comment.