Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Fix sycl::vec::convert<> to allow conversion to and from sycl::vec of bfloat16 type to that of other data types #14105

Merged
merged 25 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0aa7a9a
Add copy constructor
uditagarwal97 Apr 18, 2024
f2a1dc2
Merge branch 'sycl' of https://github.com/uditagarwal97/llvm into sycl
uditagarwal97 May 28, 2024
361eea7
Merge branch 'sycl' of https://github.com/uditagarwal97/llvm into sycl
uditagarwal97 May 31, 2024
48a8574
Add vector overloads on ConvertBFloat16ToFINTEL and ConvertFToBFloat1…
uditagarwal97 Jun 6, 2024
6bce35d
Fix test case
uditagarwal97 Jun 6, 2024
8d8295e
Fix tests; Address reviews.
uditagarwal97 Jun 7, 2024
91cd730
Fix formatting
uditagarwal97 Jun 7, 2024
20580de
Merge branch 'bf16tof' into opt_math_builtins
uditagarwal97 Jun 7, 2024
d76bc66
Fix conversion between vec<float> <--> vec<bfloat16>.
uditagarwal97 Jun 9, 2024
3b7826e
Call libdevice primitives instead of sprirv ones
uditagarwal97 Jun 11, 2024
5608861
Merge remote-tracking branch 'upstream/sycl' into opt_math_builtins
uditagarwal97 Jun 12, 2024
402073d
Fix formatting
uditagarwal97 Jun 12, 2024
fc6aa6a
Simplify convert for float and BF16
uditagarwal97 Jun 12, 2024
c60ec69
Merge remote-tracking branch 'upstream/sycl' into opt_math_builtins
uditagarwal97 Jun 14, 2024
7651a30
fix for older vec implementation
uditagarwal97 Jun 14, 2024
439e03a
Remove redundant statement
uditagarwal97 Jun 14, 2024
1aa0304
Re-enable e2e BF16 vec test for older vec implementation.
uditagarwal97 Jun 14, 2024
98588cd
Fix test failure and add assert
uditagarwal97 Jun 16, 2024
a264ee5
Fix formatting
uditagarwal97 Jun 16, 2024
8a6caf1
Add comment in static assert.
uditagarwal97 Jun 17, 2024
85a33f8
Fix build error
uditagarwal97 Jun 17, 2024
e5ef19a
Support all rounding modes for BF16 <--> float conversion.
uditagarwal97 Jun 18, 2024
fbeb2db
Fix formatting
uditagarwal97 Jun 18, 2024
a9df444
Don't emit __imf_ functions on non-intel hardwares
uditagarwal97 Jun 20, 2024
cc288cb
Address review. Fix build error.
uditagarwal97 Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sycl/include/sycl/detail/generic_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <sycl/half_type.hpp> // for BIsRepresentationT
#include <sycl/multi_ptr.hpp> // for multi_ptr, address_spa...

#include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16 storage type.

