diff --git a/include/oneapi/mkl/blas.hxx b/include/oneapi/mkl/blas.hxx index 5a703fea2..374585912 100644 --- a/include/oneapi/mkl/blas.hxx +++ b/include/oneapi/mkl/blas.hxx @@ -382,6 +382,39 @@ static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose tr stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); } +static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +} + +static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +} + +static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +} + static inline void gemm_bias(sycl::queue &queue, transpose transa, transpose transb, offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, @@ -2246,6 +2279,45 @@ static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, return done; } +static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, + const sycl::half **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}) { + auto done = + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, group_count, group_size, dependencies); + return done; +} + +static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}) { + auto done = + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, group_count, group_size, dependencies); + return done; +} + +static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}) { + auto done = + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, group_count, group_size, dependencies); + return done; +} + static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, @@ -2312,6 +2384,45 @@ static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, trans return done; } +static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, + float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, + batch_size, dependencies); + return done; +} + +static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, + batch_size, dependencies); + return done; +} + +static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + float beta, std::int32_t *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}) { + auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, + batch_size, dependencies); + return done; +} + static inline sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *b, diff --git a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx index 784eeafee..afebb93c3 100644 --- a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx +++ b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx @@ -464,6 +464,30 @@ static inline void gemm_batch(backend_selector selector, trans sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +static inline void gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + +static inline void gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + +static inline void gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + static inline void spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, sycl::buffer &a, sycl::buffer &x, std::int64_t incx, float beta, @@ -1870,6 +1894,30 @@ static inline sycl::event gemm_batch(backend_selector selector std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies = {}); +static inline sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const sycl::half **a, + std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, + float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); + +static inline sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, + float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); + +static inline sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, + float *beta, std::int32_t **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); + static inline sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, @@ -1911,6 +1959,33 @@ static inline sycl::event gemm_batch( sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); +static inline sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + static inline sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, const float *a, const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, diff --git a/include/oneapi/mkl/blas/detail/blas_loader.hxx b/include/oneapi/mkl/blas/detail/blas_loader.hxx index d964d0024..98d93b2ad 100644 --- a/include/oneapi/mkl/blas/detail/blas_loader.hxx +++ b/include/oneapi/mkl/blas/detail/blas_loader.hxx @@ -124,6 +124,27 @@ ONEMKL_EXPORT void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, tr std::int64_t stride_b, sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +ONEMKL_EXPORT void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); +ONEMKL_EXPORT void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); +ONEMKL_EXPORT void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); ONEMKL_EXPORT void syrk(oneapi::mkl::device libkey, sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, @@ -1227,6 +1248,29 @@ ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &qu sycl::half **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}); ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, @@ -1263,6 +1307,30 @@ ONEMKL_EXPORT sycl::event gemm_batch( std::int64_t lda, std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, + const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, + float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + float beta, std::int32_t *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); ONEMKL_EXPORT sycl::event syrk(oneapi::mkl::device libkey, sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, diff --git a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx index 65ae5b853..9483a66c1 100644 --- a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx @@ -186,6 +186,39 @@ void gemm_batch(backend_selector selector, transpose transa, tr ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::cublas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::cublas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::cublas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -2670,6 +2703,42 @@ sycl::event gemm_batch(backend_selector selector, transpose *tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -2737,6 +2806,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tra return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, const float *a, const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, diff --git a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx index f94e09426..1141eb238 100644 --- a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx @@ -804,6 +804,25 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, @@ -2040,6 +2059,24 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies = {}); +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, const sycl::half **a, + std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, + std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies = {}); + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, const float *b, @@ -2081,6 +2118,27 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t batch_size, const std::vector &dependencies = {}); +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, std::int32_t *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); + sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *b, std::int64_t ldb, float beta, float *c, std::int64_t ldc, diff --git a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx index 004a4c11c..1724bf5c7 100644 --- a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx @@ -188,6 +188,39 @@ void gemm_batch(backend_selector selector, transpose transa, tr ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -2672,6 +2705,42 @@ sycl::event gemm_batch(backend_selector selector, transpose *tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -2739,6 +2808,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tra return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, const float *a, const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, diff --git a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx index d365a39c4..c69257e9c 100644 --- a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx @@ -188,6 +188,39 @@ void gemm_batch(backend_selector selector, transpose transa, tr ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -2622,6 +2655,42 @@ sycl::event gemm_batch(backend_selector selector, transpose *tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose *transa, transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const float **a, std::int64_t *lda, const float **b, @@ -2685,6 +2754,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tra return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, diff --git a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx index fe5b56b48..404d79ae0 100644 --- a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx @@ -188,6 +188,39 @@ void gemm_batch(backend_selector selector, transpose transa, tr ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::netlib::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::netlib::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::netlib::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -2672,6 +2705,42 @@ sycl::event gemm_batch(backend_selector selector, transpose *tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -2739,6 +2808,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tra return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, const float *a, const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, diff --git a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx index fe81ae5aa..fbb64a6a0 100644 --- a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx +++ b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx @@ -973,6 +973,30 @@ ONEMKL_EXPORT void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +ONEMKL_EXPORT void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +ONEMKL_EXPORT void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +ONEMKL_EXPORT void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + ONEMKL_EXPORT void trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, @@ -2558,6 +2582,32 @@ ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}); + ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, @@ -2599,6 +2649,33 @@ ONEMKL_EXPORT sycl::event gemm_batch( std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + ONEMKL_EXPORT sycl::event gemmt(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, diff --git a/include/oneapi/mkl/blas/detail/portblas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/portblas/blas_ct.hxx index 2a092a61b..8a66ed707 100644 --- a/include/oneapi/mkl/blas/detail/portblas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/portblas/blas_ct.hxx @@ -187,6 +187,39 @@ void gemm_batch(backend_selector selector, transpose transa, c, ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -2576,6 +2609,42 @@ sycl::event gemm_batch(backend_selector selector, transpose * return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose *transa, transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const float **a, std::int64_t *lda, const float **b, @@ -2638,6 +2707,42 @@ sycl::event gemm_batch(backend_selector selector, transpose t return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, diff --git a/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx index 32188fed7..bc86929b0 100644 --- a/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx @@ -181,6 +181,36 @@ void gemm_batch(backend_selector selector, transpose transa, t c, ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + int64_t m, int64_t n, int64_t k, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, + int64_t stride_b, float beta, sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + oneapi::mkl::blas::rocblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + int64_t m, int64_t n, int64_t k, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, + int64_t stride_b, float beta, sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + oneapi::mkl::blas::rocblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + int64_t m, int64_t n, int64_t k, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, + int64_t stride_b, float beta, sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + oneapi::mkl::blas::rocblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, int64_t n, int64_t k, float alpha, sycl::buffer &a, int64_t lda, float beta, sycl::buffer &c, int64_t ldc) { @@ -2538,6 +2568,39 @@ sycl::event gemm_batch(backend_selector selector, transpose *t return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, int64_t *m, int64_t *n, int64_t *k, float *alpha, + const sycl::half **a, int64_t *lda, const sycl::half **b, int64_t *ldb, + float *beta, float **c, int64_t *ldc, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, int64_t *m, int64_t *n, int64_t *k, float *alpha, + const std::int8_t **a, int64_t *lda, const std::int8_t **b, int64_t *ldb, + float *beta, float **c, int64_t *ldc, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, int64_t *m, int64_t *n, int64_t *k, float *alpha, + const std::int8_t **a, int64_t *lda, const std::int8_t **b, int64_t *ldb, + float *beta, std::int32_t **c, int64_t *ldc, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, float alpha, const float *a, int64_t lda, int64_t stride_a, const float *b, int64_t ldb, @@ -2598,6 +2661,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, float alpha, + const sycl::half *a, int64_t lda, int64_t stride_a, const sycl::half *b, + int64_t ldb, int64_t stride_b, float beta, float *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, float alpha, + const std::int8_t *a, int64_t lda, int64_t stride_a, const std::int8_t *b, + int64_t ldb, int64_t stride_b, float beta, float *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, float alpha, + const std::int8_t *a, int64_t lda, int64_t stride_a, const std::int8_t *b, + int64_t ldb, int64_t stride_b, float beta, std::int32_t *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event spmv(backend_selector selector, uplo upper_lower, int64_t n, float alpha, const float *a, const float *x, int64_t incx, float beta, float *y, int64_t incy, const std::vector &dependencies) { diff --git a/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx b/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx index e4cd77c4a..70aabaaf9 100644 --- a/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx +++ b/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx @@ -744,6 +744,24 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t sycl::half beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size); +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size); + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size); + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size); + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, @@ -1848,6 +1866,24 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, sycl::half **c, int64_t *ldc, int64_t group_count, int64_t *group_size, const std::vector &dependencies = {}); +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const sycl::half **a, int64_t *lda, + const sycl::half **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *group_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *group_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, std::int32_t **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies = {}); + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, float alpha, const float *a, int64_t lda, int64_t stride_a, const float *b, int64_t ldb, int64_t stride_b, float beta, float *c, @@ -1880,6 +1916,24 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, i sycl::half beta, sycl::half *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies = {}); +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const sycl::half *a, int64_t lda, int64_t stride_a, + const sycl::half *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, + std::int32_t *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, int64_t n, int64_t k, float alpha, const float *a, int64_t lda, const float *b, int64_t ldb, float beta, float *c, int64_t ldc, diff --git a/src/blas/backends/backend_wrappers.cxx b/src/blas/backends/backend_wrappers.cxx index 34af9cf2f..62f6ced13 100644 --- a/src/blas/backends/backend_wrappers.cxx +++ b/src/blas/backends/backend_wrappers.cxx @@ -200,6 +200,9 @@ oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::trsm_batch, oneapi::mkl::blas::BACKEND::MAJOR::trsm_batch, oneapi::mkl::blas::BACKEND::MAJOR::trsm_batch, @@ -455,6 +458,12 @@ oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemmt, oneapi::mkl::blas::BACKEND::MAJOR::gemmt, oneapi::mkl::blas::BACKEND::MAJOR::gemmt, diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index beefd6eeb..009bb9541 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -140,16 +140,21 @@ void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -template -inline void gemm_batch(const char *func_name, Func func, sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, int64_t k, T alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, - int64_t ldb, int64_t stride_b, T beta, sycl::buffer &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - using cuDataType = typename CudaEquivalentType::Type; +template +inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, int64_t k, Ts alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + Ts beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, batch_size); + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, sycl::aspect::fp16)) { + if (!verify_support(queue, sycl::aspect::fp16)) { throw oneapi::mkl::unimplemented( "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); } @@ -158,33 +163,56 @@ inline void gemm_batch(const char *func_name, Func func, sycl::queue &queue, tra auto c_acc = c.template get_access(cgh); onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); - auto a_ = sc.get_mem(a_acc); - auto b_ = sc.get_mem(b_acc); - auto c_ = sc.get_mem(c_acc); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, - a_, lda, stride_a, b_, ldb, stride_b, (cuDataType *)&beta, c_, - ldc, stride_c, batch_size); + CUBLAS_ERROR_FUNC_T_SYNC( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_, + get_cublas_datatype(), lda, stride_a, b_, get_cublas_datatype(), + ldb, stride_b, &beta, c_, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); }); }); } -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE alpha, sycl::buffer &a, int64_t lda, \ - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size) { \ - gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, \ - stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); \ +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size) { \ + gemm_batch_impl(queue, transa, transb, m, n, k, alpha, a, \ + lda, stride_a, b, ldb, stride_b, beta, c, \ + ldc, stride_c, batch_size); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) + +#undef GEMM_STRIDED_BATCH_LAUNCHER + +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ } -GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, cublasHgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(float, cublasSgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(double, cublasDgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, cublasCgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, cublasZgemmStridedBatched) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER @@ -553,17 +581,23 @@ sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -template -inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &queue, - transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, - T alpha, const T *a, int64_t lda, int64_t stride_a, const T *b, - int64_t ldb, int64_t stride_b, T beta, T *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - using cuDataType = typename CudaEquivalentType::Type; +template +inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, + Ts alpha, const Ta *a, int64_t lda, int64_t stride_a, + const Tb *b, int64_t ldb, int64_t stride_b, Ts beta, + Tc *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, + const std::vector &dependencies) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, batch_size); + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; auto done = queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, sycl::aspect::fp16)) { + if (!verify_support(queue, sycl::aspect::fp16)) { throw oneapi::mkl::unimplemented( "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); } @@ -573,50 +607,74 @@ inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &que } onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); - auto a_ = reinterpret_cast(a); - auto b_ = reinterpret_cast(b); - auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, - a_, lda, stride_a, b_, ldb, stride_b, (cuDataType *)&beta, c_, - ldc, stride_c, batch_size); + CUBLAS_ERROR_FUNC_T_SYNC( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a, + get_cublas_datatype(), lda, stride_a, b, get_cublas_datatype(), + ldb, stride_b, &beta, c, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); }); }); return done; } -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE alpha, const TYPE *a, int64_t lda, \ - int64_t stride_a, const TYPE *b, int64_t ldb, int64_t stride_b, \ - TYPE beta, TYPE *c, int64_t ldc, int64_t stride_c, int64_t batch_size, \ - const std::vector &dependencies) { \ - return gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, \ - a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, \ - dependencies); \ +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + return gemm_batch_strided_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, \ + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, \ + batch_size, dependencies); \ } -GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, cublasHgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(float, cublasSgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(double, cublasDgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, cublasCgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, cublasZgemmStridedBatched) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -template -inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &queue, - transpose *transa, transpose *transb, int64_t *m, int64_t *n, - int64_t *k, T *alpha, const T **a, int64_t *lda, const T **b, - int64_t *ldb, T *beta, T **c, int64_t *ldc, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - using cuDataType = typename CudaEquivalentType::Type; +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, + int64_t *m, int64_t *n, int64_t *k, Ts *alpha, const Ta **a, + int64_t *lda, const Tb **b, int64_t *ldb, Ts *beta, Tc **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; for (int64_t i = 0; i < group_count; i++) { overflow_check(m[i], n[i], k[i], lda[i], ldb[i], ldc[i], group_size[i]); } + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; auto done = queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, sycl::aspect::fp16)) { + if (!verify_support(queue, sycl::aspect::fp16)) { throw oneapi::mkl::unimplemented( "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); } @@ -629,14 +687,14 @@ inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &que int64_t offset = 0; cublasStatus_t err; for (int64_t i = 0; i < group_count; i++) { - auto **a_ = reinterpret_cast(a); - auto **b_ = reinterpret_cast(b); - auto **c_ = reinterpret_cast(c); CUBLAS_ERROR_FUNC_T_SYNC( - func_name, func, err, handle, get_cublas_operation(transa[i]), - get_cublas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i], - (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i], - (cuDataType *)&beta[i], c_ + offset, (int)ldc[i], (int)group_size[i]); + "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, + get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i], + (int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset), + get_cublas_datatype(), (int)lda[i], (const void *const *)(b + offset), + get_cublas_datatype(), (int)ldb[i], &beta[i], + (void *const *)(c + offset), get_cublas_datatype(), (int)ldc[i], + (int)group_size[i], get_cublas_datatype(), cublas_gemm_algo); offset += group_size[i]; } }); @@ -644,21 +702,41 @@ inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &que return done; } -#define GEMM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, \ - const TYPE **b, int64_t *ldb, TYPE *beta, TYPE **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, \ - a, lda, b, ldb, beta, c, ldc, group_count, group_size, dependencies); \ +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return gemm_batch_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, \ + ldc, group_count, group_size, dependencies); \ } -GEMM_BATCH_LAUNCHER_USM(sycl::half, cublasHgemmBatched) -GEMM_BATCH_LAUNCHER_USM(float, cublasSgemmBatched) -GEMM_BATCH_LAUNCHER_USM(double, cublasDgemmBatched) -GEMM_BATCH_LAUNCHER_USM(std::complex, cublasCgemmBatched) -GEMM_BATCH_LAUNCHER_USM(std::complex, cublasZgemmBatched) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) +GEMM_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) + +#undef GEMM_BATCH_LAUNCHER_USM + +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } + +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_BATCH_LAUNCHER_USM @@ -1066,30 +1144,25 @@ void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -template -inline void gemm_batch(const char *func_name, Func func, sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, int64_t k, T alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, - int64_t ldb, int64_t stride_b, T beta, sycl::buffer &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", "for row_major layout"); -} - -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE alpha, sycl::buffer &a, int64_t lda, \ - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size) { \ - gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, \ - stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); \ +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ } -GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, cublasHgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(float, cublasSgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(double, cublasDgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, cublasCgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, cublasZgemmStridedBatched) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) +GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER @@ -1458,59 +1531,47 @@ sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -template -inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &queue, - transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, - T alpha, const T *a, int64_t lda, int64_t stride_a, const T *b, - int64_t ldb, int64_t stride_b, T beta, T *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", "for row_major layout"); -} - -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE alpha, const TYPE *a, int64_t lda, \ - int64_t stride_a, const TYPE *b, int64_t ldb, int64_t stride_b, \ - TYPE beta, TYPE *c, int64_t ldc, int64_t stride_c, int64_t batch_size, \ - const std::vector &dependencies) { \ - return gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, \ - a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, \ - dependencies); \ +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ } -GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, cublasHgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(float, cublasSgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(double, cublasDgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, cublasCgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, cublasZgemmStridedBatched) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -template -inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &queue, - transpose *transa, transpose *transb, int64_t *m, int64_t *n, - int64_t *k, T *alpha, const T **a, int64_t *lda, const T **b, - int64_t *ldb, T *beta, T **c, int64_t *ldc, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", "for row_major layout"); -} - -#define GEMM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, \ - const TYPE **b, int64_t *ldb, TYPE *beta, TYPE **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, \ - a, lda, b, ldb, beta, c, ldc, group_count, group_size, dependencies); \ +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ } -GEMM_BATCH_LAUNCHER_USM(sycl::half, cublasHgemmBatched) -GEMM_BATCH_LAUNCHER_USM(float, cublasSgemmBatched) -GEMM_BATCH_LAUNCHER_USM(double, cublasDgemmBatched) -GEMM_BATCH_LAUNCHER_USM(std::complex, cublasCgemmBatched) -GEMM_BATCH_LAUNCHER_USM(std::complex, cublasZgemmBatched) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) +GEMM_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_BATCH_LAUNCHER_USM diff --git a/src/blas/backends/cublas/cublas_helper.hpp b/src/blas/backends/cublas/cublas_helper.hpp index 0ee9930e3..0fe7e7c5a 100644 --- a/src/blas/backends/cublas/cublas_helper.hpp +++ b/src/blas/backends/cublas/cublas_helper.hpp @@ -35,6 +35,7 @@ #include "oneapi/mkl/types.hpp" #include "runtime_support_helper.hpp" +#include "dtype_string.hpp" namespace oneapi { namespace mkl { @@ -231,6 +232,56 @@ inline cublasSideMode_t get_cublas_side_mode(oneapi::mkl::side lr) { } } +template +inline cudaDataType_t get_cublas_datatype() { + static_assert(false); +} + +template <> +inline cudaDataType_t get_cublas_datatype<__half>() { + return CUDA_R_16F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_32F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_64F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_C_32F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_C_64F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_8I; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_8U; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_32I; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_32U; +} + /*converting std::complex to cuComplex*/ /*converting sycl::half to __half*/ template diff --git a/src/blas/backends/cublas/cublas_wrappers.cpp b/src/blas/backends/cublas/cublas_wrappers.cpp index fe479e195..ee5c7239f 100644 --- a/src/blas/backends/cublas/cublas_wrappers.cpp +++ b/src/blas/backends/cublas/cublas_wrappers.cpp @@ -205,6 +205,9 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::trsm_batch, oneapi::mkl::blas::cublas::column_major::trsm_batch, oneapi::mkl::blas::cublas::column_major::trsm_batch, @@ -460,6 +463,12 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemmt, oneapi::mkl::blas::cublas::column_major::gemmt, oneapi::mkl::blas::cublas::column_major::gemmt, @@ -686,6 +695,9 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::trsm_batch, oneapi::mkl::blas::cublas::row_major::trsm_batch, oneapi::mkl::blas::cublas::row_major::trsm_batch, @@ -941,6 +953,12 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemmt, oneapi::mkl::blas::cublas::row_major::gemmt, oneapi::mkl::blas::cublas::row_major::gemmt, diff --git a/src/blas/backends/mkl_common/mkl_batch.cxx b/src/blas/backends/mkl_common/mkl_batch.cxx index 0a204d5b7..6358a3922 100644 --- a/src/blas/backends/mkl_common/mkl_batch.cxx +++ b/src/blas/backends/mkl_common/mkl_batch.cxx @@ -182,6 +182,33 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t stride_b, beta, c, ldc, stride_c, batch_size); } +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); +} + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", + "unsupported dtype combination: int8_t, int8_t, float, float"); +} + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); +} + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, @@ -642,6 +669,33 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, i stride_b, beta, c, ldc, stride_c, batch_size, dependencies); } +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const sycl::half *a, int64_t lda, int64_t stride_a, + const sycl::half *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", + "unsupported dtype combination: int8_t, int8_t, float, float"); +} + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, + std::int32_t *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size, dependencies); +} + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, int64_t *n, int64_t *k, float *alpha, const float **a, int64_t *lda, const float **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, @@ -689,6 +743,33 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, ldc, group_count, groupsize, dependencies); } +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const sycl::half **a, int64_t *lda, + const sycl::half **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc, group_count, groupsize, dependencies); +} + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", + "unsupported dtype combination: int8_t, int8_t, float, float"); +} + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, std::int32_t **c, + int64_t *ldc, int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc, group_count, groupsize, dependencies); +} + sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, int64_t stride_b, diff --git a/src/blas/backends/mklcpu/mklcpu_batch.cpp b/src/blas/backends/mklcpu/mklcpu_batch.cpp index 9dd231629..5ecf4cc69 100644 --- a/src/blas/backends/mklcpu/mklcpu_batch.cpp +++ b/src/blas/backends/mklcpu/mklcpu_batch.cpp @@ -25,6 +25,7 @@ #include "oneapi/mkl/blas/detail/mklcpu/onemkl_blas_mklcpu.hpp" +#include "oneapi/mkl/exceptions.hpp" #include "../mkl_common/mkl_blas_backend.hpp" namespace oneapi { diff --git a/src/blas/backends/mklgpu/mklgpu_batch.cpp b/src/blas/backends/mklgpu/mklgpu_batch.cpp index d859a3b78..bad2db82c 100644 --- a/src/blas/backends/mklgpu/mklgpu_batch.cpp +++ b/src/blas/backends/mklgpu/mklgpu_batch.cpp @@ -25,6 +25,7 @@ #include "oneapi/mkl/blas/detail/mklgpu/onemkl_blas_mklgpu.hpp" +#include "oneapi/mkl/exceptions.hpp" #include "../mkl_common/mkl_blas_backend.hpp" namespace oneapi { diff --git a/src/blas/backends/netlib/netlib_batch.cxx b/src/blas/backends/netlib/netlib_batch.cxx index a029a60bc..7a2839dd4 100644 --- a/src/blas/backends/netlib/netlib_batch.cxx +++ b/src/blas/backends/netlib/netlib_batch.cxx @@ -279,6 +279,45 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t #endif } +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, @@ -983,6 +1022,45 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, #endif } +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const sycl::half **a, int64_t *lda, + const sycl::half **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, std::int32_t **c, + int64_t *ldc, int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, float alpha, const float *a, int64_t lda, int64_t stride_a, const float *b, int64_t ldb, int64_t stride_b, @@ -1052,6 +1130,45 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, i #endif } +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const sycl::half *a, int64_t lda, int64_t stride_a, + const sycl::half *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, + std::int32_t *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, diff --git a/src/blas/backends/portblas/portblas_batch.cxx b/src/blas/backends/portblas/portblas_batch.cxx index 581fcd2e5..28c7ee5dc 100644 --- a/src/blas/backends/portblas/portblas_batch.cxx +++ b/src/blas/backends/portblas/portblas_batch.cxx @@ -213,6 +213,33 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:: throw unimplemented("blas", "gemm_batch", " for complex"); } +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); +} + +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); +} + +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); +} + void trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, @@ -700,6 +727,33 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, throw unimplemented("blas", "gemm_batch", " for USM"); } +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const sycl::half **a, std::int64_t *lda, + const sycl::half **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, @@ -754,6 +808,36 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, throw unimplemented("blas", "gemm_batch", " for USM"); } +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const sycl::half *a, std::int64_t lda, + std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, std::int64_t lda, + std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, std::int64_t lda, + std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + sycl::event trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, float alpha, diff --git a/src/blas/backends/rocblas/CMakeLists.txt b/src/blas/backends/rocblas/CMakeLists.txt index 3a71eda1c..76dc126ad 100644 --- a/src/blas/backends/rocblas/CMakeLists.txt +++ b/src/blas/backends/rocblas/CMakeLists.txt @@ -39,6 +39,7 @@ add_dependencies(onemkl_backend_libs_blas ${LIB_NAME}) target_include_directories(${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src/include ${PROJECT_SOURCE_DIR}/src ${PROJECT_BINARY_DIR}/bin ${ONEMKL_GENERATED_INCLUDE_PATH} diff --git a/src/blas/backends/rocblas/rocblas_batch.cpp b/src/blas/backends/rocblas/rocblas_batch.cpp index 9a0a1be28..5fa103055 100644 --- a/src/blas/backends/rocblas/rocblas_batch.cpp +++ b/src/blas/backends/rocblas/rocblas_batch.cpp @@ -227,14 +227,20 @@ DGMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zdgmm_strided_batched) #undef DGMM_STRIDED_BATCH_LAUNCHER -template -inline void gemm_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, T alpha, sycl::buffer &a, int64_t lda, - int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, T beta, - sycl::buffer &c, int64_t ldc, int64_t stridec, int64_t batch_size) { - using rocDataType = typename RocEquivalentType::Type; +template +inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, int64_t k, Ts alpha, sycl::buffer &a, int64_t lda, + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, + Ts beta, sycl::buffer &c, int64_t ldc, int64_t stridec, + int64_t batch_size) { + using rocTypeA = typename RocEquivalentType::Type; + using rocTypeB = typename RocEquivalentType::Type; + using rocTypeC = typename RocEquivalentType::Type; + using rocTypeS = typename RocEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc, stridea, strideb, stridec, batch_size); + int32_t solution_index = 0; + rocblas_gemm_flags flags = rocblas_gemm_flags_none; queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto b_acc = b.template get_access(cgh); @@ -242,32 +248,58 @@ inline void gemm_batch(Func func, sycl::queue &queue, transpose transa, transpos onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); - auto a_ = sc.get_mem(a_acc); - auto b_ = sc.get_mem(b_acc); - auto c_ = sc.get_mem(c_acc); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); + rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa), - get_rocblas_operation(transb), m, n, k, (rocDataType *)&alpha, - a_, lda, stridea, b_, ldb, strideb, (rocDataType *)&beta, c_, - ldc, stridec, batch_size); + ROCBLAS_ERROR_FUNC_SYNC(rocblas_gemm_strided_batched_ex, err, handle, + get_rocblas_operation(transa), get_rocblas_operation(transb), m, + n, k, &alpha, a_, get_rocblas_datatype(), lda, + stridea, b_, get_rocblas_datatype(), ldb, strideb, + &beta, c_, get_rocblas_datatype(), ldc, stridec, c_, + get_rocblas_datatype(), ldc, stridec, batch_size, + get_rocblas_datatype(), rocblas_gemm_algo_standard, + solution_index, flags); }); }); } -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE alpha, sycl::buffer &a, int64_t lda, int64_t stridea, \ - sycl::buffer &b, int64_t ldb, int64_t strideb, TYPE beta, \ - sycl::buffer &c, int64_t ldc, int64_t stridec, int64_t batch_size) { \ - gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, stridea, b, \ - ldb, strideb, beta, c, ldc, stridec, batch_size); \ +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ + int64_t batch_size) { \ + gemm_batch_impl(queue, transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, \ + beta, c, ldc, stridec, batch_size); \ } -GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, rocblas_hgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(float, rocblas_sgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(double, rocblas_dgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zgemm_strided_batched) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER + +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER @@ -816,63 +848,100 @@ DGMM_BATCH_LAUNCHER_USM(std::complex, rocblas_zdgmm_batched) #undef DGMM_BATCH_LAUNCHER -template -inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, int64_t k, T alpha, const T *a, int64_t lda, - int64_t stridea, const T *b, int64_t ldb, int64_t strideb, T beta, - T *c, int64_t ldc, int64_t stridec, int64_t batch_size, - const std::vector &dependencies) { - using rocDataType = typename RocEquivalentType::Type; +template +inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, + Ts alpha, const Ta *a, int64_t lda, int64_t stridea, + const Tb *b, int64_t ldb, int64_t strideb, Ts beta, + Tc *c, int64_t ldc, int64_t stridec, + int64_t batch_size, + const std::vector &dependencies) { + using rocTypeA = typename RocEquivalentType::Type; + using rocTypeB = typename RocEquivalentType::Type; + using rocTypeC = typename RocEquivalentType::Type; + using rocTypeS = typename RocEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc, stridea, strideb, stridec, batch_size); + int32_t solution_index = 0; + rocblas_gemm_flags flags = rocblas_gemm_flags_none; auto done = queue.submit([&](sycl::handler &cgh) { cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); - auto a_ = reinterpret_cast(a); - auto b_ = reinterpret_cast(b); - auto c_ = reinterpret_cast(c); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa), - get_rocblas_operation(transb), m, n, k, (rocDataType *)&alpha, - a_, lda, stridea, b_, ldb, strideb, (rocDataType *)&beta, c_, - ldc, stridec, batch_size); + ROCBLAS_ERROR_FUNC_SYNC(rocblas_gemm_strided_batched_ex, err, handle, + get_rocblas_operation(transa), get_rocblas_operation(transb), m, + n, k, &alpha, a_, get_rocblas_datatype(), lda, + stridea, b_, get_rocblas_datatype(), ldb, strideb, + &beta, c_, get_rocblas_datatype(), ldc, stridec, c_, + get_rocblas_datatype(), ldc, stridec, batch_size, + get_rocblas_datatype(), rocblas_gemm_algo_standard, + solution_index, flags); }); }); return done; } -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE alpha, const TYPE *a, int64_t lda, \ - int64_t stridea, const TYPE *b, int64_t ldb, int64_t strideb, \ - TYPE beta, TYPE *c, int64_t ldc, int64_t stridec, int64_t batch_size, \ - const std::vector &dependencies) { \ - return gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, stridea, \ - b, ldb, strideb, beta, c, ldc, stridec, batch_size, dependencies); \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + return gemm_batch_strided_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, stridea, \ + b, ldb, strideb, beta, c, ldc, stridec, batch_size, \ + dependencies); \ } -GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, rocblas_hgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemm_strided_batched) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -template -inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose *transa, transpose *transb, - int64_t *m, int64_t *n, int64_t *k, T *alpha, const T **a, - int64_t *lda, const T **b, int64_t *ldb, T *beta, T **c, int64_t *ldc, - int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - using rocDataType = typename RocEquivalentType::Type; +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, + int64_t *m, int64_t *n, int64_t *k, Ts *alpha, const Ta **a, + int64_t *lda, const Tb **b, int64_t *ldb, Ts *beta, Tc **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + using rocTypeA = typename RocEquivalentType::Type; + using rocTypeB = typename RocEquivalentType::Type; + using rocTypeC = typename RocEquivalentType::Type; + using rocTypeS = typename RocEquivalentType::Type; for (int64_t i = 0; i < group_count; i++) { overflow_check(m[i], n[i], k[i], lda[i], ldb[i], ldc[i], group_size[i]); } + int32_t solution_index = 0; + rocblas_gemm_flags flags = rocblas_gemm_flags_none; auto done = queue.submit([&](sycl::handler &cgh) { cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { @@ -881,14 +950,18 @@ inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose *transa, int64_t offset = 0; rocblas_status err; for (int64_t i = 0; i < group_count; i++) { - auto **a_ = reinterpret_cast(a); - auto **b_ = reinterpret_cast(b); - auto **c_ = reinterpret_cast(c); + auto **a_ = reinterpret_cast(a); + auto **b_ = reinterpret_cast(b); + auto **c_ = reinterpret_cast(c); ROCBLAS_ERROR_FUNC_SYNC( - func, err, handle, get_rocblas_operation(transa[i]), - get_rocblas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i], - (rocDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i], - (rocDataType *)&beta[i], c_ + offset, (int)ldc[i], (int)group_size[i]); + rocblas_gemm_batched_ex, err, handle, get_rocblas_operation(transa[i]), + get_rocblas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i], &alpha[i], + a_ + offset, get_rocblas_datatype(), (int)lda[i], b_ + offset, + get_rocblas_datatype(), (int)ldb[i], &beta[i], c_ + offset, + get_rocblas_datatype(), (int)ldc[i], c_ + offset, + get_rocblas_datatype(), (int)ldc[i], (int)group_size[i], + get_rocblas_datatype(), rocblas_gemm_algo_standard, solution_index, + flags); offset += group_size[i]; } }); @@ -897,21 +970,41 @@ inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose *transa, return done; } -#define GEMM_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, \ - const TYPE **b, int64_t *ldb, TYPE *beta, TYPE **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, \ - beta, c, ldc, group_count, group_size, dependencies); \ +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return gemm_batch_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, \ + ldc, group_count, group_size, dependencies); \ } -GEMM_BATCH_LAUNCHER_USM(sycl::half, rocblas_hgemm_batched) -GEMM_BATCH_LAUNCHER_USM(float, rocblas_sgemm_batched) -GEMM_BATCH_LAUNCHER_USM(double, rocblas_dgemm_batched) -GEMM_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemm_batched) -GEMM_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemm_batched) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) + +#undef GEMM_BATCH_LAUNCHER_USM + +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } + +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_BATCH_LAUNCHER_USM @@ -1442,32 +1535,55 @@ DGMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zdgmm_strided_batched) #undef DGMM_STRIDED_BATCH_LAUNCHER -template -inline void gemm_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, T alpha, sycl::buffer &a, int64_t lda, - int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, T beta, - sycl::buffer &c, int64_t ldc, int64_t stridec, int64_t batch_size) { +template +inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, int64_t k, Ts alpha, sycl::buffer &a, int64_t lda, + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, + Ts beta, sycl::buffer &c, int64_t ldc, int64_t stridec, + int64_t batch_size) { auto new_transa = transb; auto new_transb = transa; - column_major::gemm_batch(func, queue, new_transa, new_transb, n, m, k, alpha, b, ldb, strideb, - a, lda, stridea, beta, c, ldc, stridec, batch_size); + column_major::gemm_batch(queue, new_transa, new_transb, n, m, k, alpha, b, ldb, strideb, a, lda, + stridea, beta, c, ldc, stridec, batch_size); } -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE alpha, sycl::buffer &a, int64_t lda, int64_t stridea, \ - sycl::buffer &b, int64_t ldb, int64_t strideb, TYPE beta, \ - sycl::buffer &c, int64_t ldc, int64_t stridec, int64_t batch_size) { \ - gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, stridea, b, \ - ldb, strideb, beta, c, ldc, stridec, batch_size); \ +#undef GEMM_STRIDED_BATCH_LAUNCHER +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ + int64_t batch_size) { \ + gemm_batch_impl(queue, transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, \ + beta, c, ldc, stridec, batch_size); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER + +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ } -GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, rocblas_hgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(float, rocblas_sgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(double, rocblas_dgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zgemm_strided_batched) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER @@ -1936,67 +2052,110 @@ DGMM_BATCH_LAUNCHER_USM(std::complex, rocblas_zdgmm_batched) #undef DGMM_BATCH_LAUNCHER -template -inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, int64_t k, T alpha, const T *a, int64_t lda, - int64_t stridea, const T *b, int64_t ldb, int64_t strideb, T beta, - T *c, int64_t ldc, int64_t stridec, int64_t batch_size, - const std::vector &dependencies) { +template +inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, + Ts alpha, const Ta *a, int64_t lda, int64_t stridea, + const Tb *b, int64_t ldb, int64_t strideb, Ts beta, + Tc *c, int64_t ldc, int64_t stridec, + int64_t batch_size, + const std::vector &dependencies) { auto new_transa = transb; auto new_transb = transa; - return column_major::gemm_batch(func, queue, new_transa, new_transb, n, m, k, alpha, b, ldb, - strideb, a, lda, stridea, beta, c, ldc, stridec, batch_size, + return column_major::gemm_batch(queue, new_transa, new_transb, n, m, k, alpha, b, ldb, strideb, + a, lda, stridea, beta, c, ldc, stridec, batch_size, dependencies); } -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE alpha, const TYPE *a, int64_t lda, \ - int64_t stridea, const TYPE *b, int64_t ldb, int64_t strideb, \ - TYPE beta, TYPE *c, int64_t ldc, int64_t stridec, int64_t batch_size, \ - const std::vector &dependencies) { \ - return gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, stridea, \ - b, ldb, strideb, beta, c, ldc, stridec, batch_size, dependencies); \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + return gemm_batch_strided_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, stridea, \ + b, ldb, strideb, beta, c, ldc, stridec, batch_size, \ + dependencies); \ } -GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, rocblas_hgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemm_strided_batched) +GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -template -inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose *transa, transpose *transb, - int64_t *m, int64_t *n, int64_t *k, T *alpha, const T **a, - int64_t *lda, const T **b, int64_t *ldb, T *beta, T **c, int64_t *ldc, - int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, + int64_t *m, int64_t *n, int64_t *k, Ts *alpha, const Ta **a, + int64_t *lda, const Tb **b, int64_t *ldb, Ts *beta, Tc **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { for (int64_t i = 0; i < group_count; i++) { std::swap(transa[i], transb[i]); } - return column_major::gemm_batch(func, queue, transa, transb, n, m, k, alpha, b, ldb, a, lda, - beta, c, ldc, group_count, group_size, dependencies); + return column_major::gemm_batch(queue, transa, transb, n, m, k, alpha, b, ldb, a, lda, beta, c, + ldc, group_count, group_size, dependencies); } -#define GEMM_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, \ - const TYPE **b, int64_t *ldb, TYPE *beta, TYPE **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, \ - beta, c, ldc, group_count, group_size, dependencies); \ +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return gemm_batch_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, \ + ldc, group_count, group_size, dependencies); \ + } + +GEMM_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) + +#undef GEMM_BATCH_LAUNCHER_USM + +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ } -GEMM_BATCH_LAUNCHER_USM(sycl::half, rocblas_hgemm_batched) -GEMM_BATCH_LAUNCHER_USM(float, rocblas_sgemm_batched) -GEMM_BATCH_LAUNCHER_USM(double, rocblas_dgemm_batched) -GEMM_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemm_batched) -GEMM_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemm_batched) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_BATCH_LAUNCHER_USM diff --git a/src/blas/backends/rocblas/rocblas_helper.hpp b/src/blas/backends/rocblas/rocblas_helper.hpp index eeeb5a11c..ae6301a7a 100644 --- a/src/blas/backends/rocblas/rocblas_helper.hpp +++ b/src/blas/backends/rocblas/rocblas_helper.hpp @@ -31,6 +31,7 @@ #include #include "oneapi/mkl/types.hpp" #include +#include "dtype_string.hpp" namespace oneapi { namespace mkl { @@ -205,6 +206,66 @@ inline rocblas_side get_rocblas_side_mode(oneapi::mkl::side lr) { } } +template +inline rocblas_datatype get_rocblas_datatype() { + static_assert(false); +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f16_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f32_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f64_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f32_c; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f64_c; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_i8_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_u8_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_i32_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_u32_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_bf16_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype>() { + return rocblas_datatype_bf16_c; +} + /*converting std::complex to roc__complex sycl::half to rocblas_half*/ template diff --git a/src/blas/backends/rocblas/rocblas_wrappers.cpp b/src/blas/backends/rocblas/rocblas_wrappers.cpp index 87fc78b86..ce4c92da5 100644 --- a/src/blas/backends/rocblas/rocblas_wrappers.cpp +++ b/src/blas/backends/rocblas/rocblas_wrappers.cpp @@ -207,6 +207,9 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::trsm_batch, oneapi::mkl::blas::rocblas::column_major::trsm_batch, oneapi::mkl::blas::rocblas::column_major::trsm_batch, @@ -462,6 +465,12 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemmt, oneapi::mkl::blas::rocblas::column_major::gemmt, oneapi::mkl::blas::rocblas::column_major::gemmt, @@ -688,6 +697,9 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::trsm_batch, oneapi::mkl::blas::rocblas::row_major::trsm_batch, oneapi::mkl::blas::rocblas::row_major::trsm_batch, @@ -943,6 +955,12 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemmt, oneapi::mkl::blas::rocblas::row_major::gemmt, oneapi::mkl::blas::rocblas::row_major::gemmt, diff --git a/src/blas/blas_loader.cpp b/src/blas/blas_loader.cpp index 490d730a7..c1f1339c6 100644 --- a/src/blas/blas_loader.cpp +++ b/src/blas/blas_loader.cpp @@ -1342,6 +1342,39 @@ void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa stride_c, batch_size); } +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].column_major_gemm_f16f16f32_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].column_major_gemm_s8s8f32_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + function_tables[libkey].column_major_gemm_s8s8s32_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + void trsm_batch(oneapi::mkl::device libkey, sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, @@ -3405,6 +3438,39 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose group_size, dependencies); } +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_gemm_f16f16f32_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_gemm_s8s8f32_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_gemm_s8s8s32_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -3463,6 +3529,39 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose stride_c, batch_size, dependencies); } +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_gemm_f16f16f32_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_gemm_s8s8f32_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].column_major_gemm_s8s8s32_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + sycl::event gemmt(oneapi::mkl::device libkey, sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *b, std::int64_t ldb, float beta, @@ -5177,6 +5276,39 @@ void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa stride_c, batch_size); } +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].row_major_gemm_f16f16f32_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].row_major_gemm_s8s8f32_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + function_tables[libkey].row_major_gemm_s8s8s32_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + void trsm_batch(oneapi::mkl::device libkey, sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, @@ -7236,6 +7368,39 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose group_size, dependencies); } +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_gemm_f16f16f32_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_gemm_s8s8f32_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_gemm_s8s8s32_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -7294,6 +7459,39 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose stride_c, batch_size, dependencies); } +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_gemm_f16f16f32_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_gemm_s8s8f32_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].row_major_gemm_s8s8s32_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + sycl::event gemmt(oneapi::mkl::device libkey, sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *b, std::int64_t ldb, float beta, diff --git a/src/blas/function_table.hpp b/src/blas/function_table.hpp index c9d640b1c..a242fd0c0 100644 --- a/src/blas/function_table.hpp +++ b/src/blas/function_table.hpp @@ -869,6 +869,26 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + void (*column_major_gemm_f16f16f32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + void (*column_major_gemm_s8s8f32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + void (*column_major_gemm_s8s8s32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); void (*column_major_strsm_batch_strided_sycl)( sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, @@ -2180,6 +2200,24 @@ typedef struct { std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, sycl::half *beta, sycl::half **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies); + sycl::event (*column_major_gemm_f16f16f32_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const sycl::half **a, + std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); + sycl::event (*column_major_gemm_s8s8f32_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); + sycl::event (*column_major_gemm_s8s8s32_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); sycl::event (*column_major_sgemm_batch_strided_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, @@ -2213,6 +2251,24 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_gemm_f16f16f32_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_gemm_s8s8f32_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_gemm_s8s8s32_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); sycl::event (*column_major_sgemmt_usm_sycl)(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t n, @@ -3269,6 +3325,26 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + void (*row_major_gemm_f16f16f32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + void (*row_major_gemm_s8s8f32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + void (*row_major_gemm_s8s8s32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); void (*row_major_strsm_batch_strided_sycl)( sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, @@ -4581,6 +4657,24 @@ typedef struct { std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, sycl::half *beta, sycl::half **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies); + sycl::event (*row_major_gemm_f16f16f32_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const sycl::half **a, + std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); + sycl::event (*row_major_gemm_s8s8f32_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); + sycl::event (*row_major_gemm_s8s8s32_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); sycl::event (*row_major_sgemm_batch_strided_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, @@ -4614,6 +4708,24 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_gemm_f16f16f32_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_gemm_s8s8f32_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_gemm_s8s8s32_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); sycl::event (*row_major_sgemmt_usm_sycl)(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t n, diff --git a/src/include/dtype_string.hpp b/src/include/dtype_string.hpp new file mode 100644 index 000000000..6f2a87feb --- /dev/null +++ b/src/include/dtype_string.hpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2020-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_ERROR_HELPER_HPP_ +#define _ONEMKL_ERROR_HELPER_HPP_ + +#include + +template +inline const std::string dtype_string(); +template <> +inline const std::string dtype_string() { + return "float"; +} +template <> +inline const std::string dtype_string() { + return "double"; +} +template <> +inline const std::string dtype_string() { + return "half"; +} +template <> +inline const std::string dtype_string>() { + return "complex"; +} +template <> +inline const std::string dtype_string>() { + return "complex"; +} +template <> +inline const std::string dtype_string() { + return "int32"; +} +template <> +inline const std::string dtype_string() { + return "int8"; +} + +#endif //_ONEMKL_ERROR_HELPER_HPP_ diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp index d194e2007..12af18ec9 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp @@ -47,13 +47,13 @@ extern std::vector devices; namespace { -template +template int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Prepare data. int64_t m, n, k; int64_t lda, ldb, ldc; oneapi::mkl::transpose transa, transb; - fp alpha, beta; + Ts alpha, beta; int64_t i, tmp; batch_size = 1 + std::rand() % 20; @@ -63,14 +63,11 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { lda = std::max(m, k); ldb = std::max(n, k); ldc = std::max(m, n); - alpha = rand_scalar(); - beta = rand_scalar(); + alpha = rand_scalar(); + beta = rand_scalar(); - if ((std::is_same::value) || (std::is_same::value)) { - transa = (oneapi::mkl::transpose)(std::rand() % 2); - transb = (oneapi::mkl::transpose)(std::rand() % 2); - } - else { + if ((std::is_same>::value) || + (std::is_same>::value)) { tmp = std::rand() % 3; if (tmp == 2) transa = oneapi::mkl::transpose::conjtrans; @@ -82,6 +79,10 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { else transb = (oneapi::mkl::transpose)tmp; } + else { + transa = (oneapi::mkl::transpose)(std::rand() % 2); + transb = (oneapi::mkl::transpose)(std::rand() % 2); + } int64_t stride_a, stride_b, stride_c; @@ -99,8 +100,12 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { default: break; } - vector> A(stride_a * batch_size), B(stride_b * batch_size); - vector> C(stride_c * batch_size), C_ref(stride_c * batch_size); + vector> A(stride_a * batch_size); + vector> B(stride_b * batch_size); + vector> C(stride_c * batch_size), + C_cast_ref(stride_c * batch_size); + vector> A_ref(stride_a * batch_size), B_ref(stride_b * batch_size), + C_ref(stride_c * batch_size); for (i = 0; i < batch_size; i++) { rand_matrix(A.data() + stride_a * i, layout, transa, m, k, lda); @@ -108,10 +113,18 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { rand_matrix(C.data() + stride_c * i, layout, oneapi::mkl::transpose::nontrans, m, n, ldc); } - C_ref = C; + for (size_t i = 0; i < A.size(); ++i) { + A_ref[i] = A[i]; + } + for (size_t i = 0; i < B.size(); ++i) { + B_ref[i] = B[i]; + } + for (size_t i = 0; i < C.size(); ++i) { + C_ref[i] = C[i]; + } // Call reference GEMM_BATCH_STRIDE. - using fp_ref = typename ref_type_info::type; + using fp_ref = typename ref_type_info::type; int m_ref = (int)m; int n_ref = (int)n; int k_ref = (int)k; @@ -121,12 +134,13 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int batch_size_ref = (int)batch_size; for (i = 0; i < batch_size_ref; i++) { - ::gemm( - convert_to_cblas_layout(layout), convert_to_cblas_trans(transa), - convert_to_cblas_trans(transb), (const int *)&m_ref, (const int *)&n_ref, - (const int *)&k_ref, (const fp_ref *)&alpha, (const fp_ref *)(A.data() + stride_a * i), - (const int *)&lda_ref, (const fp_ref *)(B.data() + stride_b * i), (const int *)&ldb_ref, - (const fp_ref *)&beta, (fp_ref *)(C_ref.data() + stride_c * i), (const int *)&ldc_ref); + ::gemm(convert_to_cblas_layout(layout), convert_to_cblas_trans(transa), + convert_to_cblas_trans(transb), (const int *)&m_ref, (const int *)&n_ref, + (const int *)&k_ref, (const fp_ref *)&alpha, + (const fp_ref *)(A_ref.data() + stride_a * i), (const int *)&lda_ref, + (const fp_ref *)(B_ref.data() + stride_b * i), (const int *)&ldb_ref, + (const fp_ref *)&beta, (fp_ref *)(C_ref.data() + stride_c * i), + (const int *)&ldc_ref); } // Call DPC++ GEMM_BATCH_STRIDE. @@ -147,9 +161,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { queue main_queue(*dev, exception_handler); - buffer A_buffer(A.data(), range<1>(A.size())); - buffer B_buffer(B.data(), range<1>(B.size())); - buffer C_buffer(C.data(), range<1>(C.size())); + buffer A_buffer(A.data(), range<1>(A.size())); + buffer B_buffer(B.data(), range<1>(B.size())); + buffer C_buffer(C.data(), range<1>(C.size())); try { #ifdef CALL_RT_API @@ -183,6 +197,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { default: break; } #endif + main_queue.wait_and_throw(); } catch (exception const &e) { std::cout << "Caught synchronous SYCL exception during GEMM_BATCH_STRIDE:\n" @@ -200,11 +215,18 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. + int tol_scalar = 10; + int error_mag = tol_scalar * k; + if (std::is_same_v) + error_mag = 1; + for (size_t i = 0; i < C_ref.size(); ++i) { + C_cast_ref[i] = C_ref[i]; + } auto C_accessor = C_buffer.template get_host_access(read_only); - bool good = - check_equal_matrix(C_accessor, C_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, - 1, stride_c * batch_size, 10 * k, std::cout); + bool good = check_almost_equal_matrix(C_accessor, C_cast_ref, oneapi::mkl::layout::col_major, + stride_c * batch_size, 1, stride_c * batch_size, + error_mag, std::cout); return (int)good; } @@ -213,29 +235,49 @@ class GemmBatchStrideTests : public ::testing::TestWithParam> {}; TEST_P(GemmBatchStrideTests, RealHalfPrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideTests, HalfHalfFloatPrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideTests, Int8Int8SinglePrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideTests, Int8Int8Int32Precision) { + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideTests, RealSinglePrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP( + (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideTests, RealDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP(( + test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideTests, ComplexSinglePrecision) { EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, std::complex>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideTests, ComplexDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, + std::complex>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } INSTANTIATE_TEST_SUITE_P(GemmBatchStrideTestSuite, GemmBatchStrideTests, diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp index b0d8ec90b..97f2dd086 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp @@ -47,7 +47,7 @@ extern std::vector devices; namespace { -template +template int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Catch asynchronous exceptions. auto exception_handler = [](exception_list exceptions) { @@ -72,7 +72,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int64_t m, n, k; int64_t lda, ldb, ldc; oneapi::mkl::transpose transa, transb; - fp alpha, beta; + Ts alpha, beta; int64_t i, tmp; @@ -83,13 +83,10 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { lda = std::max(m, k); ldb = std::max(n, k); ldc = std::max(m, n); - alpha = rand_scalar(); - beta = rand_scalar(); - if ((std::is_same::value) || (std::is_same::value)) { - transa = (oneapi::mkl::transpose)(std::rand() % 2); - transb = (oneapi::mkl::transpose)(std::rand() % 2); - } - else { + alpha = rand_scalar(); + beta = rand_scalar(); + if ((std::is_same>::value) || + (std::is_same>::value)) { tmp = std::rand() % 3; if (tmp == 2) transa = oneapi::mkl::transpose::conjtrans; @@ -101,6 +98,10 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { else transb = (oneapi::mkl::transpose)tmp; } + else { + transa = (oneapi::mkl::transpose)(std::rand() % 2); + transb = (oneapi::mkl::transpose)(std::rand() % 2); + } int64_t stride_a, stride_b, stride_c; @@ -118,18 +119,27 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { default: break; } - auto ua = usm_allocator(cxt, *dev); - vector A(ua), B(ua), C(ua), C_ref(ua); + auto ua = usm_allocator(cxt, *dev); + auto ub = usm_allocator(cxt, *dev); + auto uc = usm_allocator(cxt, *dev); + auto us = usm_allocator(cxt, *dev); + vector A(ua); + vector B(ub); + vector C(uc), C_cast_ref(uc); + vector A_ref(us), B_ref(us), C_ref(us); A.resize(stride_a * batch_size); B.resize(stride_b * batch_size); C.resize(stride_c * batch_size); + A_ref.resize(stride_c * batch_size); + B_ref.resize(stride_c * batch_size); C_ref.resize(stride_c * batch_size); + C_cast_ref.resize(stride_c * batch_size); - fp **a_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); - fp **b_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); - fp **c_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); - fp **c_ref_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); + Ta **a_array = (Ta **)oneapi::mkl::malloc_shared(64, sizeof(Ta *) * batch_size, *dev, cxt); + Tb **b_array = (Tb **)oneapi::mkl::malloc_shared(64, sizeof(Tb *) * batch_size, *dev, cxt); + Tc **c_array = (Tc **)oneapi::mkl::malloc_shared(64, sizeof(Tc *) * batch_size, *dev, cxt); + Ts **c_ref_array = (Ts **)oneapi::mkl::malloc_shared(64, sizeof(Ts *) * batch_size, *dev, cxt); if ((a_array == NULL) || (b_array == NULL) || (c_array == NULL) || (c_ref_array == NULL)) { std::cout << "Error cannot allocate arrays of pointers\n"; @@ -153,11 +163,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { stride_b * batch_size, 1, stride_b * batch_size); rand_matrix(C, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_c * batch_size, 1, stride_c * batch_size); + copy_matrix(A, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, + stride_a * batch_size, 1, stride_a * batch_size, A_ref); + copy_matrix(B, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, + stride_b * batch_size, 1, stride_b * batch_size, B_ref); copy_matrix(C, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_c * batch_size, 1, stride_c * batch_size, C_ref); // Call reference GEMM_BATCH_STRIDE. - using fp_ref = typename ref_type_info::type; + using fp_ref = typename ref_type_info::type; int m_ref = (int)m; int n_ref = (int)n; int k_ref = (int)k; @@ -166,12 +180,13 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int ldc_ref = (int)ldc; int batch_size_ref = (int)batch_size; for (i = 0; i < batch_size_ref; i++) { - ::gemm( - convert_to_cblas_layout(layout), convert_to_cblas_trans(transa), - convert_to_cblas_trans(transb), (const int *)&m_ref, (const int *)&n_ref, - (const int *)&k_ref, (const fp_ref *)&alpha, (const fp_ref *)(A.data() + stride_a * i), - (const int *)&lda_ref, (const fp_ref *)(B.data() + stride_b * i), (const int *)&ldb_ref, - (const fp_ref *)&beta, (fp_ref *)(C_ref.data() + stride_c * i), (const int *)&ldc_ref); + ::gemm(convert_to_cblas_layout(layout), convert_to_cblas_trans(transa), + convert_to_cblas_trans(transb), (const int *)&m_ref, (const int *)&n_ref, + (const int *)&k_ref, (const fp_ref *)&alpha, + (const fp_ref *)(A_ref.data() + stride_a * i), (const int *)&lda_ref, + (const fp_ref *)(B_ref.data() + stride_b * i), (const int *)&ldb_ref, + (const fp_ref *)&beta, (fp_ref *)(C_ref.data() + stride_c * i), + (const int *)&ldc_ref); } // Call DPC++ GEMM_BATCH_STRIDE. @@ -191,7 +206,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { break; default: break; } - done.wait(); + done.wait_and_throw(); #else switch (layout) { case oneapi::mkl::layout::col_major: @@ -208,7 +223,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { break; default: break; } - main_queue.wait(); + main_queue.wait_and_throw(); #endif } catch (exception const &e) { @@ -231,8 +246,17 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal_matrix(C, C_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, - 1, stride_c * batch_size, 10 * k, std::cout); + int tol_scalar = 10; + int error_mag = tol_scalar * k; + if (std::is_same_v) + error_mag = 1; + + for (size_t i = 0; i < C_ref.size(); ++i) { + C_cast_ref[i] = C_ref[i]; + } + bool good = check_almost_equal_matrix(C, C_cast_ref, oneapi::mkl::layout::col_major, + stride_c * batch_size, 1, stride_c * batch_size, + error_mag, std::cout); oneapi::mkl::free_shared(a_array, cxt); oneapi::mkl::free_shared(b_array, cxt); @@ -246,29 +270,49 @@ class GemmBatchStrideUsmTests : public ::testing::TestWithParam> {}; TEST_P(GemmBatchStrideUsmTests, RealHalfPrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideUsmTests, HalfHalfFloatPrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideUsmTests, Int8Int8SinglePrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideUsmTests, Int8Int8Int32Precision) { + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideUsmTests, RealSinglePrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP( + (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideUsmTests, RealDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP(( + test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideUsmTests, ComplexSinglePrecision) { EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, std::complex>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideUsmTests, ComplexDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, + std::complex>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } INSTANTIATE_TEST_SUITE_P(GemmBatchStrideUsmTestSuite, GemmBatchStrideUsmTests, diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index 58963a889..a651f9ae3 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -47,7 +47,7 @@ extern std::vector devices; namespace { -template +template int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { // Catch asynchronous exceptions. auto exception_handler = [](exception_list exceptions) { @@ -76,8 +76,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { auto uatranspose = usm_allocator(cxt, *dev); vector transa(uatranspose), transb(uatranspose); - auto uafp = usm_allocator(cxt, *dev); - vector alpha(uafp), beta(uafp); + auto uaTs = usm_allocator(cxt, *dev); + vector alpha(uaTs), beta(uaTs); m.resize(group_count); n.resize(group_count); @@ -104,13 +104,10 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { lda[i] = std::max(m[i], k[i]); ldb[i] = std::max(n[i], k[i]); ldc[i] = std::max(m[i], n[i]); - alpha[i] = rand_scalar(); - beta[i] = rand_scalar(); - if ((std::is_same::value) || (std::is_same::value)) { - transa[i] = (oneapi::mkl::transpose)(std::rand() % 2); - transb[i] = (oneapi::mkl::transpose)(std::rand() % 2); - } - else { + alpha[i] = rand_scalar(); + beta[i] = rand_scalar(); + if ((std::is_same>::value) || + (std::is_same>::value)) { tmp = std::rand() % 3; if (tmp == 2) transa[i] = oneapi::mkl::transpose::conjtrans; @@ -122,15 +119,27 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { else transb[i] = (oneapi::mkl::transpose)tmp; } + else { + transa[i] = (oneapi::mkl::transpose)(std::rand() % 2); + transb[i] = (oneapi::mkl::transpose)(std::rand() % 2); + } total_batch_count += group_size[i]; } - auto uafpp = usm_allocator(cxt, *dev); - vector a_array(uafpp), b_array(uafpp), c_array(uafpp), - c_ref_array(uafpp); + auto uaTap = usm_allocator(cxt, *dev); + auto uaTbp = usm_allocator(cxt, *dev); + auto uaTcp = usm_allocator(cxt, *dev); + auto uaTsp = usm_allocator(cxt, *dev); + vector a_array(uaTap); + vector b_array(uaTbp); + vector c_array(uaTcp), c_cast_ref_array(uaTcp); + vector a_ref_array(uaTsp), b_ref_array(uaTsp), c_ref_array(uaTsp); a_array.resize(total_batch_count); b_array.resize(total_batch_count); c_array.resize(total_batch_count); + a_ref_array.resize(total_batch_count); + b_ref_array.resize(total_batch_count); + c_cast_ref_array.resize(total_batch_count); c_ref_array.resize(total_batch_count); idx = 0; @@ -149,13 +158,19 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { default: break; } for (j = 0; j < group_size[i]; j++) { - a_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_a, *dev, cxt); - b_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_b, *dev, cxt); - c_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_c, *dev, cxt); - c_ref_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_c, *dev, cxt); + a_array[idx] = (Ta *)oneapi::mkl::malloc_shared(64, sizeof(Ta) * size_a, *dev, cxt); + b_array[idx] = (Tb *)oneapi::mkl::malloc_shared(64, sizeof(Tb) * size_b, *dev, cxt); + c_array[idx] = (Tc *)oneapi::mkl::malloc_shared(64, sizeof(Tc) * size_c, *dev, cxt); + a_ref_array[idx] = (Ts *)oneapi::mkl::malloc_shared(64, sizeof(Ts) * size_a, *dev, cxt); + b_ref_array[idx] = (Ts *)oneapi::mkl::malloc_shared(64, sizeof(Ts) * size_b, *dev, cxt); + c_cast_ref_array[idx] = + (Tc *)oneapi::mkl::malloc_shared(64, sizeof(Tc) * size_c, *dev, cxt); + c_ref_array[idx] = (Ts *)oneapi::mkl::malloc_shared(64, sizeof(Ts) * size_c, *dev, cxt); rand_matrix(a_array[idx], layout, transa[i], m[i], k[i], lda[i]); rand_matrix(b_array[idx], layout, transb[i], k[i], n[i], ldb[i]); rand_matrix(c_array[idx], layout, oneapi::mkl::transpose::nontrans, m[i], n[i], ldc[i]); + copy_matrix(a_array[idx], layout, transa[i], m[i], k[i], lda[i], a_ref_array[idx]); + copy_matrix(b_array[idx], layout, transb[i], k[i], n[i], ldb[i], b_ref_array[idx]); copy_matrix(c_array[idx], layout, oneapi::mkl::transpose::nontrans, m[i], n[i], ldc[i], c_ref_array[idx]); idx++; @@ -163,7 +178,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { } // Call reference GEMM_BATCH. - using fp_ref = typename ref_type_info::type; + using fp_ref = typename ref_type_info::type; int *m_ref = (int *)oneapi::mkl::aligned_alloc(64, sizeof(int) * group_count); int *n_ref = (int *)oneapi::mkl::aligned_alloc(64, sizeof(int) * group_count); int *k_ref = (int *)oneapi::mkl::aligned_alloc(64, sizeof(int) * group_count); @@ -196,6 +211,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { oneapi::mkl::free_shared(a_array[idx], cxt); oneapi::mkl::free_shared(b_array[idx], cxt); oneapi::mkl::free_shared(c_array[idx], cxt); + oneapi::mkl::free_shared(a_ref_array[idx], cxt); + oneapi::mkl::free_shared(b_ref_array[idx], cxt); + oneapi::mkl::free_shared(c_cast_ref_array[idx], cxt); oneapi::mkl::free_shared(c_ref_array[idx], cxt); idx++; } @@ -216,9 +234,10 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { for (j = 0; j < group_size_ref[i]; j++) { ::gemm(convert_to_cblas_layout(layout), transa_ref[i], transb_ref[i], (const int *)&m_ref[i], (const int *)&n_ref[i], (const int *)&k_ref[i], - (const fp_ref *)&alpha[i], (const fp_ref *)a_array[idx], - (const int *)&lda_ref[i], (const fp_ref *)b_array[idx], (const int *)&ldb_ref[i], - (const fp_ref *)&beta[i], (fp_ref *)c_ref_array[idx], (const int *)&ldc_ref[i]); + (const fp_ref *)&alpha[i], (const fp_ref *)a_ref_array[idx], + (const int *)&lda_ref[i], (const fp_ref *)b_ref_array[idx], + (const int *)&ldb_ref[i], (const fp_ref *)&beta[i], (fp_ref *)c_ref_array[idx], + (const int *)&ldc_ref[i]); idx++; } } @@ -231,37 +250,37 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::gemm_batch( main_queue, &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], (const fp **)&b_array[0], &ldb[0], &beta[0], + (const Ta **)&a_array[0], &lda[0], (const Tb **)&b_array[0], &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, &group_size[0], dependencies); break; case oneapi::mkl::layout::row_major: done = oneapi::mkl::blas::row_major::gemm_batch( main_queue, &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], (const fp **)&b_array[0], &ldb[0], &beta[0], + (const Ta **)&a_array[0], &lda[0], (const Tb **)&b_array[0], &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, &group_size[0], dependencies); break; default: break; } - done.wait(); + done.wait_and_throw(); #else switch (layout) { case oneapi::mkl::layout::col_major: TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm_batch, &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], (const fp **)&b_array[0], + (const Ta **)&a_array[0], &lda[0], (const Tb **)&b_array[0], &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, &group_size[0], dependencies); break; case oneapi::mkl::layout::row_major: TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm_batch, &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], (const fp **)&b_array[0], + (const Ta **)&a_array[0], &lda[0], (const Ta **)&b_array[0], &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, &group_size[0], dependencies); break; default: break; } - main_queue.wait(); + main_queue.wait_and_throw(); #endif } catch (exception const &e) { @@ -286,6 +305,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { oneapi::mkl::free_shared(a_array[idx], cxt); oneapi::mkl::free_shared(b_array[idx], cxt); oneapi::mkl::free_shared(c_array[idx], cxt); + oneapi::mkl::free_shared(a_ref_array[idx], cxt); + oneapi::mkl::free_shared(b_ref_array[idx], cxt); + oneapi::mkl::free_shared(c_cast_ref_array[idx], cxt); oneapi::mkl::free_shared(c_ref_array[idx], cxt); idx++; } @@ -299,11 +321,19 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { bool good = true; // Compare the results of reference implementation and DPC++ implementation. + int tol_scalar = 10; + idx = 0; for (i = 0; i < group_count; i++) { for (j = 0; j < group_size[i]; j++) { - good = good && check_equal_matrix(c_array[idx], c_ref_array[idx], layout, m[i], n[i], - ldc[i], 10 * k[i], std::cout); + int error_mag = tol_scalar * k[i]; + if (std::is_same_v) + error_mag = 1; + + copy_matrix(c_ref_array[idx], layout, oneapi::mkl::transpose::nontrans, m[i], n[i], + ldc[i], c_cast_ref_array[idx]); + good = good && check_almost_equal_matrix(c_array[idx], c_cast_ref_array[idx], layout, + m[i], n[i], ldc[i], error_mag, std::cout); idx++; } } @@ -322,6 +352,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { oneapi::mkl::free_shared(a_array[idx], cxt); oneapi::mkl::free_shared(b_array[idx], cxt); oneapi::mkl::free_shared(c_array[idx], cxt); + oneapi::mkl::free_shared(a_ref_array[idx], cxt); + oneapi::mkl::free_shared(b_ref_array[idx], cxt); + oneapi::mkl::free_shared(c_cast_ref_array[idx], cxt); oneapi::mkl::free_shared(c_ref_array[idx], cxt); idx++; } @@ -334,29 +367,49 @@ class GemmBatchUsmTests : public ::testing::TestWithParam> {}; TEST_P(GemmBatchUsmTests, RealHalfPrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchUsmTests, HalfHalfFloatPrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchUsmTests, Int8Int8SinglePrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchUsmTests, Int8Int8Int32Precision) { + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchUsmTests, RealSinglePrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP( + (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchUsmTests, RealDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP(( + test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchUsmTests, ComplexSinglePrecision) { EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, std::complex>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchUsmTests, ComplexDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, + std::complex>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } INSTANTIATE_TEST_SUITE_P(GemmBatchUsmTestSuite, GemmBatchUsmTests, diff --git a/tests/unit_tests/blas/include/test_common.hpp b/tests/unit_tests/blas/include/test_common.hpp index d8c7029b1..5d607991e 100644 --- a/tests/unit_tests/blas/include/test_common.hpp +++ b/tests/unit_tests/blas/include/test_common.hpp @@ -249,21 +249,21 @@ void copy_matrix(vec_src &src, oneapi::mkl::layout layout, oneapi::mkl::transpos } } -template -void copy_matrix(fp *src, oneapi::mkl::layout layout, oneapi::mkl::transpose trans, int m, int n, - int ld, fp *dest) { +template +void copy_matrix(fp_src *src, oneapi::mkl::layout layout, oneapi::mkl::transpose trans, int m, + int n, int ld, fp_dst *dest) { if (((trans == oneapi::mkl::transpose::nontrans) && (layout == oneapi::mkl::layout::col_major)) || ((trans != oneapi::mkl::transpose::nontrans) && (layout == oneapi::mkl::layout::row_major))) { for (int j = 0; j < n; j++) for (int i = 0; i < m; i++) - dest[i + j * ld] = (fp)src[i + j * ld]; + dest[i + j * ld] = (fp_dst)src[i + j * ld]; } else { for (int i = 0; i < m; i++) for (int j = 0; j < n; j++) - dest[j + i * ld] = (fp)src[j + i * ld]; + dest[j + i * ld] = (fp_dst)src[j + i * ld]; } } @@ -655,4 +655,57 @@ bool check_equal_trsm_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, i return good; } +// Helper for using std::result_of for evalutation operator[] return type +template +struct access_index { + auto operator()(T M) { + return M[0]; + } +}; + +// Helper for checking if a matrix/vector/accessor structure returns an integral type +template +constexpr bool is_matrix_type_integral() { + return std::is_integral_v< + std::remove_reference_t(T)>::type>>; +} + +template +typename std::enable_if::value, bool>::type check_almost_equal_int( + fp x, fp x_ref, int error_mag) { + return (std::abs(x - x_ref) <= error_mag); +} + +template +bool check_almost_equal_matrix_int(Ta &M, Tb &M_ref, oneapi::mkl::layout layout, int m, int n, + int ld, int error_mag, std::ostream &out) { + static_assert(is_matrix_type_integral() && is_matrix_type_integral()); + bool good = true; + int idx, count = 0; + for (int j = 0; j < n; j++) { + for (int i = 0; i < m; i++) { + idx = (layout == oneapi::mkl::layout::col_major) ? i + j * ld : j + i * ld; + if (!check_almost_equal_int(M[idx], M_ref[idx], error_mag)) { + out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx] + << " vs. Reference " << M_ref[idx] << std::endl; + good = false; + count++; + if (count > MAX_NUM_PRINT) + return good; + } + } + } + + return good; +} + +template +bool check_almost_equal_matrix(Ta &M, Tb &M_ref, oneapi::mkl::layout layout, int m, int n, int ld, + int error_mag, std::ostream &out) { + // Only call if returned dtype is integral + if constexpr (is_matrix_type_integral() && is_matrix_type_integral()) + return check_almost_equal_matrix_int(M, M_ref, layout, m, n, ld, error_mag, out); + return check_equal_matrix(M, M_ref, layout, m, n, ld, error_mag, out); +} + #endif /* header guard */