From c030617752bf4d0fe2e86c09406f253615fb9243 Mon Sep 17 00:00:00 2001 From: Chuang-Yu Cheng Date: Thu, 20 Jun 2024 07:33:11 +0100 Subject: [PATCH] spirv_new: fix test_decorate to use the device's default rounding mode instead of CL_HALF_RTE for half conversion. --- test_conformance/spirv_new/test_decorate.cpp | 30 +++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/test_conformance/spirv_new/test_decorate.cpp b/test_conformance/spirv_new/test_decorate.cpp index 4c2f82b540..0717b0ddfc 100644 --- a/test_conformance/spirv_new/test_decorate.cpp +++ b/test_conformance/spirv_new/test_decorate.cpp @@ -216,7 +216,8 @@ static inline Ti generate_saturated_rhs_input(RandomSeed &seed) } template -static inline To compute_saturated_output(Ti lhs, Ti rhs) +static inline To compute_saturated_output(Ti lhs, Ti rhs, + cl_half_rounding_mode half_rounding) { constexpr auto loVal = std::numeric_limits::min(); constexpr auto hiVal = std::numeric_limits::max(); @@ -226,10 +227,10 @@ static inline To compute_saturated_output(Ti lhs, Ti rhs) cl_float f = cl_half_to_float(lhs) * cl_half_to_float(rhs); // Quantize to fp16: - f = cl_half_to_float(cl_half_from_float(f, CL_HALF_RTE)); + f = cl_half_to_float(cl_half_from_float(f, half_rounding)); To val = (To)std::min(std::max(f, loVal), hiVal); - if (isnan(cl_half_from_float(rhs, CL_HALF_RTE))) + if (isnan(cl_half_from_float(rhs, half_rounding))) { val = 0; } @@ -246,6 +247,26 @@ static inline To compute_saturated_output(Ti lhs, Ti rhs) return val; } +static cl_half_rounding_mode get_half_rounding_mode(cl_device_id deviceID) +{ + const cl_device_fp_config fpConfigHalf = + get_default_rounding_mode(deviceID, CL_DEVICE_HALF_FP_CONFIG); + + if (fpConfigHalf == CL_FP_ROUND_TO_NEAREST) + { + return CL_HALF_RTE; + } + else if (fpConfigHalf == CL_FP_ROUND_TO_ZERO) + { + return CL_HALF_RTZ; + } + else + { + log_error("Error while acquiring half rounding mode"); + } + return CL_HALF_RTE; +} + template int verify_saturated_results(cl_device_id deviceID, cl_context context, cl_command_queue queue, const char *kname, @@ -305,7 +326,8 @@ int verify_saturated_results(cl_device_id deviceID, cl_context context, for (int i = 0; i < num; i++) { - To val = compute_saturated_output(h_lhs[i], h_rhs[i]); + To val = compute_saturated_output(h_lhs[i], h_rhs[i], + get_half_rounding_mode(deviceID)); if (val != h_res[i]) {