-
Notifications
You must be signed in to change notification settings - Fork 738
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SYCL] Fix sycl::vec::convert<>
to allow conversion to and from sycl::vec
of bfloat16
type to that of other data types
#14105
Conversation
sycl::vec::convert<>
to allow conversion between sycl::vec
of float
and bfloat16
typesycl::vec::convert<>
to allow conversion between sycl::vec
of float
and bfloat16
types
Converted this PR back to draft to: |
detail::convertImpl<T, R, roundingMode, 1, OpenCLT, OpenCLR>( | ||
vec_data<DataT>::get(getValue(I))))); | ||
// For float -> bf16. | ||
if constexpr (isFloatToBF16Conv) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
detail::convertImpl<>
expects OpenCL type as input and returns the OpenCL type corresponding to convertT
. In the case of BF16, the OpenCL type will be uint16
for device and bfloat16
on host.
However, currently, vec_data<bfloat16>::get()
returns bfloat16
value on both device and host.
As a workaround to this, I've added explicit if constexpr
for BF16 <--> float
conversion. A proper fix would require more if conditions/if defs, which IMO, is not worth it since we will anyway be replacing vector.hpp
with vector_poreview.hpp
soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, we can just refactor the entire convertImpl
, if you have a good plan/picture for that.
sycl/include/sycl/vector_preview.hpp
Outdated
std::is_same_v<DataT, bfloat16> && std::is_same_v<convertT, float>; | ||
if constexpr (isFloatToBF16Conv || isBF16ToFloatConv) { | ||
static_assert(roundingMode == rounding_mode::automatic || | ||
roundingMode == rounding_mode::rte); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a message to this static assert to explicitly say that not all rounding modes are supported for bfloat16
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Fixed in 8a6caf1
template <typename NativeBFT, typename NativeFloatT, int VecSize> | ||
inline NativeFloatT ConvertBF16ToF(NativeBFT val) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can NativeFloatT
be anything other than float
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On host, no. NativeFloatT
is always float
.
// On host, ensure that we don't convert BF16 to uint16 for conversion. | ||
static_assert(std::is_same_v<NativeBFT, sycl::ext::oneapi::bfloat16>); | ||
|
||
return (NativeFloatT)val; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't use C-style casts.
@@ -498,6 +528,51 @@ __SYCL_FLOAT_FLOAT_CONVERT_FOR_TYPE(double) | |||
#undef __SYCL_FLOAT_FLOAT_CONVERT | |||
#undef __SYCL_FLOAT_FLOAT_CONVERT_FOR_TYPE | |||
|
|||
template <typename NativeBFT, typename NativeFloatT, int VecSize> | |||
inline NativeFloatT ConvertBF16ToF(NativeBFT vec) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the scalar case, are we going vec<bf16,1> -> operator[] -> cast_to_ushort->cast back to bf16 -> convert to float
here + in the caller? Do you think it still makes sense after we changed storage type in vec
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that convertImpl
accepts native OpenCL type for device, whether it is uint16
(For vec<bfloat, 1>
) or uint16 ext_vector_type()
(For vec<bfloat, N>
).
I had to do the casts to provide a unified interface for vec::convert
(to use convertImpl
), plus I expect compiler to get rid of these extra casts.
A long term solution, would be to refactor convertImpl
entirely but that is tangential to this PR.
sycl/include/sycl/vector.hpp
Outdated
// Currently, for BF16 <--> float conversion, we only support | ||
// Round-to-even rounding mode. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd expect that bfloat
maps precisely onto float
s, so that direction should "support" all the rounding modes. Am I wrong here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, there can not be a 1:1 mapping between float
and bfloat
as bfloat
has only 8-bit mantissa while float
as 24-bit mantissa. The default rounding mode is RTE(https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_rounding_modes_for_conversions) for floating point to floating point conversion.
detail::convertImpl<T, R, roundingMode, 1, OpenCLT, OpenCLR>( | ||
vec_data<DataT>::get(getValue(I))))); | ||
// For float -> bf16. | ||
if constexpr (isFloatToBF16Conv) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, we can just refactor the entire convertImpl
, if you have a good plan/picture for that.
@cperkinsintel Since @aelovikov-intel is OOO, could you help review this PR?
|
sycl::vec::convert<>
to allow conversion between sycl::vec
of float
and bfloat16
typessycl::vec::convert<>
to allow conversion to and from sycl::vec
of bfloat16
type to that of other data types
roundingMode == SYCLRoundingMode::rte, | ||
"Only automatic/RTE rounding mode is supported for double type."); | ||
return getBFloat16FromDoubleWithRoundingMode(a, roundingMode); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a possibility of other floating types besides float and double? Half? Should there be a std::is_floating_point<T>
clause for the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch. I've added the clause for half
as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, had one question.
@intel/llvm-gatekeepers the PR is ready to be merged! |
Followup and blocked by: #14105 Currently, `vec<bfloat>` math builtins do element-by-element operations. This PR optimize `vec<bfloat>` math builtins by: (1) Converting `vec<bfloat>` to `vec<float>`. (2) Do the operation on `vec<float>` (which uses Spirv built-ins underneath for optimized vector operations). (3) Convert back the return value to `vec<bfloat>`. Look at the beautiful diff in `check_device_code/vector/vector_bf16_builtins.cpp` to visualize the device code generated before and after this optimization.
Follow-up of and blocked by: #14085
After this change:
On host, conversion between
vec<bfloat16>
andvec<float>
will happen element-by-element. While on device, we'll use Spirv intrinsicOpConvertFToBF16INTEL
andOpConvertBF16ToFINTEL
(https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_bfloat16_conversion.asciidoc) for vector conversion.