Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Fix sycl::vec::convert<> to allow conversion to and from sycl::vec of bfloat16 type to that of other data types #14105

Merged
merged 25 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0aa7a9a
Add copy constructor
uditagarwal97 Apr 18, 2024
f2a1dc2
Merge branch 'sycl' of https://github.com/uditagarwal97/llvm into sycl
uditagarwal97 May 28, 2024
361eea7
Merge branch 'sycl' of https://github.com/uditagarwal97/llvm into sycl
uditagarwal97 May 31, 2024
48a8574
Add vector overloads on ConvertBFloat16ToFINTEL and ConvertFToBFloat1…
uditagarwal97 Jun 6, 2024
6bce35d
Fix test case
uditagarwal97 Jun 6, 2024
8d8295e
Fix tests; Address reviews.
uditagarwal97 Jun 7, 2024
91cd730
Fix formatting
uditagarwal97 Jun 7, 2024
20580de
Merge branch 'bf16tof' into opt_math_builtins
uditagarwal97 Jun 7, 2024
d76bc66
Fix conversion between vec<float> <--> vec<bfloat16>.
uditagarwal97 Jun 9, 2024
3b7826e
Call libdevice primitives instead of sprirv ones
uditagarwal97 Jun 11, 2024
5608861
Merge remote-tracking branch 'upstream/sycl' into opt_math_builtins
uditagarwal97 Jun 12, 2024
402073d
Fix formatting
uditagarwal97 Jun 12, 2024
fc6aa6a
Simplify convert for float and BF16
uditagarwal97 Jun 12, 2024
c60ec69
Merge remote-tracking branch 'upstream/sycl' into opt_math_builtins
uditagarwal97 Jun 14, 2024
7651a30
fix for older vec implementation
uditagarwal97 Jun 14, 2024
439e03a
Remove redundant statement
uditagarwal97 Jun 14, 2024
1aa0304
Re-enable e2e BF16 vec test for older vec implementation.
uditagarwal97 Jun 14, 2024
98588cd
Fix test failure and add assert
uditagarwal97 Jun 16, 2024
a264ee5
Fix formatting
uditagarwal97 Jun 16, 2024
8a6caf1
Add comment in static assert.
uditagarwal97 Jun 17, 2024
85a33f8
Fix build error
uditagarwal97 Jun 17, 2024
e5ef19a
Support all rounding modes for BF16 <--> float conversion.
uditagarwal97 Jun 18, 2024
fbeb2db
Fix formatting
uditagarwal97 Jun 18, 2024
a9df444
Don't emit __imf_ functions on non-intel hardwares
uditagarwal97 Jun 20, 2024
cc288cb
Address review. Fix build error.
uditagarwal97 Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sycl/include/sycl/detail/generic_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <sycl/half_type.hpp> // for BIsRepresentationT
#include <sycl/multi_ptr.hpp> // for multi_ptr, address_spa...

#include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16 storage type.

