Skip to content

Commit

Permalink
spirv_new: fix test_decorate to use the device's default rounding (#1987
Browse files Browse the repository at this point in the history
)

The verification code assumes the hardware uses CL_HALF_RTE, which
causes a mismatch computation results when the hardware uses RTZ. Fix to
use the hardware's default rounding mode.
  • Loading branch information
cycheng committed Jul 2, 2024
1 parent 340b7c9 commit 1cd0266
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions test_conformance/spirv_new/test_decorate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ static inline Ti generate_saturated_rhs_input(RandomSeed &seed)
}

template <typename Ti, typename Tl, typename To>
static inline To compute_saturated_output(Ti lhs, Ti rhs)
static inline To compute_saturated_output(Ti lhs, Ti rhs,
cl_half_rounding_mode half_rounding)
{
constexpr auto loVal = std::numeric_limits<To>::min();
constexpr auto hiVal = std::numeric_limits<To>::max();
Expand All @@ -226,7 +227,7 @@ static inline To compute_saturated_output(Ti lhs, Ti rhs)
cl_float f = cl_half_to_float(lhs) * cl_half_to_float(rhs);

// Quantize to fp16:
f = cl_half_to_float(cl_half_from_float(f, CL_HALF_RTE));
f = cl_half_to_float(cl_half_from_float(f, half_rounding));

To val = (To)std::min<float>(std::max<float>(f, loVal), hiVal);
if (isnan(cl_half_to_float(rhs)))
Expand All @@ -246,6 +247,26 @@ static inline To compute_saturated_output(Ti lhs, Ti rhs)
return val;
}

static cl_half_rounding_mode get_half_rounding_mode(cl_device_id deviceID)
{
const cl_device_fp_config fpConfigHalf =
get_default_rounding_mode(deviceID, CL_DEVICE_HALF_FP_CONFIG);

if (fpConfigHalf == CL_FP_ROUND_TO_NEAREST)
{
return CL_HALF_RTE;
}
else if (fpConfigHalf == CL_FP_ROUND_TO_ZERO)
{
return CL_HALF_RTZ;
}
else
{
log_error("Error while acquiring half rounding mode");
}
return CL_HALF_RTE;
}

template <typename Ti, typename Tl, typename To>
int verify_saturated_results(cl_device_id deviceID, cl_context context,
cl_command_queue queue, const char *kname,
Expand Down Expand Up @@ -303,9 +324,16 @@ int verify_saturated_results(cl_device_id deviceID, cl_context context,
err = clEnqueueReadBuffer(queue, res, CL_TRUE, 0, out_bytes, &h_res[0], 0, NULL, NULL);
SPIRV_CHECK_ERROR(err, "Failed to read to output");

cl_half_rounding_mode half_rounding = CL_HALF_RTE;
if (std::is_same<Ti, cl_half>::value)
{
half_rounding = get_half_rounding_mode(deviceID);
}

for (int i = 0; i < num; i++)
{
To val = compute_saturated_output<Ti, Tl, To>(h_lhs[i], h_rhs[i]);
To val = compute_saturated_output<Ti, Tl, To>(h_lhs[i], h_rhs[i],
half_rounding);

if (val != h_res[i])
{
Expand Down

0 comments on commit 1cd0266

Please sign in to comment.