Skip to content

Commit

Permalink
spirv_new: fix test_decorate to use the device's default rounding mod…
Browse files Browse the repository at this point in the history
…e instead of CL_HALF_RTE for half conversion.
  • Loading branch information
cycheng committed Jun 20, 2024
1 parent 2b26643 commit c030617
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions test_conformance/spirv_new/test_decorate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ static inline Ti generate_saturated_rhs_input(RandomSeed &seed)
}

template <typename Ti, typename Tl, typename To>
static inline To compute_saturated_output(Ti lhs, Ti rhs)
static inline To compute_saturated_output(Ti lhs, Ti rhs,
cl_half_rounding_mode half_rounding)
{
constexpr auto loVal = std::numeric_limits<To>::min();
constexpr auto hiVal = std::numeric_limits<To>::max();
Expand All @@ -226,10 +227,10 @@ static inline To compute_saturated_output(Ti lhs, Ti rhs)
cl_float f = cl_half_to_float(lhs) * cl_half_to_float(rhs);

// Quantize to fp16:
f = cl_half_to_float(cl_half_from_float(f, CL_HALF_RTE));
f = cl_half_to_float(cl_half_from_float(f, half_rounding));

To val = (To)std::min<float>(std::max<float>(f, loVal), hiVal);
if (isnan(cl_half_from_float(rhs, CL_HALF_RTE)))
if (isnan(cl_half_from_float(rhs, half_rounding)))
{
val = 0;
}
Expand All @@ -246,6 +247,26 @@ static inline To compute_saturated_output(Ti lhs, Ti rhs)
return val;
}

static cl_half_rounding_mode get_half_rounding_mode(cl_device_id deviceID)
{
const cl_device_fp_config fpConfigHalf =
get_default_rounding_mode(deviceID, CL_DEVICE_HALF_FP_CONFIG);

if (fpConfigHalf == CL_FP_ROUND_TO_NEAREST)
{
return CL_HALF_RTE;
}
else if (fpConfigHalf == CL_FP_ROUND_TO_ZERO)
{
return CL_HALF_RTZ;
}
else
{
log_error("Error while acquiring half rounding mode");
}
return CL_HALF_RTE;
}

template <typename Ti, typename Tl, typename To>
int verify_saturated_results(cl_device_id deviceID, cl_context context,
cl_command_queue queue, const char *kname,
Expand Down Expand Up @@ -305,7 +326,8 @@ int verify_saturated_results(cl_device_id deviceID, cl_context context,

for (int i = 0; i < num; i++)
{
To val = compute_saturated_output<Ti, Tl, To>(h_lhs[i], h_rhs[i]);
To val = compute_saturated_output<Ti, Tl, To>(h_lhs[i], h_rhs[i],
get_half_rounding_mode(deviceID));

if (val != h_res[i])
{
Expand Down

0 comments on commit c030617

Please sign in to comment.