From 18096ee9e9539e3cb80032d562f729529c9e2f83 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 1 Mar 2024 12:57:49 -0800 Subject: [PATCH] add a way to generate tf32 dpas currently (disabled by default) --- samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl index d759252..6660340 100644 --- a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl @@ -13,6 +13,7 @@ float emu_sub_group_tf32_tf32_matrix_mad_k8(float a, float8 b, float acc) { float res = acc; +#if 1 res = fma(sub_group_broadcast(a, 0), b.s0, res); res = fma(sub_group_broadcast(a, 1), b.s1, res); res = fma(sub_group_broadcast(a, 2), b.s2, res); @@ -21,6 +22,12 @@ float emu_sub_group_tf32_tf32_matrix_mad_k8(float a, float8 b, float acc) res = fma(sub_group_broadcast(a, 5), b.s5, res); res = fma(sub_group_broadcast(a, 6), b.s6, res); res = fma(sub_group_broadcast(a, 7), b.s7, res); +#else +float __attribute__((overloadable)) intel_sub_group_tf32_tf32_matrix_mad_k8_f32(short a, int8 b, float acc); + uint a_ui = as_uint(sub_group_shuffle(a, get_sub_group_local_id() / 2)); + short aData = get_sub_group_local_id() % 2 ? as_short2(a_ui).hi : as_short2(a_ui).lo; + res = intel_sub_group_tf32_tf32_matrix_mad_k8_f32(aData, as_int8(b), res); +#endif return res; }