Skip to content

Commit

Permalink
[SYCL][ESIMD]Limit bfloat16 operators to scalars to enable operations…
Browse files Browse the repository at this point in the history
… with simd vectors (intel#12089)

The purpose of this change is to limit operators defined for bfloat16 to
scalar types to allow arithmetic operations between bfloat16 scalars and
simd vectors. This allows to use simd operators that are defined
separately and support operations between vectors and scalars
  • Loading branch information
fineg74 authored Jan 3, 2024
1 parent a90aaa7 commit 8c92df9
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ inline std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
return O;
}

template <> struct is_esimd_arithmetic_type<bfloat16, void> : std::true_type {};

} // namespace ext::intel::esimd::detail
} // namespace _V1
} // namespace sycl
Expand Down
222 changes: 164 additions & 58 deletions sycl/include/sycl/ext/oneapi/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,69 +132,175 @@ class bfloat16 {
#endif
}

// Increment and decrement operators overloading
bfloat16 &operator+=(const bfloat16 &rhs) {
value = from_float(to_float(value) + to_float(rhs.value));
return *this;
}

bfloat16 &operator-=(const bfloat16 &rhs) {
value = from_float(to_float(value) - to_float(rhs.value));
return *this;
}

bfloat16 &operator*=(const bfloat16 &rhs) {
value = from_float(to_float(value) * to_float(rhs.value));
return *this;
}

bfloat16 &operator/=(const bfloat16 &rhs) {
value = from_float(to_float(value) / to_float(rhs.value));
return *this;
}

// Operator ++, --
bfloat16 &operator++() {
float f = to_float(value);
value = from_float(++f);
return *this;
}

bfloat16 operator++(int) {
bfloat16 ret(*this);
operator++();
return ret;
}

bfloat16 &operator--() {
float f = to_float(value);
value = from_float(--f);
return *this;
}

bfloat16 operator--(int) {
bfloat16 ret(*this);
operator--();
return ret;
}

