Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

spirv_new: fix test_decorate to use the device's default rounding #1987

Merged
merged 4 commits into from
Jul 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions test_conformance/spirv_new/test_decorate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ static inline Ti generate_saturated_rhs_input(RandomSeed &seed)
}

template <typename Ti, typename Tl, typename To>
static inline To compute_saturated_output(Ti lhs, Ti rhs)
static inline To compute_saturated_output(Ti lhs, Ti rhs,
cl_half_rounding_mode half_rounding)
{
constexpr auto loVal = std::numeric_limits<To>::min();
constexpr auto hiVal = std::numeric_limits<To>::max();
Expand All @@ -226,7 +227,7 @@ static inline To compute_saturated_output(Ti lhs, Ti rhs)
cl_float f = cl_half_to_float(lhs) * cl_half_to_float(rhs);

// Quantize to fp16:
f = cl_half_to_float(cl_half_from_float(f, CL_HALF_RTE));
f = cl_half_to_float(cl_half_from_float(f, half_rounding));

To val = (To)std::min<float>(std::max<float>(f, loVal), hiVal);
if (isnan(cl_half_to_float(rhs)))
Expand All @@ -246,6 +247,26 @@ static inline To compute_saturated_output(Ti lhs, Ti rhs)
return val;
}

static cl_half_rounding_mode get_half_rounding_mode(cl_device_id deviceID)
{
const cl_device_fp_config fpConfigHalf =
get_default_rounding_mode(deviceID, CL_DEVICE_HALF_FP_CONFIG);

if (fpConfigHalf == CL_FP_ROUND_TO_NEAREST)
{
return CL_HALF_RTE;
}
else if (fpConfigHalf == CL_FP_ROUND_TO_ZERO)
{
return CL_HALF_RTZ;
}
else
{
log_error("Error while acquiring half rounding mode");
}
return CL_HALF_RTE;
}

template <typename Ti, typename Tl, typename To>
int verify_saturated_results(cl_device_id deviceID, cl_context context,
cl_command_queue queue, const char *kname,
Expand Down Expand Up @@ -303,9 +324,16 @@ int verify_saturated_results(cl_device_id deviceID, cl_context context,
err = clEnqueueReadBuffer(queue, res, CL_TRUE, 0, out_bytes, &h_res[0], 0, NULL, NULL);
SPIRV_CHECK_ERROR(err, "Failed to read to output");

cl_half_rounding_mode half_rounding = CL_HALF_RTE;
if (std::is_same<Ti, cl_half>::value)
{
half_rounding = get_half_rounding_mode(deviceID);
}

for (int i = 0; i < num; i++)
{
To val = compute_saturated_output<Ti, Tl, To>(h_lhs[i], h_rhs[i]);
To val = compute_saturated_output<Ti, Tl, To>(h_lhs[i], h_rhs[i],
half_rounding);

if (val != h_res[i])
{
Expand Down
Loading