Skip to content

Commit

Permalink
Add gemm_batch dtypes to netlib (unimplemented)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aidan committed Apr 3, 2024
1 parent b98adf0 commit e436f1c
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 0 deletions.
105 changes: 105 additions & 0 deletions include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,39 @@ void gemm_batch(backend_selector<backend::netlib> selector, transpose transa, tr
ldc, stride_c, batch_size);
}

void gemm_batch(backend_selector<backend::netlib> selector, transpose transa, transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
sycl::buffer<sycl::half, 1> &a, std::int64_t lda, std::int64_t stride_a,
sycl::buffer<sycl::half, 1> &b, std::int64_t ldb, std::int64_t stride_b, float beta,
sycl::buffer<float, 1> &c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size) {
oneapi::mkl::blas::netlib::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k,
alpha, a, lda, stride_a, b, ldb, stride_b, beta, c,
ldc, stride_c, batch_size);
}

void gemm_batch(backend_selector<backend::netlib> selector, transpose transa, transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
sycl::buffer<std::int8_t, 1> &a, std::int64_t lda, std::int64_t stride_a,
sycl::buffer<std::int8_t, 1> &b, std::int64_t ldb, std::int64_t stride_b,
float beta, sycl::buffer<float, 1> &c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size) {
oneapi::mkl::blas::netlib::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k,
alpha, a, lda, stride_a, b, ldb, stride_b, beta, c,
ldc, stride_c, batch_size);
}

void gemm_batch(backend_selector<backend::netlib> selector, transpose transa, transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
sycl::buffer<std::int8_t, 1> &a, std::int64_t lda, std::int64_t stride_a,
sycl::buffer<std::int8_t, 1> &b, std::int64_t ldb, std::int64_t stride_b,
float beta, sycl::buffer<std::int32_t, 1> &c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size) {
oneapi::mkl::blas::netlib::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k,
alpha, a, lda, stride_a, b, ldb, stride_b, beta, c,
ldc, stride_c, batch_size);
}

void syrk(backend_selector<backend::netlib> selector, uplo upper_lower, transpose trans,
std::int64_t n, std::int64_t k, float alpha, sycl::buffer<float, 1> &a,
std::int64_t lda, float beta, sycl::buffer<float, 1> &c, std::int64_t ldc) {
Expand Down Expand Up @@ -2672,6 +2705,42 @@ sycl::event gemm_batch(backend_selector<backend::netlib> selector, transpose *tr
return done;
}

sycl::event gemm_batch(backend_selector<backend::netlib> selector, transpose *transa,
transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k,
float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b,
std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc,
std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies) {
auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch(
selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
group_count, group_size, dependencies);
return done;
}

sycl::event gemm_batch(backend_selector<backend::netlib> selector, transpose *transa,
transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k,
float *alpha, const std::int8_t **a, std::int64_t *lda,
const std::int8_t **b, std::int64_t *ldb, float *beta, float **c,
std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies) {
auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch(
selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
group_count, group_size, dependencies);
return done;
}

sycl::event gemm_batch(backend_selector<backend::netlib> selector, transpose *transa,
transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k,
float *alpha, const std::int8_t **a, std::int64_t *lda,
const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c,
std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies) {
auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch(
selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
group_count, group_size, dependencies);
return done;
}

sycl::event gemm_batch(backend_selector<backend::netlib> selector, transpose transa,
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
float alpha, const float *a, std::int64_t lda, std::int64_t stride_a,
Expand Down Expand Up @@ -2739,6 +2808,42 @@ sycl::event gemm_batch(backend_selector<backend::netlib> selector, transpose tra
return done;
}

sycl::event gemm_batch(backend_selector<backend::netlib> selector, transpose transa,
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a,
const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta,
float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size,
const std::vector<sycl::event> &dependencies) {
auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch(
selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
beta, c, ldc, stride_c, batch_size, dependencies);
return done;
}

sycl::event gemm_batch(backend_selector<backend::netlib> selector, transpose transa,
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a,
const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta,
float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size,
const std::vector<sycl::event> &dependencies) {
auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch(
selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
beta, c, ldc, stride_c, batch_size, dependencies);
return done;
}

sycl::event gemm_batch(backend_selector<backend::netlib> selector, transpose transa,
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a,
const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta,
std::int32_t *c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size, const std::vector<sycl::event> &dependencies) {
auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch(
selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
beta, c, ldc, stride_c, batch_size, dependencies);
return done;
}

sycl::event spmv(backend_selector<backend::netlib> selector, uplo upper_lower, std::int64_t n,
float alpha, const float *a, const float *x, std::int64_t incx, float beta,
float *y, std::int64_t incy,
Expand Down
117 changes: 117 additions & 0 deletions src/blas/backends/netlib/netlib_batch.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,45 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t
#endif
}

void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n,
int64_t k, float alpha, sycl::buffer<sycl::half, 1> &a, int64_t lda,
int64_t stride_a, sycl::buffer<sycl::half, 1> &b, int64_t ldb, int64_t stride_b,
float beta, sycl::buffer<float, 1> &c, int64_t ldc, int64_t stride_c,
int64_t batch_size) {
#ifdef COLUMN_MAJOR
throw unimplemented("blas", "gemm_batch", "for column_major layout");
#endif
#ifdef ROW_MAJOR
throw unimplemented("blas", "gemm_batch", "for row_major layout");
#endif
}

void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n,
int64_t k, float alpha, sycl::buffer<std::int8_t, 1> &a, int64_t lda,
int64_t stride_a, sycl::buffer<std::int8_t, 1> &b, int64_t ldb, int64_t stride_b,
float beta, sycl::buffer<float, 1> &c, int64_t ldc, int64_t stride_c,
int64_t batch_size) {
#ifdef COLUMN_MAJOR
throw unimplemented("blas", "gemm_batch", "for column_major layout");
#endif
#ifdef ROW_MAJOR
throw unimplemented("blas", "gemm_batch", "for row_major layout");
#endif
}

