-
Notifications
You must be signed in to change notification settings - Fork 738
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SYCL] Fix sycl::vec::convert<>
to allow conversion to and from sycl::vec
of bfloat16
type to that of other data types
#14105
Changes from 19 commits
0aa7a9a
f2a1dc2
361eea7
48a8574
6bce35d
8d8295e
91cd730
20580de
d76bc66
3b7826e
5608861
402073d
fc6aa6a
c60ec69
7651a30
439e03a
1aa0304
98588cd
a264ee5
8a6caf1
85a33f8
e5ef19a
fbeb2db
a9df444
cc288cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,8 @@ | |
#include <sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s... | ||
#include <sycl/exception.hpp> // for errc | ||
|
||
#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16 | ||
|
||
#ifndef __SYCL_DEVICE_ONLY__ | ||
#include <cfenv> // for fesetround, fegetround | ||
#endif | ||
|
@@ -123,6 +125,15 @@ using is_float_to_float = | |
std::bool_constant<detail::is_floating_point<T>::value && | ||
detail::is_floating_point<R>::value>; | ||
|
||
using bfloat16 = sycl::ext::oneapi::bfloat16; | ||
template <typename T, typename R> | ||
using is_bf16_to_float = | ||
std::bool_constant<std::is_same_v<T, bfloat16> && std::is_same_v<R, float>>; | ||
|
||
template <typename T, typename R> | ||
using is_float_to_bf16 = | ||
std::bool_constant<std::is_same_v<R, bfloat16> && std::is_same_v<T, float>>; | ||
|
||
#ifndef __SYCL_DEVICE_ONLY__ | ||
template <typename From, typename To, int VecSize, | ||
typename Enable = std::enable_if_t<VecSize == 1>> | ||
|
@@ -196,8 +207,27 @@ template <typename From, typename To, int VecSize, | |
To ConvertFToU(From Value) { | ||
return ConvertFToS<From, To, VecSize, Enable, roundingMode>(Value); | ||
} | ||
#else | ||
|
||
template <typename NativeBFT, typename NativeFloatT, int VecSize> | ||
inline NativeFloatT ConvertBF16ToF(NativeBFT val) { | ||
static_assert(VecSize == 1); | ||
// On host, ensure that we don't convert BF16 to uint16 for conversion. | ||
static_assert(std::is_same_v<NativeBFT, sycl::ext::oneapi::bfloat16>); | ||
|
||
return (NativeFloatT)val; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please don't use C-style casts. |
||
} | ||
|
||
// Create a bfloat16 from float. | ||
template <typename NativeFloatT, typename NativeBFT, int VecSize> | ||
inline NativeBFT ConvertFToBF16(NativeFloatT val) { | ||
static_assert(VecSize == 1); | ||
// On host, ensure that we don't convert BF16 to uint16 for conversion. | ||
static_assert(std::is_same_v<NativeBFT, sycl::ext::oneapi::bfloat16>); | ||
|
||
return NativeBFT(val); | ||
} | ||
|
||
#else | ||
// Bunch of helpers to "specialize" each template for its own destination type | ||
// and vector size. | ||
|
||
|
@@ -498,6 +528,51 @@ __SYCL_FLOAT_FLOAT_CONVERT_FOR_TYPE(double) | |
#undef __SYCL_FLOAT_FLOAT_CONVERT | ||
#undef __SYCL_FLOAT_FLOAT_CONVERT_FOR_TYPE | ||
|
||
template <typename NativeBFT, typename NativeFloatT, int VecSize> | ||
inline NativeFloatT ConvertBF16ToF(NativeBFT vec) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the scalar case, are we going There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is that A long term solution, would be to refactor |
||
if constexpr (VecSize == 1) { | ||
// On device, we interpret bfloat16 as a uint16_t scalar or vector. | ||
static_assert( | ||
std::is_same_v<NativeBFT, sycl::ext::oneapi::detail::Bfloat16StorageT>); | ||
|
||
// Bitcast to BF16 and typecast to float. | ||
bfloat16 convertedBF = sycl::bit_cast<bfloat16>(vec); | ||
return (float)convertedBF; | ||
} else { | ||
bfloat16 *src = sycl::bit_cast<bfloat16 *>(&vec); | ||
|
||
// OpenCL vector of 3 elements is aligned to 4 multiplied by | ||
// the size of data type. | ||
constexpr int AdjustedSize = (VecSize == 3) ? 4 : VecSize; | ||
float dst[AdjustedSize]; | ||
sycl::ext::oneapi::detail::BF16VecToFloatVec<VecSize>(src, dst); | ||
|
||
return sycl::bit_cast<NativeFloatT>(dst); | ||
} | ||
} | ||
|
||
template <typename NativeFloatT, typename NativeBFT, int VecSize> | ||
inline NativeBFT ConvertFToBF16(NativeFloatT vec) { | ||
if constexpr (VecSize == 1) { | ||
// On device, we interpret bfloat16 as a uint16_t scalar or vector. | ||
static_assert( | ||
std::is_same_v<NativeBFT, sycl::ext::oneapi::detail::Bfloat16StorageT>); | ||
|
||
auto bf16Val = bfloat16(vec); | ||
return sycl::bit_cast<NativeBFT>(bf16Val); | ||
} else { | ||
float *src = sycl::bit_cast<float *>(&vec); | ||
|
||
// OpenCL vector of 3 elements is aligned to 4 multiplied by | ||
// the size of data type. | ||
constexpr int AdjustedSize = (VecSize == 3) ? 4 : VecSize; | ||
bfloat16 dst[AdjustedSize]; | ||
|
||
sycl::ext::oneapi::detail::FloatVecToBF16Vec<VecSize>(src, dst); | ||
return sycl::bit_cast<NativeBFT>(dst); | ||
} | ||
} | ||
|
||
#endif // __SYCL_DEVICE_ONLY__ | ||
|
||
/// Entry point helper for all kinds of converts between scalars and vectors, it | ||
|
@@ -537,6 +612,10 @@ NativeToT convertImpl(NativeFromT Value) { | |
else if constexpr (is_float_to_float<FromT, ToT>::value) | ||
return FConvert<NativeFromT, NativeToT, VecSize, ElemTy, RoundingMode>( | ||
Value); | ||
else if constexpr (is_bf16_to_float<FromT, ToT>::value) | ||
return ConvertBF16ToF<NativeFromT, NativeToT, VecSize>(Value); | ||
else if constexpr (is_float_to_bf16<FromT, ToT>::value) | ||
return ConvertFToBF16<NativeFromT, NativeToT, VecSize>(Value); | ||
else if constexpr (is_float_to_sint<FromT, ToT>::value) | ||
return ConvertFToS<NativeFromT, NativeToT, VecSize, ElemTy, RoundingMode>( | ||
Value); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -786,9 +786,25 @@ template <typename Type, int NumElements> class vec { | |
detail::ConvertToOpenCLType_t<vec_data_t<convertT>>>, | ||
vec<convertT, NumElements>> | ||
convert() const { | ||
using bfloat16 = sycl::ext::oneapi::bfloat16; | ||
static_assert(std::is_integral_v<vec_data_t<convertT>> || | ||
detail::is_floating_point<convertT>::value, | ||
detail::is_floating_point<convertT>::value || | ||
// Conversion to BF16 available only for float. | ||
(std::is_same_v<convertT, bfloat16> && | ||
std::is_same_v<DataT, float>), | ||
"Unsupported convertT"); | ||
|
||
// Currently, for BF16 <--> float conversion, we only support | ||
// Round-to-even rounding mode. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd expect that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, there can not be a 1:1 mapping between |
||
constexpr bool isFloatToBF16Conv = | ||
std::is_same_v<convertT, bfloat16> && std::is_same_v<DataT, float>; | ||
constexpr bool isBF16ToFloatConv = | ||
std::is_same_v<DataT, bfloat16> && std::is_same_v<convertT, float>; | ||
if constexpr (isFloatToBF16Conv || isBF16ToFloatConv) { | ||
static_assert(roundingMode == rounding_mode::automatic || | ||
roundingMode == rounding_mode::rte); | ||
} | ||
|
||
using T = vec_data_t<DataT>; | ||
using R = vec_data_t<convertT>; | ||
using OpenCLT = detail::ConvertToOpenCLType_t<T>; | ||
|
@@ -828,10 +844,19 @@ template <typename Type, int NumElements> class vec { | |
{ | ||
// Otherwise, we fallback to per-element conversion: | ||
for (size_t I = 0; I < NumElements; ++I) { | ||
Result.setValue( | ||
I, vec_data<convertT>::get( | ||
detail::convertImpl<T, R, roundingMode, 1, OpenCLT, OpenCLR>( | ||
vec_data<DataT>::get(getValue(I))))); | ||
// For float -> bf16. | ||
if constexpr (isFloatToBF16Conv) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
As a workaround to this, I've added explicit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively, we can just refactor the entire |
||
Result[I] = bfloat16((*this)[I]); | ||
} else | ||
// For bf16 -> float. | ||
if constexpr (isBF16ToFloatConv) { | ||
Result[I] = (float)((*this)[I]); | ||
} else { | ||
Result.setValue(I, vec_data<convertT>::get( | ||
detail::convertImpl<T, R, roundingMode, 1, | ||
OpenCLT, OpenCLR>( | ||
vec_data<DataT>::get(getValue(I))))); | ||
} | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -419,9 +419,25 @@ class vec : public detail::vec_arith<DataT, NumElements> { | |
|
||
using T = ConvertBoolAndByteT<DataT>; | ||
using R = ConvertBoolAndByteT<convertT>; | ||
static_assert(std::is_integral_v<R> || detail::is_floating_point<R>::value, | ||
using bfloat16 = sycl::ext::oneapi::bfloat16; | ||
static_assert(std::is_integral_v<R> || | ||
detail::is_floating_point<R>::value || | ||
std::is_same_v<R, bfloat16>, | ||
"Unsupported convertT"); | ||
|
||
{ | ||
// Currently, for BF16 <--> float conversion, we only support | ||
// Round-to-even rounding mode. | ||
constexpr bool isFloatToBF16Conv = | ||
std::is_same_v<convertT, bfloat16> && std::is_same_v<DataT, float>; | ||
constexpr bool isBF16ToFloatConv = | ||
std::is_same_v<DataT, bfloat16> && std::is_same_v<convertT, float>; | ||
if constexpr (isFloatToBF16Conv || isBF16ToFloatConv) { | ||
static_assert(roundingMode == rounding_mode::automatic || | ||
roundingMode == rounding_mode::rte); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a message to this static assert to explicitly say that not all rounding modes are supported for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. Fixed in 8a6caf1 |
||
} | ||
} | ||
|
||
using OpenCLT = detail::ConvertToOpenCLType_t<T>; | ||
using OpenCLR = detail::ConvertToOpenCLType_t<R>; | ||
vec<convertT, NumElements> Result; | ||
|
@@ -479,11 +495,16 @@ class vec : public detail::vec_arith<DataT, NumElements> { | |
auto val = | ||
detail::convertImpl<T, R, roundingMode, 1, OpenCLT, OpenCLR>( | ||
getValue(I)); | ||
Result[I] = static_cast<convertT>(val); | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
// On device, we interpret BF16 as uint16. | ||
if constexpr (std::is_same_v<convertT, bfloat16>) | ||
Result[I] = sycl::bit_cast<convertT>(val); | ||
else | ||
#endif | ||
Result[I] = static_cast<convertT>(val); | ||
} | ||
} | ||
} | ||
|
||
return Result; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can
NativeFloatT
be anything other thanfloat
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On host, no.
NativeFloatT
is alwaysfloat
.