diff --git a/test_conformance/spirv_new/test_decorate.cpp b/test_conformance/spirv_new/test_decorate.cpp index 3a1f422af..b85419300 100644 --- a/test_conformance/spirv_new/test_decorate.cpp +++ b/test_conformance/spirv_new/test_decorate.cpp @@ -216,7 +216,8 @@ static inline Ti generate_saturated_rhs_input(RandomSeed &seed) } template -static inline To compute_saturated_output(Ti lhs, Ti rhs) +static inline To compute_saturated_output(Ti lhs, Ti rhs, + cl_half_rounding_mode half_rounding) { constexpr auto loVal = std::numeric_limits::min(); constexpr auto hiVal = std::numeric_limits::max(); @@ -226,7 +227,7 @@ static inline To compute_saturated_output(Ti lhs, Ti rhs) cl_float f = cl_half_to_float(lhs) * cl_half_to_float(rhs); // Quantize to fp16: - f = cl_half_to_float(cl_half_from_float(f, CL_HALF_RTE)); + f = cl_half_to_float(cl_half_from_float(f, half_rounding)); To val = (To)std::min(std::max(f, loVal), hiVal); if (isnan(cl_half_to_float(rhs))) @@ -246,6 +247,26 @@ static inline To compute_saturated_output(Ti lhs, Ti rhs) return val; } +static cl_half_rounding_mode get_half_rounding_mode(cl_device_id deviceID) +{ + const cl_device_fp_config fpConfigHalf = + get_default_rounding_mode(deviceID, CL_DEVICE_HALF_FP_CONFIG); + + if (fpConfigHalf == CL_FP_ROUND_TO_NEAREST) + { + return CL_HALF_RTE; + } + else if (fpConfigHalf == CL_FP_ROUND_TO_ZERO) + { + return CL_HALF_RTZ; + } + else + { + log_error("Error while acquiring half rounding mode"); + } + return CL_HALF_RTE; +} + template int verify_saturated_results(cl_device_id deviceID, cl_context context, cl_command_queue queue, const char *kname, @@ -303,9 +324,16 @@ int verify_saturated_results(cl_device_id deviceID, cl_context context, err = clEnqueueReadBuffer(queue, res, CL_TRUE, 0, out_bytes, &h_res[0], 0, NULL, NULL); SPIRV_CHECK_ERROR(err, "Failed to read to output"); + cl_half_rounding_mode half_rounding = CL_HALF_RTE; + if (std::is_same::value) + { + half_rounding = get_half_rounding_mode(deviceID); + } + for (int i = 0; i < num; i++) { - To val = compute_saturated_output(h_lhs[i], h_rhs[i]); + To val = compute_saturated_output(h_lhs[i], h_rhs[i], + half_rounding); if (val != h_res[i]) {