#include <cstddef> // for byte
#include <cstdint> // for uint8_t
#include <limits> // for numeric_limits
Expand Down Expand Up @@ -386,7 +388,13 @@ template <typename T> auto convertToOpenCLType(T &&x) {
static_assert(sizeof(OpenCLType) == sizeof(T));
return static_cast<OpenCLType>(x);
} else if constexpr (is_bfloat16_v<no_ref>) {
// On host, don't interpret BF16 as uint16.
#ifdef __SYCL_DEVICE_ONLY__
using OpenCLType = sycl::ext::oneapi::detail::Bfloat16StorageT;
return sycl::bit_cast<OpenCLType>(x);
#else
return std::forward<T>(x);
#endif
} else if constexpr (std::is_floating_point_v<no_ref>) {
static_assert(std::is_same_v<no_ref, float> ||
std::is_same_v<no_ref, double>,
Expand Down
81 changes: 80 additions & 1 deletion sycl/include/sycl/detail/vector_convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
#include <sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
#include <sycl/exception.hpp> // for errc

#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16

#ifndef __SYCL_DEVICE_ONLY__
#include <cfenv> // for fesetround, fegetround
#endif
Expand Down Expand Up @@ -123,6 +125,15 @@ using is_float_to_float =
std::bool_constant<detail::is_floating_point<T>::value &&
detail::is_floating_point<R>::value>;

using bfloat16 = sycl::ext::oneapi::bfloat16;
template <typename T, typename R>
using is_bf16_to_float =
std::bool_constant<std::is_same_v<T, bfloat16> && std::is_same_v<R, float>>;

template <typename T, typename R>
using is_float_to_bf16 =
std::bool_constant<std::is_same_v<R, bfloat16> && std::is_same_v<T, float>>;

#ifndef __SYCL_DEVICE_ONLY__
template <typename From, typename To, int VecSize,
typename Enable = std::enable_if_t<VecSize == 1>>
Expand Down Expand Up @@ -196,8 +207,27 @@ template <typename From, typename To, int VecSize,
To ConvertFToU(From Value) {
return ConvertFToS<From, To, VecSize, Enable, roundingMode>(Value);
}
#else

template <typename NativeBFT, typename NativeFloatT, int VecSize>
inline NativeFloatT ConvertBF16ToF(NativeBFT val) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can NativeFloatT be anything other than float?

Copy link
Contributor Author

@uditagarwal97 uditagarwal97 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On host, no. NativeFloatT is always float.

static_assert(VecSize == 1);
// On host, ensure that we don't convert BF16 to uint16 for conversion.
static_assert(std::is_same_v<NativeBFT, sycl::ext::oneapi::bfloat16>);

return (NativeFloatT)val;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't use C-style casts.

}

// Create a bfloat16 from float.
template <typename NativeFloatT, typename NativeBFT, int VecSize>
inline NativeBFT ConvertFToBF16(NativeFloatT val) {
static_assert(VecSize == 1);
// On host, ensure that we don't convert BF16 to uint16 for conversion.
static_assert(std::is_same_v<NativeBFT, sycl::ext::oneapi::bfloat16>);

return NativeBFT(val);
}

#else
// Bunch of helpers to "specialize" each template for its own destination type
// and vector size.

Expand Down Expand Up @@ -498,6 +528,51 @@ __SYCL_FLOAT_FLOAT_CONVERT_FOR_TYPE(double)
#undef __SYCL_FLOAT_FLOAT_CONVERT
#undef __SYCL_FLOAT_FLOAT_CONVERT_FOR_TYPE

template <typename NativeBFT, typename NativeFloatT, int VecSize>
inline NativeFloatT ConvertBF16ToF(NativeBFT vec) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the scalar case, are we going vec<bf16,1> -> operator[] -> cast_to_ushort->cast back to bf16 -> convert to float here + in the caller? Do you think it still makes sense after we changed storage type in vec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that convertImpl accepts native OpenCL type for device, whether it is uint16 (For vec<bfloat, 1>) or uint16 ext_vector_type() (For vec<bfloat, N>).
I had to do the casts to provide a unified interface for vec::convert (to use convertImpl), plus I expect compiler to get rid of these extra casts.

A long term solution, would be to refactor convertImpl entirely but that is tangential to this PR.

if constexpr (VecSize == 1) {
// On device, we interpret bfloat16 as a uint16_t scalar or vector.
static_assert(
std::is_same_v<NativeBFT, sycl::ext::oneapi::detail::Bfloat16StorageT>);

// Bitcast to BF16 and typecast to float.
bfloat16 convertedBF = sycl::bit_cast<bfloat16>(vec);
return (float)convertedBF;
} else {
bfloat16 *src = sycl::bit_cast<bfloat16 *>(&vec);

// OpenCL vector of 3 elements is aligned to 4 multiplied by
// the size of data type.
constexpr int AdjustedSize = (VecSize == 3) ? 4 : VecSize;
float dst[AdjustedSize];
sycl::ext::oneapi::detail::BF16VecToFloatVec<VecSize>(src, dst);

return sycl::bit_cast<NativeFloatT>(dst);
}
}