// Operator +, -, *, /
#define OP(op) \
friend bfloat16 &operator op(bfloat16 &lhs) { \
float f = to_float(lhs.value); \
lhs.value = from_float(op f); \
return lhs; \
} \
friend bfloat16 operator op(bfloat16 &lhs, int) { \
bfloat16 old = lhs; \
operator op(lhs); \
return old; \
}
OP(++)
OP(--)
friend bfloat16 operator op(const bfloat16 lhs, const bfloat16 rhs) { \
return to_float(lhs.value) op to_float(rhs.value); \
} \
friend double operator op(const bfloat16 lhs, const double rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend double operator op(const double lhs, const bfloat16 rhs) { \
return lhs op to_float(rhs.value); \
} \
friend float operator op(const bfloat16 lhs, const float rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend float operator op(const float lhs, const bfloat16 rhs) { \
return lhs op to_float(rhs.value); \
} \
friend bfloat16 operator op(const bfloat16 lhs, const int rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bfloat16 operator op(const int lhs, const bfloat16 rhs) { \
return lhs op to_float(rhs.value); \
} \
friend bfloat16 operator op(const bfloat16 lhs, const long rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bfloat16 operator op(const long lhs, const bfloat16 rhs) { \
return lhs op to_float(rhs.value); \
} \
friend bfloat16 operator op(const bfloat16 lhs, const long long rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bfloat16 operator op(const long long lhs, const bfloat16 rhs) { \
return lhs op to_float(rhs.value); \
} \
friend bfloat16 operator op(const bfloat16 &lhs, const unsigned int &rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bfloat16 operator op(const unsigned int &lhs, const bfloat16 &rhs) { \
return lhs op to_float(rhs.value); \
} \
friend bfloat16 operator op(const bfloat16 &lhs, const unsigned long &rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bfloat16 operator op(const unsigned long &lhs, const bfloat16 &rhs) { \
return lhs op to_float(rhs.value); \
} \
friend bfloat16 operator op(const bfloat16 &lhs, \
const unsigned long long &rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bfloat16 operator op(const unsigned long long &lhs, \
const bfloat16 &rhs) { \
return lhs op to_float(rhs.value); \
}
OP(+)
OP(-)
OP(*)
OP(/)

#undef OP

// Assignment operators overloading
// Operator ==, !=, <, >, <=, >=
#define OP(op) \
friend bfloat16 &operator op(bfloat16 &lhs, const bfloat16 &rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
} \
template <typename T> \
friend bfloat16 &operator op(bfloat16 &lhs, const T &rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
} \
template <typename T> friend T &operator op(T &lhs, const bfloat16 &rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
}
OP(+=)
OP(-=)
OP(*=)
OP(/=)
#undef OP
friend bool operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
return to_float(lhs.value) op to_float(rhs.value); \
} \
friend bool operator op(const bfloat16 &lhs, const double &rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bool operator op(const double &lhs, const bfloat16 &rhs) { \
return lhs op to_float(rhs.value); \
} \
friend bool operator op(const bfloat16 &lhs, const float &rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bool operator op(const float &lhs, const bfloat16 &rhs) { \
return lhs op to_float(rhs.value); \
} \
friend bool operator op(const bfloat16 &lhs, const int &rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bool operator op(const int &lhs, const bfloat16 &rhs) { \
return lhs op to_float(rhs.value); \
} \
friend bool operator op(const bfloat16 &lhs, const long &rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bool operator op(const long &lhs, const bfloat16 &rhs) { \
return lhs op to_float(rhs.value); \
} \
friend bool operator op(const bfloat16 &lhs, const long long &rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bool operator op(const long long &lhs, const bfloat16 &rhs) { \
return lhs op to_float(rhs.value); \
} \
friend bool operator op(const bfloat16 &lhs, const unsigned int &rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bool operator op(const unsigned int &lhs, const bfloat16 &rhs) { \
return lhs op to_float(rhs.value); \
} \
friend bool operator op(const bfloat16 &lhs, const unsigned long &rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bool operator op(const unsigned long &lhs, const bfloat16 &rhs) { \
return lhs op to_float(rhs.value); \
} \
friend bool operator op(const bfloat16 &lhs, \
const unsigned long long &rhs) { \
return to_float(lhs.value) op rhs; \
} \
friend bool operator op(const unsigned long long &lhs, \
const bfloat16 &rhs) { \
return lhs op to_float(rhs.value); \
}
OP(==)
OP(!=)
OP(<)
OP(>)
OP(<=)
OP(>=)

// Binary operators overloading
#define OP(type, op) \
friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
} \
template <typename T> \
friend type operator op(const bfloat16 &lhs, const T &rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
} \
template <typename T> \
friend type operator op(const T &lhs, const bfloat16 &rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
}
OP(bfloat16, +)
OP(bfloat16, -)
OP(bfloat16, *)
OP(bfloat16, /)
OP(bool, ==)
OP(bool, !=)
OP(bool, <)
OP(bool, >)
OP(bool, <=)
OP(bool, >=)
#undef OP

// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,11 @@ int main() {
}

#ifdef USE_BF16
// TODO: Reenable once the issue with bfloat16 is resolved
// Passed &= test<sycl::ext::oneapi::bfloat16>(Q);
Passed &= test<sycl::ext::oneapi::bfloat16>(Q);
#endif
#ifdef USE_TF32
Passed &= test<sycl::ext::intel::experimental::esimd::tfloat32>(Q);
#endif
std::cout << (Passed ? "Passed\n" : "FAILED\n");
return Passed ? 0 : 1;
}
}
100 changes: 100 additions & 0 deletions sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out
//==- bfloat16_vector_plus_scalar.cpp - Test for bfloat16 operators ------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "../esimd_test_utils.hpp"
#include <iostream>
#include <sycl/ext/intel/esimd.hpp>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::intel::esimd;
using namespace sycl::ext::intel::experimental::esimd;

template <typename T> ESIMD_NOINLINE bool test(queue Q) {
std::cout << "Testing T=" << esimd_test::type_name<T>() << "...\n";

constexpr int N = 8;

constexpr int NumOps = 4;
constexpr int CSize = NumOps * N;

T *Mem = malloc_shared<T>(CSize, Q);
T TOne = static_cast<T>(1);
T TTen = static_cast<T>(10);

Q.single_task([=]() SYCL_ESIMD_KERNEL {
{
simd<T, N> Vec(TOne);
Vec = Vec + TTen;
Vec.copy_to(Mem);
}
{
simd<T, N> Vec(TOne);
Vec = Vec - TTen;
Vec.copy_to(Mem + N);
}
{
simd<T, N> Vec(TOne);
Vec = Vec * TTen;
Vec.copy_to(Mem + 2 * N);
}
{
simd<T, N> Vec(TOne);
Vec = Vec / TTen;
Vec.copy_to(Mem + 3 * N);
}
}).wait();

bool ReturnValue = true;
for (int i = 0; i < N; ++i) {
if (Mem[i] != TOne + TTen) {
ReturnValue = false;
break;
}
if (Mem[i + N] != TOne - TTen) {
ReturnValue = false;
break;
}
if (Mem[i + 2 * N] != TOne * TTen) {
ReturnValue = false;
break;
}
if (!((Mem[i + 3 * N] == (TOne / TTen)) ||
(std::abs((double)(Mem[i + 3 * N] - (TOne / TTen)) /
(double)(TOne / TTen)) <= 0.001))) {
ReturnValue = false;
break;
}
}

free(Mem, Q);
return ReturnValue;
}

int main() {
queue Q;
esimd_test::printTestLabel(Q);

bool SupportsHalf = Q.get_device().has(aspect::fp16);

bool Passed = true;
Passed &= test<int>(Q);
Passed &= test<float>(Q);
if (SupportsHalf) {
Passed &= test<sycl::half>(Q);
}
#ifdef USE_BF16
Passed &= test<sycl::ext::oneapi::bfloat16>(Q);
#endif
#ifdef USE_TF32
Passed &= test<sycl::ext::intel::experimental::esimd::tfloat32>(Q);
#endif
std::cout << (Passed ? "Passed\n" : "FAILED\n");
return Passed ? 0 : 1;
}
14 changes: 14 additions & 0 deletions sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar_pvc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//==- bfloat16_vector_plus_scalar_pvc.cpp - Test for bfloat16 operators -==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// REQUIRES: gpu-intel-pvc
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#define USE_BF16
#define USE_TF32
#include "bfloat16_vector_plus_scalar.cpp"

0 comments on commit 8c92df9

Please sign in to comment.