diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index f92d7bd0e..009bb9541 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -190,7 +190,6 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) -GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, @@ -212,6 +211,7 @@ GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::com dtype_string() + "," + dtype_string() + ">"); \ } +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER @@ -632,7 +632,6 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, @@ -654,6 +653,7 @@ GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std: dtype_string() + "," + dtype_string() + ">"); \ } +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM @@ -714,7 +714,6 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) -GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_BATCH_LAUNCHER_USM(float, float, float, float) GEMM_BATCH_LAUNCHER_USM(double, double, double, double) GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, @@ -736,6 +735,7 @@ GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex dtype_string() + "," + dtype_string() + ">"); \ } +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_BATCH_LAUNCHER_USM