void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n,
int64_t k, float alpha, sycl::buffer<std::int8_t, 1> &a, int64_t lda,
int64_t stride_a, sycl::buffer<std::int8_t, 1> &b, int64_t ldb, int64_t stride_b,
float beta, sycl::buffer<std::int32_t, 1> &c, int64_t ldc, int64_t stride_c,
int64_t batch_size) {
#ifdef COLUMN_MAJOR
throw unimplemented("blas", "gemm_batch", "for column_major layout");
#endif
#ifdef ROW_MAJOR
throw unimplemented("blas", "gemm_batch", "for row_major layout");
#endif
}

void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans,
diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer<float, 1> &a,
int64_t lda, int64_t stride_a, sycl::buffer<float, 1> &b, int64_t ldb,
Expand Down Expand Up @@ -983,6 +1022,45 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb,
#endif
}

sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m,
int64_t *n, int64_t *k, float *alpha, const sycl::half **a, int64_t *lda,
const sycl::half **b, int64_t *ldb, float *beta, float **c, int64_t *ldc,
int64_t group_count, int64_t *groupsize,
const std::vector<sycl::event> &dependencies) {
#ifdef COLUMN_MAJOR
throw unimplemented("blas", "gemm_batch", "for column_major layout");
#endif
#ifdef ROW_MAJOR
throw unimplemented("blas", "gemm_batch", "for row_major layout");
#endif
}

sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m,
int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda,
const std::int8_t **b, int64_t *ldb, float *beta, float **c, int64_t *ldc,
int64_t group_count, int64_t *groupsize,
const std::vector<sycl::event> &dependencies) {
#ifdef COLUMN_MAJOR
throw unimplemented("blas", "gemm_batch", "for column_major layout");
#endif
#ifdef ROW_MAJOR
throw unimplemented("blas", "gemm_batch", "for row_major layout");
#endif
}

sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m,
int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda,
const std::int8_t **b, int64_t *ldb, float *beta, std::int32_t **c,
int64_t *ldc, int64_t group_count, int64_t *groupsize,
const std::vector<sycl::event> &dependencies) {
#ifdef COLUMN_MAJOR
throw unimplemented("blas", "gemm_batch", "for column_major layout");
#endif
#ifdef ROW_MAJOR
throw unimplemented("blas", "gemm_batch", "for row_major layout");
#endif
}

sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m,
int64_t n, int64_t k, float alpha, const float *a, int64_t lda,
int64_t stride_a, const float *b, int64_t ldb, int64_t stride_b,
Expand Down Expand Up @@ -1052,6 +1130,45 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, i
#endif
}

sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n,
int64_t k, float alpha, const sycl::half *a, int64_t lda, int64_t stride_a,
const sycl::half *b, int64_t ldb, int64_t stride_b, float beta, float *c,
int64_t ldc, int64_t stride_c, int64_t batch_size,
const std::vector<sycl::event> &dependencies) {
#ifdef COLUMN_MAJOR
throw unimplemented("blas", "gemm_batch", "for column_major layout");
#endif
#ifdef ROW_MAJOR
throw unimplemented("blas", "gemm_batch", "for row_major layout");
#endif
}

sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n,
int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a,
const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, float *c,
int64_t ldc, int64_t stride_c, int64_t batch_size,
const std::vector<sycl::event> &dependencies) {
#ifdef COLUMN_MAJOR
throw unimplemented("blas", "gemm_batch", "for column_major layout");
#endif
#ifdef ROW_MAJOR
throw unimplemented("blas", "gemm_batch", "for row_major layout");
#endif
}

sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n,
int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a,
const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta,
std::int32_t *c, int64_t ldc, int64_t stride_c, int64_t batch_size,
const std::vector<sycl::event> &dependencies) {
#ifdef COLUMN_MAJOR
throw unimplemented("blas", "gemm_batch", "for column_major layout");
#endif
#ifdef ROW_MAJOR
throw unimplemented("blas", "gemm_batch", "for row_major layout");
#endif
}

sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower,
transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha,
const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb,
Expand Down

0 comments on commit e436f1c

Please sign in to comment.