Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Enabled more data types for oneMKL's gemm_batch API #8236

Merged
merged 4 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 0 additions & 30 deletions ggml/src/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4116,10 +4116,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
queue_ptr main_stream = ctx.stream();;

bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_cuda ||
main_stream->get_backend() == sycl::backend::ext_oneapi_hip;


void * src0_ddq = src0->data;
sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
float * src1_ddf = (float *) src1->data;
Expand All @@ -4137,15 +4133,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
: src1_f16_alloc.get();

ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
char * dst_t;

dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
if (no_mixed_dtypes) {
cu_compute_type = dpct::library_data_t::real_half;
cu_data_type = dpct::library_data_t::real_half;
}

// dst strides
size_t nbd2 = dst->nb[2];
Expand All @@ -4154,26 +4145,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;

const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f;

const void * alpha = &alpha_f32;
const void * beta = &beta_f32;
if (no_mixed_dtypes) {
alpha = &alpha_f16;
beta = &beta_f16;
}

// TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
// when oneMKL open source supports half, half, float, float: datatypes

dst_t = (char *) dst_ddf;
if (no_mixed_dtypes) {
dst_t = (char *) dst_f16.alloc(ne_dst);

nbd2 /= sizeof(float) / sizeof(sycl::half);
nbd3 /= sizeof(float) / sizeof(sycl::half);
}

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
Expand Down Expand Up @@ -4235,11 +4210,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
cu_compute_type)));
}

if (no_mixed_dtypes) {
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-sycl/dpct/helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2426,6 +2426,7 @@ namespace dpct
b, ldb, beta, c, ldc, batch_size);
break;
}
#endif
AidanBeltonS marked this conversation as resolved.
Show resolved Hide resolved
case detail::get_type_combination_id(
library_data_t::real_int8, library_data_t::real_int8,
library_data_t::real_int32, library_data_t::real_int32):
Expand Down Expand Up @@ -2458,7 +2459,6 @@ namespace dpct
batch_size);
break;
}
#endif
case detail::get_type_combination_id(
library_data_t::real_half, library_data_t::real_half,
library_data_t::real_half, library_data_t::real_float):
Expand Down Expand Up @@ -2595,6 +2595,7 @@ namespace dpct
stride_c, batch_size);
break;
}
#endif
case detail::get_type_combination_id(
library_data_t::real_int8, library_data_t::real_int8,
library_data_t::real_int32, library_data_t::real_int32):
Expand Down Expand Up @@ -2623,7 +2624,6 @@ namespace dpct
beta, c, ldc, stride_c, batch_size);
break;
}
#endif
case detail::get_type_combination_id(
library_data_t::real_half, library_data_t::real_half,
library_data_t::real_half, library_data_t::real_float):
Expand Down
Loading