#include <cstddef> // for byte
#include <cstdint> // for uint8_t
#include <limits> // for numeric_limits
Expand Down Expand Up @@ -386,7 +388,13 @@ template <typename T> auto convertToOpenCLType(T &&x) {
static_assert(sizeof(OpenCLType) == sizeof(T));
return static_cast<OpenCLType>(x);
} else if constexpr (is_bfloat16_v<no_ref>) {
// On host, don't interpret BF16 as uint16.
#ifdef __SYCL_DEVICE_ONLY__
using OpenCLType = sycl::ext::oneapi::detail::Bfloat16StorageT;
return sycl::bit_cast<OpenCLType>(x);
#else
return std::forward<T>(x);
#endif
} else if constexpr (std::is_floating_point_v<no_ref>) {
static_assert(std::is_same_v<no_ref, float> ||
std::is_same_v<no_ref, double>,
Expand Down
307 changes: 306 additions & 1 deletion sycl/include/sycl/detail/vector_convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,100 @@
#include <sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
#include <sycl/exception.hpp> // for errc

#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16

#ifndef __SYCL_DEVICE_ONLY__
#include <cfenv> // for fesetround, fegetround
#endif

#include <type_traits>

// Enable on only intel devices.
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
extern "C" {
// For converting BF16 to other types.
extern __DPCPP_SYCL_EXTERNAL float __imf_bfloat162float(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned int __imf_bfloat162uint_rd(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned int __imf_bfloat162uint_rn(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned int __imf_bfloat162uint_ru(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned int __imf_bfloat162uint_rz(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned short
__imf_bfloat162ushort_rd(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned short
__imf_bfloat162ushort_rn(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned short
__imf_bfloat162ushort_ru(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned short
__imf_bfloat162ushort_rz(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned long long
__imf_bfloat162ull_rd(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned long long
__imf_bfloat162ull_rn(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned long long
__imf_bfloat162ull_ru(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned long long
__imf_bfloat162ull_rz(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL int __imf_bfloat162int_rd(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL int __imf_bfloat162int_rn(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL int __imf_bfloat162int_ru(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL int __imf_bfloat162int_rz(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL short __imf_bfloat162short_rd(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL short __imf_bfloat162short_rn(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL short __imf_bfloat162short_ru(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL short __imf_bfloat162short_rz(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL long long __imf_bfloat162ll_rd(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL long long __imf_bfloat162ll_rn(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL long long __imf_bfloat162ll_ru(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL long long __imf_bfloat162ll_rz(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL short __imf_bfloat16_as_short(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned short
__imf_bfloat16_as_ushort(uint16_t x);

// For converting other types to BF16.
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_float2bfloat16(float x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_float2bfloat16_rd(float x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_float2bfloat16_rn(float x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_float2bfloat16_ru(float x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_float2bfloat16_rz(float x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ushort2bfloat16_rd(unsigned short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ushort2bfloat16_rn(unsigned short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ushort2bfloat16_ru(unsigned short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ushort2bfloat16_rz(unsigned short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_uint2bfloat16_rd(unsigned int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_uint2bfloat16_rn(unsigned int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_uint2bfloat16_ru(unsigned int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_uint2bfloat16_rz(unsigned int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ull2bfloat16_rd(unsigned long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ull2bfloat16_rn(unsigned long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ull2bfloat16_ru(unsigned long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ull2bfloat16_rz(unsigned long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_short2bfloat16_rd(short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_short2bfloat16_rn(short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_short2bfloat16_ru(short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_short2bfloat16_rz(short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_int2bfloat16_rd(int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_int2bfloat16_rn(int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_int2bfloat16_ru(int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_int2bfloat16_rz(int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_ll2bfloat16_rd(long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_ll2bfloat16_rn(long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_ll2bfloat16_ru(long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_ll2bfloat16_rz(long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_double2bfloat16(double x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_short_as_bfloat16(short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ushort_as_bfloat16(unsigned short x);
}
#endif // __SYCL_DEVICE_ONLY__ && (defined(__SPIR__) || defined(__SPIRV__))

namespace sycl {

enum class rounding_mode { automatic = 0, rte = 1, rtz = 2, rtp = 3, rtn = 4 };
Expand All @@ -81,6 +169,10 @@ inline double trunc(double);
#endif
namespace detail {

template <typename FromT, typename ToT, sycl::rounding_mode RoundingMode,
int VecSize, typename NativeFromT, typename NativeToT>
NativeToT convertImpl(NativeFromT);

template <typename T, typename R>
using is_sint_to_sint =
std::bool_constant<is_sigeninteger_v<T> && is_sigeninteger_v<R>>;
Expand Down Expand Up @@ -123,6 +215,8 @@ using is_float_to_float =
std::bool_constant<detail::is_floating_point<T>::value &&
detail::is_floating_point<R>::value>;

using bfloat16 = sycl::ext::oneapi::bfloat16;

#ifndef __SYCL_DEVICE_ONLY__
template <typename From, typename To, int VecSize,
typename Enable = std::enable_if_t<VecSize == 1>>
Expand Down Expand Up @@ -196,8 +290,29 @@ template <typename From, typename To, int VecSize,
To ConvertFToU(From Value) {
return ConvertFToS<From, To, VecSize, Enable, roundingMode>(Value);
}
#else

template <typename NativeToT, sycl::rounding_mode RoundingMode>
inline NativeToT ConvertFromBF16Scalar(bfloat16 val) {
// On host, NativeBF16T is bfloat16. Convert BF16 to float losslessly.
float fval = static_cast<float>(val);

if constexpr (std::is_same_v<NativeToT, float>)
return fval;
else
// Convert float to the desired type.
return convertImpl<float, NativeToT, RoundingMode, 1, float, NativeToT>(
fval);
}

template <typename NativeFromT, sycl::rounding_mode RoundingMode>
bfloat16 ConvertToBF16Scalar(NativeFromT val) {

constexpr int rm = static_cast<int>(RoundingMode);
return sycl::ext::oneapi::detail::ConvertToBfloat16::
getBfloat16WithRoundingMode<NativeFromT, rm>(val);
}

#else
// Bunch of helpers to "specialize" each template for its own destination type
// and vector size.

Expand Down Expand Up @@ -498,8 +613,190 @@ __SYCL_FLOAT_FLOAT_CONVERT_FOR_TYPE(double)
#undef __SYCL_FLOAT_FLOAT_CONVERT
#undef __SYCL_FLOAT_FLOAT_CONVERT_FOR_TYPE

template <typename NativeBFT, typename NativeFloatT, int VecSize>
inline NativeFloatT ConvertBF16ToFVec(NativeBFT vec) {
bfloat16 *src = sycl::bit_cast<bfloat16 *>(&vec);

// OpenCL vector of 3 elements is aligned to 4 multiplied by
// the size of data type.
constexpr int AdjustedSize = (VecSize == 3) ? 4 : VecSize;
float dst[AdjustedSize];
sycl::ext::oneapi::detail::BF16VecToFloatVec<VecSize>(src, dst);

return sycl::bit_cast<NativeFloatT>(dst);
}

template <typename NativeFloatT, typename NativeBFT, int VecSize>
inline NativeBFT ConvertFToBF16Vec(NativeFloatT vec) {
float *src = sycl::bit_cast<float *>(&vec);

// OpenCL vector of 3 elements is aligned to 4 multiplied by
// the size of data type.
constexpr int AdjustedSize = (VecSize == 3) ? 4 : VecSize;
bfloat16 dst[AdjustedSize];

sycl::ext::oneapi::detail::FloatVecToBF16Vec<VecSize>(src, dst);
return sycl::bit_cast<NativeBFT>(dst);
}

/* Emit _imf_* funcs only on Intel hardware. */
#if defined(__SPIR__) || defined(__SPIRV__)
#define EXPAND_BF16_ROUNDING_MODE(type, type_str, rmode, rmode_str) \
template <typename NativeToT, sycl::rounding_mode RoundingMode> \
std::enable_if_t<(std::is_same_v<NativeToT, type> && RoundingMode == rmode), \
NativeToT> \
ConvertFromBF16Scalar(uint16_t val) { \
return __imf_bfloat162##type_str##_##rmode_str(val); \
} \
template <typename NativeFromT, sycl::rounding_mode RoundingMode> \
std::enable_if_t< \
(std::is_same_v<NativeFromT, type> && RoundingMode == rmode), uint16_t> \
ConvertToBF16Scalar(NativeFromT val) { \
return __imf_##type_str##2bfloat16_##rmode_str(val); \
}

#else // __SYCL_DEVICE_ONLY__ && (defined(__SPIR__) || defined(__SPIRV__))
/* On non-Intel HWs, convert BF16 to float (losslessly) and convert float*/ /* to
the desired type. */
#define EXPAND_BF16_ROUNDING_MODE(type, type_str, rmode, rmode_str) \
template <typename NativeToT, sycl::rounding_mode RoundingMode> \
std::enable_if_t<(std::is_same_v<NativeToT, type> && RoundingMode == rmode), \
NativeToT> \
ConvertFromBF16Scalar(uint16_t val) { \
bfloat16 bfval = sycl::bit_cast<bfloat16>(val); \
float fval = static_cast<float>(bfval); \
return convertImpl<fval, NativeToT, RoundingMode, 1, float, NativeToT>( \
fval); \
} \
template <typename NativeFromT, sycl::rounding_mode RoundingMode> \
std::enable_if_t< \
(std::is_same_v<NativeFromT, type> && RoundingMode == rmode), uint16_t> \
ConvertToBF16Scalar(NativeFromT val) { \
constexpr int rm = static_cast<int>(RoundingMode); \
bfloat16 bfval = sycl::ext::oneapi::detail::ConvertToBfloat16:: \
getBfloat16WithRoundingMode<NativeFromT, rm>(val); \
return sycl::bit_cast<uint16_t>(bfval); \
}
#endif // __SYCL_DEVICE_ONLY__ && (defined(__SPIR__) || defined(__SPIRV__))

#define EXPAND_BF16_TYPE(type, type_str) \
EXPAND_BF16_ROUNDING_MODE(type, type_str, sycl::rounding_mode::automatic, \
rn) \
EXPAND_BF16_ROUNDING_MODE(type, type_str, sycl::rounding_mode::rte, rn) \
EXPAND_BF16_ROUNDING_MODE(type, type_str, sycl::rounding_mode::rtp, ru) \
EXPAND_BF16_ROUNDING_MODE(type, type_str, sycl::rounding_mode::rtn, rd) \
EXPAND_BF16_ROUNDING_MODE(type, type_str, sycl::rounding_mode::rtz, rz)

EXPAND_BF16_TYPE(uint, uint)
EXPAND_BF16_TYPE(int, int)
EXPAND_BF16_TYPE(ushort, ushort)
EXPAND_BF16_TYPE(short, short)
EXPAND_BF16_TYPE(long, ll)
EXPAND_BF16_TYPE(unsigned long long, ull)

#undef EXPAND_BF16_TYPE
#undef EXPAND_BF16_ROUNDING_MODE

// Mapping from BF16 to float is 1:1, lossless, so we accept all
// rounding modes.
template <typename NativeToT, sycl::rounding_mode RoundingMode>
std::enable_if_t<std::is_same_v<NativeToT, float>, NativeToT>
ConvertFromBF16Scalar(uint16_t val) {
bfloat16 bfval = sycl::bit_cast<bfloat16>(val);
return static_cast<float>(bfval);
}

// Conversion of double to BF16 is lossless, so we accept all
// rounding modes.
template <typename NativeFromT, sycl::rounding_mode RoundingMode>
std::enable_if_t<std::is_same_v<NativeFromT, double>, uint16_t>
ConvertToBF16Scalar(NativeFromT val) {
#if defined(__SPIR__) || defined(__SPIRV__)
return __imf_double2bfloat16(val);
#else
constexpr int rm = static_cast<int>(RoundingMode);
bfloat16 bfval =
sycl::ext::oneapi::detail::ConvertToBfloat16::getBfloat16WithRoundingMode<
NativeFromT, rm>(val);
return sycl::bit_cast<uint16_t>(bfval);
#endif
}

template <typename NativeFromT, sycl::rounding_mode RoundingMode>
std::enable_if_t<std::is_same_v<NativeFromT, float>, uint16_t>
ConvertToBF16Scalar(NativeFromT val) {

#if defined(__SPIR__) || defined(__SPIRV__)
if constexpr (RoundingMode == sycl::rounding_mode::automatic ||
RoundingMode == sycl::rounding_mode::rte)
return __imf_float2bfloat16_rn(val);
else if constexpr (RoundingMode == sycl::rounding_mode::rtp)
return __imf_float2bfloat16_ru(val);
else if constexpr (RoundingMode == sycl::rounding_mode::rtn)
return __imf_float2bfloat16_rd(val);
else if constexpr (RoundingMode == sycl::rounding_mode::rtz)
return __imf_float2bfloat16_rz(val);
else
static_assert(false, "Invalid rounding mode.");
#else
constexpr int rm = static_cast<int>(RoundingMode);
bfloat16 bfval =
sycl::ext::oneapi::detail::ConvertToBfloat16::getBfloat16WithRoundingMode<
float, rm>(val);
return sycl::bit_cast<uint16_t>(bfval);
#endif
}

#endif // __SYCL_DEVICE_ONLY__

// Wrapper function for scalar and vector conversions from BF16 type.
template <typename ToT, typename NativeFromT, typename NativeToT,
sycl::rounding_mode RoundingMode, int VecSize>
NativeToT ConvertFromBF16(NativeFromT val) {
#ifdef __SYCL_DEVICE_ONLY__
// Use vector conversion from BF16 to float for all rounding modes.
if constexpr (std::is_same_v<ToT, float> && VecSize > 1)
return ConvertBF16ToFVec<NativeFromT, NativeToT, VecSize>(val);
else
#endif
// For VecSize > 1. Only for device.
if constexpr (VecSize > 1) {
NativeToT retval;
for (int i = 0; i < VecSize; i++) {
retval[i] = ConvertFromBF16Scalar<ToT, RoundingMode>(val[i]);
}
return retval;
}
// For VecSize == 1.
else
return ConvertFromBF16Scalar<NativeToT, RoundingMode>(val);
}

// Wrapper function for scalar and vector conversions to BF16 type.
template <typename FromT, typename NativeFromT, typename NativeToT,
sycl::rounding_mode RoundingMode, int VecSize>
NativeToT ConvertToBF16(NativeFromT val) {
#ifdef __SYCL_DEVICE_ONLY__
// Use vector conversion to BF16 from float for RNE rounding mode.
if constexpr (std::is_same_v<FromT, float> && VecSize > 1 &&
(RoundingMode == sycl::rounding_mode::automatic ||
RoundingMode == sycl::rounding_mode::rte))
return ConvertFToBF16Vec<NativeFromT, NativeToT, VecSize>(val);
else
#endif
// For VecSize > 1. Only for device.
if constexpr (VecSize > 1) {
NativeToT retval;
for (int i = 0; i < VecSize; i++) {
retval[i] = ConvertToBF16Scalar<FromT, RoundingMode>(val[i]);
}
return retval;
}
// For VecSize == 1.
else
return ConvertToBF16Scalar<NativeFromT, RoundingMode>(val);
}

/// Entry point helper for all kinds of converts between scalars and vectors, it
/// dispatches to a right function depending on source and destination types.
///
Expand Down Expand Up @@ -537,6 +834,14 @@ NativeToT convertImpl(NativeFromT Value) {
else if constexpr (is_float_to_float<FromT, ToT>::value)
return FConvert<NativeFromT, NativeToT, VecSize, ElemTy, RoundingMode>(
Value);
// BF16 conversion to other types.
else if constexpr (std::is_same_v<FromT, bfloat16>)
return ConvertFromBF16<ToT, NativeFromT, NativeToT, RoundingMode, VecSize>(
Value);
// conversion from other types to BF16.
else if constexpr (std::is_same_v<ToT, bfloat16>)
return ConvertToBF16<FromT, NativeFromT, NativeToT, RoundingMode, VecSize>(
Value);
else if constexpr (is_float_to_sint<FromT, ToT>::value)
return ConvertFToS<NativeFromT, NativeToT, VecSize, ElemTy, RoundingMode>(
Value);
Expand Down
Loading
Loading