template <typename NativeFloatT, typename NativeBFT, int VecSize>
inline NativeBFT ConvertFToBF16(NativeFloatT vec) {
if constexpr (VecSize == 1) {
// On device, we interpret bfloat16 as a uint16_t scalar or vector.
static_assert(
std::is_same_v<NativeBFT, sycl::ext::oneapi::detail::Bfloat16StorageT>);

auto bf16Val = bfloat16(vec);
return sycl::bit_cast<NativeBFT>(bf16Val);
} else {
float *src = sycl::bit_cast<float *>(&vec);

// OpenCL vector of 3 elements is aligned to 4 multiplied by
// the size of data type.
constexpr int AdjustedSize = (VecSize == 3) ? 4 : VecSize;
bfloat16 dst[AdjustedSize];

sycl::ext::oneapi::detail::FloatVecToBF16Vec<VecSize>(src, dst);
return sycl::bit_cast<NativeBFT>(dst);
}
}

#endif // __SYCL_DEVICE_ONLY__

/// Entry point helper for all kinds of converts between scalars and vectors, it
Expand Down Expand Up @@ -537,6 +612,10 @@ NativeToT convertImpl(NativeFromT Value) {
else if constexpr (is_float_to_float<FromT, ToT>::value)
return FConvert<NativeFromT, NativeToT, VecSize, ElemTy, RoundingMode>(
Value);
else if constexpr (is_bf16_to_float<FromT, ToT>::value)
return ConvertBF16ToF<NativeFromT, NativeToT, VecSize>(Value);
else if constexpr (is_float_to_bf16<FromT, ToT>::value)
return ConvertFToBF16<NativeFromT, NativeToT, VecSize>(Value);
else if constexpr (is_float_to_sint<FromT, ToT>::value)
return ConvertFToS<NativeFromT, NativeToT, VecSize, ElemTy, RoundingMode>(
Value);
Expand Down
35 changes: 30 additions & 5 deletions sycl/include/sycl/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,25 @@ template <typename Type, int NumElements> class vec {
detail::ConvertToOpenCLType_t<vec_data_t<convertT>>>,
vec<convertT, NumElements>>
convert() const {
using bfloat16 = sycl::ext::oneapi::bfloat16;
static_assert(std::is_integral_v<vec_data_t<convertT>> ||
detail::is_floating_point<convertT>::value,
detail::is_floating_point<convertT>::value ||
// Conversion to BF16 available only for float.
(std::is_same_v<convertT, bfloat16> &&
std::is_same_v<DataT, float>),
"Unsupported convertT");

// Currently, for BF16 <--> float conversion, we only support
// Round-to-even rounding mode.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect that bfloat maps precisely onto floats, so that direction should "support" all the rounding modes. Am I wrong here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, there can not be a 1:1 mapping between float and bfloat as bfloat has only 8-bit mantissa while float as 24-bit mantissa. The default rounding mode is RTE(https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_rounding_modes_for_conversions) for floating point to floating point conversion.

constexpr bool isFloatToBF16Conv =
std::is_same_v<convertT, bfloat16> && std::is_same_v<DataT, float>;
constexpr bool isBF16ToFloatConv =
std::is_same_v<DataT, bfloat16> && std::is_same_v<convertT, float>;
if constexpr (isFloatToBF16Conv || isBF16ToFloatConv) {
static_assert(roundingMode == rounding_mode::automatic ||
roundingMode == rounding_mode::rte);
}

using T = vec_data_t<DataT>;
using R = vec_data_t<convertT>;
using OpenCLT = detail::ConvertToOpenCLType_t<T>;
Expand Down Expand Up @@ -828,10 +844,19 @@ template <typename Type, int NumElements> class vec {
{
// Otherwise, we fallback to per-element conversion:
for (size_t I = 0; I < NumElements; ++I) {
Result.setValue(
I, vec_data<convertT>::get(
detail::convertImpl<T, R, roundingMode, 1, OpenCLT, OpenCLR>(
vec_data<DataT>::get(getValue(I)))));
// For float -> bf16.
if constexpr (isFloatToBF16Conv) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

detail::convertImpl<> expects OpenCL type as input and returns the OpenCL type corresponding to convertT. In the case of BF16, the OpenCL type will be uint16 for device and bfloat16 on host.
However, currently, vec_data<bfloat16>::get() returns bfloat16 value on both device and host.

As a workaround to this, I've added explicit if constexpr for BF16 <--> float conversion. A proper fix would require more if conditions/if defs, which IMO, is not worth it since we will anyway be replacing vector.hpp with vector_poreview.hpp soon.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we can just refactor the entire convertImpl, if you have a good plan/picture for that.

Result[I] = bfloat16((*this)[I]);
} else
// For bf16 -> float.
if constexpr (isBF16ToFloatConv) {
Result[I] = (float)((*this)[I]);
} else {
Result.setValue(I, vec_data<convertT>::get(
detail::convertImpl<T, R, roundingMode, 1,
OpenCLT, OpenCLR>(
vec_data<DataT>::get(getValue(I)))));
}
}
}

Expand Down
27 changes: 24 additions & 3 deletions sycl/include/sycl/vector_preview.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,25 @@ class vec : public detail::vec_arith<DataT, NumElements> {

using T = ConvertBoolAndByteT<DataT>;
using R = ConvertBoolAndByteT<convertT>;
static_assert(std::is_integral_v<R> || detail::is_floating_point<R>::value,
using bfloat16 = sycl::ext::oneapi::bfloat16;
static_assert(std::is_integral_v<R> ||
detail::is_floating_point<R>::value ||
std::is_same_v<R, bfloat16>,
"Unsupported convertT");

{
// Currently, for BF16 <--> float conversion, we only support
// Round-to-even rounding mode.
constexpr bool isFloatToBF16Conv =
std::is_same_v<convertT, bfloat16> && std::is_same_v<DataT, float>;
constexpr bool isBF16ToFloatConv =
std::is_same_v<DataT, bfloat16> && std::is_same_v<convertT, float>;
if constexpr (isFloatToBF16Conv || isBF16ToFloatConv) {
static_assert(roundingMode == rounding_mode::automatic ||
roundingMode == rounding_mode::rte);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a message to this static assert to explicitly say that not all rounding modes are supported for bfloat16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Fixed in 8a6caf1

}
}

using OpenCLT = detail::ConvertToOpenCLType_t<T>;
using OpenCLR = detail::ConvertToOpenCLType_t<R>;
vec<convertT, NumElements> Result;
Expand Down Expand Up @@ -479,11 +495,16 @@ class vec : public detail::vec_arith<DataT, NumElements> {
auto val =
detail::convertImpl<T, R, roundingMode, 1, OpenCLT, OpenCLR>(
getValue(I));
Result[I] = static_cast<convertT>(val);
#ifdef __SYCL_DEVICE_ONLY__
// On device, we interpret BF16 as uint16.
if constexpr (std::is_same_v<convertT, bfloat16>)
Result[I] = sycl::bit_cast<convertT>(val);
else
#endif
Result[I] = static_cast<convertT>(val);
}
}
}

return Result;
}

Expand Down
69 changes: 54 additions & 15 deletions sycl/test-e2e/BFloat16/bfloat16_vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
// RUN: %if preview-breaking-changes-supported %{ %{run} %t2.out %}

#include <sycl/detail/core.hpp>
#include <sycl/ext/oneapi/bfloat16.hpp>
#include <sycl/stream.hpp>

#include <sycl/ext/oneapi/bfloat16.hpp>

constexpr unsigned N =
10; // init plus arithmetic + - * / for vec<1> and vec<2>
14; // init plus arithmetic + - * / plus convert for vec<1> and vec<2>

int main() {

Expand All @@ -46,17 +47,26 @@ int main() {
sycl::vec<T, 1> simple_multiplication = oneA * oneB;
sycl::vec<T, 1> simple_division = oneA / oneB;

// Test bf16 to float vec conversion on host
sycl::vec<float, 1> fConv = init_float.template convert<float>();
// Test float to bf16 vec conversion on host
sycl::vec<T, 1> brev = fConv.template convert<T>();

std::cout << "iniitialization : " << oneA[0] << " float: " << init_float[0] << std::endl;
std::cout << "addition. ref: " << addition_ref0 << " vec: " << simple_addition[0] << std::endl;
std::cout << "subtraction. ref: " << subtraction_ref0 << " vec: " << simple_subtraction[0] << std::endl;
std::cout << "multiplication. ref: " << multiplication_ref0 << " vec: " << simple_multiplication[0] << std::endl;
std::cout << "division. ref: " << division_ref0 << " vec: " << simple_division[0] << std::endl;
std::cout << "float conv. ref: " << (float)init_float[0]<< " vec: " << fConv[0] << std::endl;
std::cout << "bf16 conv. ref: " << init_float[0] << " vec: " << brev[0] << std::endl;

assert(oneA[0] == init_float[0]);
assert(addition_ref0 == simple_addition[0]);
assert(subtraction_ref0 == simple_subtraction[0]);
assert(multiplication_ref0 == simple_multiplication[0]);
assert(division_ref0 == simple_division[0]);
assert((float)init_float[0] == fConv[0]);
assert(brev[0] == init_float[0]);

std::cout << " --- ON DEVICE --- " << std::endl;
sycl::range<1> r(N);
Expand All @@ -72,17 +82,26 @@ int main() {
sycl::vec<T, 1> device_multiplication = oneA * oneB;
sycl::vec<T, 1> device_division = oneA / oneB;

// Test bf16 to float vec conversion on host
sycl::vec<float, 1> fConv = dev_float.template convert<float>();
// Test float to bf16 vec conversion on host
sycl::vec<T, 1> brev = fConv.template convert<T>();

out << "iniitialization : " << oneA[0] << " float: " << dev_float[0] << sycl::endl;
out << "addition. ref: " << addition_ref0 << " vec: " << device_addition[0] << sycl::endl;
out << "subtraction. ref: " << subtraction_ref0 << " vec: " << device_subtraction[0] << sycl::endl;
out << "multiplication. ref: " << multiplication_ref0 << " vec: " << device_multiplication[0] << sycl::endl;
out << "division. ref: " << division_ref0 << " vec: " << device_division[0] << sycl::endl;
out << "float conv. ref: " << (float)dev_float[0] << " vec: " << fConv[0] << sycl::endl;
out << "bf16 conv. ref: " << dev_float[0] << " vec: " << brev[0] << sycl::endl;

acc[0] = (oneA[0] == dev_float[0]);
acc[1] = (addition_ref0 == device_addition[0]);
acc[2] = (subtraction_ref0 == device_subtraction[0]);
acc[3] = (multiplication_ref0 == device_multiplication[0]);
acc[4] = (division_ref0 == device_division[0]);
acc[5] = ((float)dev_float[0] == fConv[0]);
acc[6] = (brev[0] == dev_float[0]);

});
}).wait();
Expand All @@ -105,6 +124,11 @@ int main() {
sycl::vec<T, 2> double_multiplication = twoA * twoB;
sycl::vec<T, 2> double_division = twoA / twoB;

// Test bf16 to float vec conversion on host
sycl::vec<float, 2> fConv2 = double_float.template convert<float>();
// Test float to bf16 vec conversion on host
sycl::vec<T, 2> brev2 = fConv2.template convert<T>();

std::cout << "init ref: " << twoA[0] << " ref1: " << twoA[1] << std::endl;
std::cout << " float0: " << double_float[0] << " float1: " << double_float[1] << std::endl;
std::cout << "+ ref0: " << addition_ref0 << " ref1: " << addition_ref1 << std::endl;
Expand All @@ -115,13 +139,18 @@ int main() {
std::cout << "mul[0]: " << double_multiplication[0] << " mul[1]: " << double_multiplication[1] << std::endl;
std::cout << "/ ref0: " << division_ref0 << " ref1: " << division_ref1 << std::endl;
std::cout << "div[0]: " << double_division[0] << " div[1]: " << double_division[1] << std::endl;

std::cout << "Float convert ref0: " << double_float[0] << " ref1: " << double_float[1] << std::endl;
std::cout << "convert[0]: " << fConv2[0] << " convert[1]: " << fConv2[1] << std::endl;
std::cout << "bf16 convert[0]: " << brev2[0] << " bf16 convert[1]: " << brev2[1] << std::endl;

assert(twoA[0] == double_float[0]); assert(twoA[1] == double_float[1]);
assert(addition_ref0 == double_addition[0]); assert(addition_ref1 == double_addition[1]);
assert(subtraction_ref0 == double_subtraction[0]); assert(subtraction_ref1 == double_subtraction[1]);
assert(multiplication_ref0 == double_multiplication[0]); assert(multiplication_ref1 == double_multiplication[1]);
assert(division_ref0 == double_division[0]); assert(division_ref1 == double_division[1]);

assert(fConv2[0] == (float)double_float[0]); assert(fConv2[1] == (float)double_float[1]);
assert(brev2[0] == double_float[0]); assert(brev2[1] == double_float[1]);

std::cout << " --- ON DEVICE --- " << std::endl;
q.submit([&](sycl::handler &cgh) {
sycl::stream out(2024, 400, cgh);
Expand All @@ -133,6 +162,11 @@ int main() {
sycl::vec<T, 2> device_multiplication = twoA * twoB;
sycl::vec<T, 2> device_division = twoA / twoB;

// Test bf16 to float vec conversion on host
sycl::vec<float, 2> fConv2 = device_float.template convert<float>();
// Test float to bf16 vec conversion on host
sycl::vec<T, 2> brev2 = fConv2.template convert<T>();

out << "init ref: " << twoA[0] << " ref1: " << twoA[1] << sycl::endl;
out << " float0: " << device_float[0] << " float1: " << device_float[1] << sycl::endl;
out << "+ ref0: " << addition_ref0 << " ref1: " << addition_ref1 << sycl::endl;
Expand All @@ -143,21 +177,26 @@ int main() {
out << "mul[0]: " << device_multiplication[0] << " mul[1]: " << device_multiplication[1] << sycl::endl;
out << "/ ref0: " << division_ref0 << " ref1: " << division_ref1 << sycl::endl;
out << "div[0]: " << device_division[0] << " div[1]: " << device_division[1] << sycl::endl;

acc[5] = (twoA[0] == device_float[0]) && (twoA[1] == device_float[1]);
acc[6] = (addition_ref0 == device_addition[0]) && (addition_ref1 == device_addition[1]);
acc[7] = (subtraction_ref0 == device_subtraction[0]) && (subtraction_ref1 == device_subtraction[1]);
acc[8] = (multiplication_ref0 == device_multiplication[0]) && (multiplication_ref1 == device_multiplication[1]);
acc[9] = (division_ref0 == device_division[0]) && (division_ref1 == device_division[1]);

out << "Float convert ref0: " << device_float[0] << " ref1: " << device_float[1] << sycl::endl;
out << "convert[0]: " << fConv2[0] << " convert[1]: " << fConv2[1] << sycl::endl;
out << "bf16 convert[0]: " << brev2[0] << " bf16 convert[1]: " << brev2[1] << sycl::endl;

acc[7] = (twoA[0] == device_float[0]) && (twoA[1] == device_float[1]);
acc[8] = (addition_ref0 == device_addition[0]) && (addition_ref1 == device_addition[1]);
acc[9] = (subtraction_ref0 == device_subtraction[0]) && (subtraction_ref1 == device_subtraction[1]);
acc[10] = (multiplication_ref0 == device_multiplication[0]) && (multiplication_ref1 == device_multiplication[1]);
acc[11] = (division_ref0 == device_division[0]) && (division_ref1 == device_division[1]);
acc[12] = (fConv2[0] == (float)device_float[0]) && (fConv2[1] == (float)device_float[1]);
acc[13] = (brev2[0] == device_float[0]) && (brev2[1] == device_float[1]);
});
}).wait();
// clang-format on

sycl::host_accessor h_acc(buf, sycl::read_only);
for(unsigned i = 0; i < N; i++){
assert(h_acc[i]);
for (unsigned i = 0; i < N; i++) {
assert(h_acc[i]);
}

// clang-format on
return 0;
std::cout << "Test Passed." << std::endl;
return 0;
}
Loading