Skip to content

Commit

Permalink
[SYCL][libdevice] Add sqrt with rounding mode supported in sycl::ext:…
Browse files Browse the repository at this point in the history
…:intel::math (#12571)

Signed-off-by: jinge90 <ge.jin@intel.com>
  • Loading branch information
jinge90 authored Feb 1, 2024
1 parent 8427bd2 commit 6c1dde4
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 0 deletions.
4 changes: 4 additions & 0 deletions libdevice/imf_impl_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ class __iml_ui128 {
return (this->bits[0] > x.bits[0]);
}

bool operator>=(const __iml_ui128 &x) {
return operator==(x) || operator>(x);
}

bool operator>(const uint64_t &x) {
if (this->bits[1] > 0)
return true;
Expand Down
113 changes: 113 additions & 0 deletions libdevice/imf_rounding_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define __LIBDEVICE_IMF_ROUNDING_OP_H__
#include "imf_impl_utils.hpp"
#include <limits>

template <typename Ty>
static Ty __handling_fp_overflow(unsigned z_sig, int rd) {
typedef typename __iml_fp_config<Ty>::utype UTy;
Expand Down Expand Up @@ -1569,4 +1570,116 @@ template <typename FTy> FTy __fp_fma(FTy x, FTy y, FTy z, int rd) {
}
}

template <typename UTy> UTy integer_sqrt(UTy n, bool &is_squares) {
UTy x{n}, c{0}, d{1};
d = d << (sizeof(UTy) * 8 - 2);
while (d > n)
d = d >> 2;

while (d != 0) {
if (x >= (c + d)) {
x -= (c + d);
c = (c >> 1) + d;
} else
c = c >> 1;
d = d >> 2;
}

if (c * c > n)
c -= 1;
if (c * c == n)
is_squares = true;
else
is_squares = false;
return c;
}

template <typename FTy> FTy __fp_sqrt(FTy x, int rd) {
typedef typename __iml_fp_config<FTy>::utype UTy;
typedef typename __iml_get_double_size_unsigned<UTy>::utype DSUTy;
constexpr int fra_digits = std::numeric_limits<FTy>::digits - 1;
UTy x_bit = __builtin_bit_cast(UTy, x);
UTy x_exp = (x_bit & __iml_fp_config<FTy>::pos_inf_bits) >> fra_digits;
UTy x_fra = x_bit & __iml_fp_config<FTy>::fra_mask;
UTy x_sig = x_bit >> (sizeof(FTy) * 8 - 1);
DSUTy Bit1(1);
constexpr UTy NAN_BITS = __iml_fp_config<FTy>::nan_bits;
constexpr UTy INF_BITS = __iml_fp_config<FTy>::pos_inf_bits;

if ((x_exp == __iml_fp_config<FTy>::exp_mask) && (x_fra != 0x0))
return __builtin_bit_cast(FTy, NAN_BITS);

if ((x_exp == 0x0) && (x_fra == 0x0))
return __builtin_bit_cast(FTy, static_cast<UTy>(0x0));

if (x_sig == 1)
return __builtin_bit_cast(FTy, NAN_BITS);

if ((x_exp == __iml_fp_config<FTy>::exp_mask) && (x_fra == 0x0))
return __builtin_bit_cast(FTy, INF_BITS);

// For all postive subnormal and normal values, the result of sqrt
// is a normal value.
int32_t sx_exp = x_exp;
if (sx_exp == 0x0)
sx_exp = 1 - __iml_fp_config<FTy>::bias;
else
sx_exp -= __iml_fp_config<FTy>::bias;

DSUTy fra_holder{x_fra};
if (x_exp != 0)
fra_holder = (Bit1 << fra_digits) | fra_holder;
sx_exp -= fra_digits;

// 2^x_exp * 1.mant can be represented as: 2^(x_exp - 52) * fra_holder
// for normal value and 2^-1022 * 0.mant can be represented as:
// 2^(-1074) * fra_holder for subnormal value. For fp32, 2^x_exp * 1.mant
// can be represented as: 2^(x_exp - 23) * fra_holder for normal value and
// 2^-126 * 0.mant can be represented as 2^-149 * fra_holder for subnormal.
// fra_holder is a non-zero value.
size_t lz = 0;
if constexpr (std::is_same<DSUTy, __iml_ui128>::value)
lz = 127 - fra_holder.ui128_msb_pos();
else
lz = 63 - get_msb_pos(fra_holder);

fra_holder = fra_holder << lz;
sx_exp -= lz;
if (static_cast<uint32_t>(sx_exp) & 0x1) {
sx_exp += 1;
fra_holder = fra_holder >> 1;
}

bool is_squares = false;
DSUTy sqrt_fra = integer_sqrt<DSUTy>(fra_holder, is_squares);
sx_exp = sx_exp / 2;

if constexpr (std::is_same<DSUTy, __iml_ui128>::value)
lz = 127 - sqrt_fra.ui128_msb_pos();
else
lz = 63 - get_msb_pos(sqrt_fra);
UTy fra1 =
static_cast<UTy>(sqrt_fra >> (sizeof(DSUTy) * 8 - lz - fra_digits - 1));
fra1 = fra1 & __iml_fp_config<FTy>::fra_mask;
sx_exp += sizeof(DSUTy) * 8 - 1 - lz + __iml_fp_config<FTy>::bias;

size_t grs_nsbit = sizeof(FTy) * 16 - lz - 1 - fra_digits;
uint32_t grs_bits =
static_cast<uint32_t>(sqrt_fra & ((Bit1 << grs_nsbit) - Bit1));
uint32_t s_bits =
grs_bits & static_cast<uint32_t>((Bit1 << (grs_nsbit - 3)) - Bit1);
grs_bits = grs_bits >> (grs_nsbit - 3);
if ((s_bits > 0) || !is_squares)
grs_bits |= 0x1;

uint32_t rb =
__handling_rounding(0U, static_cast<uint32_t>(fra1), grs_bits, rd);
fra1 += rb;
if (fra1 > __iml_fp_config<FTy>::fra_mask) {
fra1 = 0x0;
sx_exp++;
}
return __builtin_bit_cast(FTy,
(static_cast<UTy>(sx_exp) << fra_digits) | fra1);
}
#endif
12 changes: 12 additions & 0 deletions libdevice/imf_utils/fp32_round.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,16 @@ DEVICE_EXTERN_C_INLINE
float __devicelib_imf_fmaf_rz(float x, float y, float z) {
return __fp_fma(x, y, z, __IML_RTZ);
}

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_rd(float x) { return __fp_sqrt(x, __IML_RTN); }

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_rn(float x) { return __fp_sqrt(x, __IML_RTE); }

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_ru(float x) { return __fp_sqrt(x, __IML_RTP); }

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_rz(float x) { return __fp_sqrt(x, __IML_RTZ); }
#endif
12 changes: 12 additions & 0 deletions libdevice/imf_utils/fp64_round.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,16 @@ DEVICE_EXTERN_C_INLINE
double __devicelib_imf_fma_rz(double x, double y, double z) {
return __fp_fma(x, y, z, __IML_RTZ);
}

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_rd(double x) { return __fp_sqrt(x, __IML_RTN); }

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_rn(double x) { return __fp_sqrt(x, __IML_RTE); }

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_ru(double x) { return __fp_sqrt(x, __IML_RTP); }

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_rz(double x) { return __fp_sqrt(x, __IML_RTZ); }
#endif
24 changes: 24 additions & 0 deletions libdevice/imf_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1992,4 +1992,28 @@ DEVICE_EXTERN_C_INLINE
float __imf_fmaf_rz(float x, float y, float z) {
return __devicelib_imf_fmaf_rz(x, y, z);
}

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_rd(float);

DEVICE_EXTERN_C_INLINE
float __imf_sqrtf_rd(float x) { return __devicelib_imf_sqrtf_rd(x); }

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_rn(float);

DEVICE_EXTERN_C_INLINE
float __imf_sqrtf_rn(float x) { return __devicelib_imf_sqrtf_rn(x); }

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_ru(float);

DEVICE_EXTERN_C_INLINE
float __imf_sqrtf_ru(float x) { return __devicelib_imf_sqrtf_ru(x); }

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_rz(float);

DEVICE_EXTERN_C_INLINE
float __imf_sqrtf_rz(float x) { return __devicelib_imf_sqrtf_rz(x); }
#endif // __LIBDEVICE_IMF_ENABLED__
24 changes: 24 additions & 0 deletions libdevice/imf_wrapper_fp64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,4 +549,28 @@ DEVICE_EXTERN_C_INLINE
double __imf_fma_rz(double x, double y, double z) {
return __devicelib_imf_fma_rz(x, y, z);
}

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_rd(double);

DEVICE_EXTERN_C_INLINE
double __imf_sqrt_rd(double x) { return __devicelib_imf_sqrt_rd(x); }

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_rn(double);

DEVICE_EXTERN_C_INLINE
double __imf_sqrt_rn(double x) { return __devicelib_imf_sqrt_rn(x); }

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_ru(double);

DEVICE_EXTERN_C_INLINE
double __imf_sqrt_ru(double x) { return __devicelib_imf_sqrt_ru(x); }

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_rz(double);

DEVICE_EXTERN_C_INLINE
double __imf_sqrt_rz(double x) { return __devicelib_imf_sqrt_rz(x); }
#endif // __LIBDEVICE_IMF_ENABLED__
8 changes: 8 additions & 0 deletions llvm/tools/sycl-post-link/SYCLDeviceLibReqMask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ SYCLDeviceLibFuncMap SDLMap = {
{"__devicelib_imf_fmaf_rn", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_fmaf_ru", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_fmaf_rz", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_sqrtf_rd", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_sqrtf_rn", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_sqrtf_ru", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_sqrtf_rz", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_float2int_rd", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_float2int_rn", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_float2int_ru", DeviceLibExt::cl_intel_devicelib_imf},
Expand Down Expand Up @@ -528,6 +532,10 @@ SYCLDeviceLibFuncMap SDLMap = {
{"__devicelib_imf_fma_rn", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_fma_ru", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_fma_rz", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_sqrt_rd", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_sqrt_rn", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_sqrt_ru", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_sqrt_rz", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_bfloat162float",
DeviceLibExt::cl_intel_devicelib_imf_bf16},
{"__devicelib_imf_bfloat162int_rd",
Expand Down
8 changes: 8 additions & 0 deletions sycl/include/sycl/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ extern __DPCPP_SYCL_EXTERNAL float __imf_fmaf_rd(float x, float y, float z);
extern __DPCPP_SYCL_EXTERNAL float __imf_fmaf_rn(float x, float y, float z);
extern __DPCPP_SYCL_EXTERNAL float __imf_fmaf_ru(float x, float y, float z);
extern __DPCPP_SYCL_EXTERNAL float __imf_fmaf_rz(float x, float y, float z);
extern __DPCPP_SYCL_EXTERNAL float __imf_sqrtf_rd(float x);
extern __DPCPP_SYCL_EXTERNAL float __imf_sqrtf_rn(float x);
extern __DPCPP_SYCL_EXTERNAL float __imf_sqrtf_ru(float x);
extern __DPCPP_SYCL_EXTERNAL float __imf_sqrtf_rz(float x);
extern __DPCPP_SYCL_EXTERNAL int __imf_float2int_rd(float x);
extern __DPCPP_SYCL_EXTERNAL int __imf_float2int_rn(float x);
extern __DPCPP_SYCL_EXTERNAL int __imf_float2int_ru(float x);
Expand Down Expand Up @@ -358,6 +362,10 @@ extern __DPCPP_SYCL_EXTERNAL double __imf_drcp_rd(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_drcp_rn(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_drcp_ru(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_drcp_rz(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_sqrt_rd(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_sqrt_rn(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_sqrt_ru(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_sqrt_rz(double x);
extern __DPCPP_SYCL_EXTERNAL float __imf_double2float_rd(double x);
extern __DPCPP_SYCL_EXTERNAL float __imf_double2float_rn(double x);
extern __DPCPP_SYCL_EXTERNAL float __imf_double2float_ru(double x);
Expand Down
24 changes: 24 additions & 0 deletions sycl/include/sycl/ext/intel/math/imf_rounding_math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ float __imf_fmaf_rz(float, float, float);
float __imf_fmaf_rn(float, float, float);
float __imf_fmaf_ru(float, float, float);
float __imf_fmaf_rd(float, float, float);
float __imf_sqrtf_rz(float);
float __imf_sqrtf_rn(float);
float __imf_sqrtf_ru(float);
float __imf_sqrtf_rd(float);

double __imf_dadd_rz(double, double);
double __imf_dadd_rn(double, double);
Expand All @@ -60,6 +64,10 @@ double __imf_fma_rz(double, double, double);
double __imf_fma_rn(double, double, double);
double __imf_fma_ru(double, double, double);
double __imf_fma_rd(double, double, double);
double __imf_sqrt_rz(double);
double __imf_sqrt_rn(double);
double __imf_sqrt_ru(double);
double __imf_sqrt_rd(double);
};

namespace sycl {
Expand Down Expand Up @@ -154,6 +162,14 @@ template <typename Tp = float> Tp fmaf_rz(Tp x, Tp y, Tp z) {
return __imf_fmaf_rz(x, y, z);
}

template <typename Tp = float> Tp fsqrt_rd(Tp x) { return __imf_sqrtf_rd(x); }

template <typename Tp = float> Tp fsqrt_rn(Tp x) { return __imf_sqrtf_rn(x); }

template <typename Tp = float> Tp fsqrt_ru(Tp x) { return __imf_sqrtf_ru(x); }

template <typename Tp = float> Tp fsqrt_rz(Tp x) { return __imf_sqrtf_rz(x); }

template <typename Tp = double> Tp dadd_rd(Tp x, Tp y) {
return __imf_dadd_rd(x, y);
}
Expand Down Expand Up @@ -242,6 +258,14 @@ template <typename Tp = double> Tp fma_rz(Tp x, Tp y, Tp z) {
return __imf_fma_rz(x, y, z);
}

template <typename Tp = double> Tp dsqrt_rd(Tp x) { return __imf_sqrt_rd(x); }

template <typename Tp = double> Tp dsqrt_rn(Tp x) { return __imf_sqrt_rn(x); }

template <typename Tp = double> Tp dsqrt_ru(Tp x) { return __imf_sqrt_ru(x); }

template <typename Tp = double> Tp dsqrt_rz(Tp x) { return __imf_sqrt_rz(x); }

} // namespace ext::intel::math
} // namespace _V1
} // namespace sycl
30 changes: 30 additions & 0 deletions sycl/test-e2e/DeviceLib/imf_fp32_rounding_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,35 @@ int main(int, char **) {
std::cout << "sycl::ext::intel::math::fmaf_rz passes." << std::endl;
}

{
std::initializer_list<float> input_vals = {
0x1.ba90e6p+1, 0x1.4p+1, 0x1.ea77e6p-2, 0x1.e8330ap+19,
0x1.4ffd68p+5, 0x1.443084p-15, 0x1.605fb2p+6, 0x1.2eb718p-7};
std::initializer_list<unsigned> ref_vals_rd = {
0x3fee0264, 0x3fca62c1, 0x3f312c12, 0x4479faa2,
0x40cf616d, 0x3bcbb4d0, 0x41162c48, 0x3dc4d80e};
std::initializer_list<unsigned> ref_vals_rn = {
0x3fee0265, 0x3fca62c2, 0x3f312c13, 0x4479faa2,
0x40cf616d, 0x3bcbb4d0, 0x41162c48, 0x3dc4d80e};
std::initializer_list<unsigned> ref_vals_ru = {
0x3fee0265, 0x3fca62c2, 0x3f312c13, 0x4479faa3,
0x40cf616e, 0x3bcbb4d1, 0x41162c49, 0x3dc4d80f};
std::initializer_list<unsigned> ref_vals_rz = {
0x3fee0264, 0x3fca62c1, 0x3f312c12, 0x4479faa2,
0x40cf616d, 0x3bcbb4d0, 0x41162c48, 0x3dc4d80e};
test(device_queue, input_vals, ref_vals_rd,
FT(unsigned, sycl::ext::intel::math::fsqrt_rd));
std::cout << "sycl::ext::intel::math::fsqrt_rd passes." << std::endl;
test(device_queue, input_vals, ref_vals_rn,
FT(unsigned, sycl::ext::intel::math::fsqrt_rn));
std::cout << "sycl::ext::intel::math::fsqrt_rn passes." << std::endl;
test(device_queue, input_vals, ref_vals_ru,
FT(unsigned, sycl::ext::intel::math::fsqrt_ru));
std::cout << "sycl::ext::intel::math::fsqrt_ru passes." << std::endl;
test(device_queue, input_vals, ref_vals_rz,
FT(unsigned, sycl::ext::intel::math::fsqrt_rz));
std::cout << "sycl::ext::intel::math::fsqrt_rz passes." << std::endl;
}

return 0;
}
29 changes: 29 additions & 0 deletions sycl/test-e2e/DeviceLib/imf_fp64_rounding_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,5 +210,34 @@ int main(int, char **) {
std::cout << "sycl::ext::intel::math::fmaf_rz passes." << std::endl;
}

{
std::initializer_list<double> input_vals1 = {
0x1p+2, 0x1.fbd37afb0f8edp-1, 0x1.9238e38e38e35p+6, 0x1.7p+3};
std::initializer_list<unsigned long long> ref_vals_rd = {
0x4000000000000000, 0x3fefde8a59acb0bb, 0x40240e33d899cd1b,
0x400b211b1c70d023};
std::initializer_list<unsigned long long> ref_vals_rn = {
0x4000000000000000, 0x3fefde8a59acb0bc, 0x40240e33d899cd1c,
0x400b211b1c70d023};
std::initializer_list<unsigned long long> ref_vals_ru = {
0x4000000000000000, 0x3fefde8a59acb0bc, 0x40240e33d899cd1c,
0x400b211b1c70d024};
std::initializer_list<unsigned long long> ref_vals_rz = {
0x4000000000000000, 0x3fefde8a59acb0bb, 0x40240e33d899cd1b,
0x400b211b1c70d023};
test(device_queue, input_vals1, ref_vals_rd,
FT(unsigned long long, sycl::ext::intel::math::dsqrt_rd));
std::cout << "sycl::ext::intel::math::dsqrt_rd passes." << std::endl;
test(device_queue, input_vals1, ref_vals_rn,
FT(unsigned long long, sycl::ext::intel::math::dsqrt_rn));
std::cout << "sycl::ext::intel::math::dsqrt_rn passes." << std::endl;
test(device_queue, input_vals1, ref_vals_ru,
FT(unsigned long long, sycl::ext::intel::math::dsqrt_ru));
std::cout << "sycl::ext::intel::math::dsqrt_ru passes." << std::endl;
test(device_queue, input_vals1, ref_vals_rz,
FT(unsigned long long, sycl::ext::intel::math::dsqrt_rz));
std::cout << "sycl::ext::intel::math::dsqrt_rz passes." << std::endl;
}

return 0;
}

0 comments on commit 6c1dde4

Please sign in to comment.