From 6c1dde4243b5685c351b93530c290e769c421de3 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 2 Feb 2024 01:40:55 +0800 Subject: [PATCH] [SYCL][libdevice] Add sqrt with rounding mode supported in sycl::ext::intel::math (#12571) Signed-off-by: jinge90 --- libdevice/imf_impl_utils.hpp | 4 + libdevice/imf_rounding_op.hpp | 113 ++++++++++++++++++ libdevice/imf_utils/fp32_round.cpp | 12 ++ libdevice/imf_utils/fp64_round.cpp | 12 ++ libdevice/imf_wrapper.cpp | 24 ++++ libdevice/imf_wrapper_fp64.cpp | 24 ++++ .../sycl-post-link/SYCLDeviceLibReqMask.cpp | 8 ++ sycl/include/sycl/builtins.hpp | 8 ++ .../sycl/ext/intel/math/imf_rounding_math.hpp | 24 ++++ .../DeviceLib/imf_fp32_rounding_test.cpp | 30 +++++ .../DeviceLib/imf_fp64_rounding_test.cpp | 29 +++++ 11 files changed, 288 insertions(+) diff --git a/libdevice/imf_impl_utils.hpp b/libdevice/imf_impl_utils.hpp index d052d9bb4ec94..eedfb6fcb10d5 100644 --- a/libdevice/imf_impl_utils.hpp +++ b/libdevice/imf_impl_utils.hpp @@ -216,6 +216,10 @@ class __iml_ui128 { return (this->bits[0] > x.bits[0]); } + bool operator>=(const __iml_ui128 &x) { + return operator==(x) || operator>(x); + } + bool operator>(const uint64_t &x) { if (this->bits[1] > 0) return true; diff --git a/libdevice/imf_rounding_op.hpp b/libdevice/imf_rounding_op.hpp index 80ea9eb78690d..3a5c2a59b331a 100644 --- a/libdevice/imf_rounding_op.hpp +++ b/libdevice/imf_rounding_op.hpp @@ -11,6 +11,7 @@ #define __LIBDEVICE_IMF_ROUNDING_OP_H__ #include "imf_impl_utils.hpp" #include + template static Ty __handling_fp_overflow(unsigned z_sig, int rd) { typedef typename __iml_fp_config::utype UTy; @@ -1569,4 +1570,116 @@ template FTy __fp_fma(FTy x, FTy y, FTy z, int rd) { } } +template UTy integer_sqrt(UTy n, bool &is_squares) { + UTy x{n}, c{0}, d{1}; + d = d << (sizeof(UTy) * 8 - 2); + while (d > n) + d = d >> 2; + + while (d != 0) { + if (x >= (c + d)) { + x -= (c + d); + c = (c >> 1) + d; + } else + c = c >> 1; + d = d >> 2; + } + + if (c * c > n) + c -= 1; + if (c * c == n) + is_squares = true; + else + is_squares = false; + return c; +} + +template FTy __fp_sqrt(FTy x, int rd) { + typedef typename __iml_fp_config::utype UTy; + typedef typename __iml_get_double_size_unsigned::utype DSUTy; + constexpr int fra_digits = std::numeric_limits::digits - 1; + UTy x_bit = __builtin_bit_cast(UTy, x); + UTy x_exp = (x_bit & __iml_fp_config::pos_inf_bits) >> fra_digits; + UTy x_fra = x_bit & __iml_fp_config::fra_mask; + UTy x_sig = x_bit >> (sizeof(FTy) * 8 - 1); + DSUTy Bit1(1); + constexpr UTy NAN_BITS = __iml_fp_config::nan_bits; + constexpr UTy INF_BITS = __iml_fp_config::pos_inf_bits; + + if ((x_exp == __iml_fp_config::exp_mask) && (x_fra != 0x0)) + return __builtin_bit_cast(FTy, NAN_BITS); + + if ((x_exp == 0x0) && (x_fra == 0x0)) + return __builtin_bit_cast(FTy, static_cast(0x0)); + + if (x_sig == 1) + return __builtin_bit_cast(FTy, NAN_BITS); + + if ((x_exp == __iml_fp_config::exp_mask) && (x_fra == 0x0)) + return __builtin_bit_cast(FTy, INF_BITS); + + // For all postive subnormal and normal values, the result of sqrt + // is a normal value. + int32_t sx_exp = x_exp; + if (sx_exp == 0x0) + sx_exp = 1 - __iml_fp_config::bias; + else + sx_exp -= __iml_fp_config::bias; + + DSUTy fra_holder{x_fra}; + if (x_exp != 0) + fra_holder = (Bit1 << fra_digits) | fra_holder; + sx_exp -= fra_digits; + + // 2^x_exp * 1.mant can be represented as: 2^(x_exp - 52) * fra_holder + // for normal value and 2^-1022 * 0.mant can be represented as: + // 2^(-1074) * fra_holder for subnormal value. For fp32, 2^x_exp * 1.mant + // can be represented as: 2^(x_exp - 23) * fra_holder for normal value and + // 2^-126 * 0.mant can be represented as 2^-149 * fra_holder for subnormal. + // fra_holder is a non-zero value. + size_t lz = 0; + if constexpr (std::is_same::value) + lz = 127 - fra_holder.ui128_msb_pos(); + else + lz = 63 - get_msb_pos(fra_holder); + + fra_holder = fra_holder << lz; + sx_exp -= lz; + if (static_cast(sx_exp) & 0x1) { + sx_exp += 1; + fra_holder = fra_holder >> 1; + } + + bool is_squares = false; + DSUTy sqrt_fra = integer_sqrt(fra_holder, is_squares); + sx_exp = sx_exp / 2; + + if constexpr (std::is_same::value) + lz = 127 - sqrt_fra.ui128_msb_pos(); + else + lz = 63 - get_msb_pos(sqrt_fra); + UTy fra1 = + static_cast(sqrt_fra >> (sizeof(DSUTy) * 8 - lz - fra_digits - 1)); + fra1 = fra1 & __iml_fp_config::fra_mask; + sx_exp += sizeof(DSUTy) * 8 - 1 - lz + __iml_fp_config::bias; + + size_t grs_nsbit = sizeof(FTy) * 16 - lz - 1 - fra_digits; + uint32_t grs_bits = + static_cast(sqrt_fra & ((Bit1 << grs_nsbit) - Bit1)); + uint32_t s_bits = + grs_bits & static_cast((Bit1 << (grs_nsbit - 3)) - Bit1); + grs_bits = grs_bits >> (grs_nsbit - 3); + if ((s_bits > 0) || !is_squares) + grs_bits |= 0x1; + + uint32_t rb = + __handling_rounding(0U, static_cast(fra1), grs_bits, rd); + fra1 += rb; + if (fra1 > __iml_fp_config::fra_mask) { + fra1 = 0x0; + sx_exp++; + } + return __builtin_bit_cast(FTy, + (static_cast(sx_exp) << fra_digits) | fra1); +} #endif diff --git a/libdevice/imf_utils/fp32_round.cpp b/libdevice/imf_utils/fp32_round.cpp index c10af7d16ff3f..32548b1ccf912 100644 --- a/libdevice/imf_utils/fp32_round.cpp +++ b/libdevice/imf_utils/fp32_round.cpp @@ -109,4 +109,16 @@ DEVICE_EXTERN_C_INLINE float __devicelib_imf_fmaf_rz(float x, float y, float z) { return __fp_fma(x, y, z, __IML_RTZ); } + +DEVICE_EXTERN_C_INLINE +float __devicelib_imf_sqrtf_rd(float x) { return __fp_sqrt(x, __IML_RTN); } + +DEVICE_EXTERN_C_INLINE +float __devicelib_imf_sqrtf_rn(float x) { return __fp_sqrt(x, __IML_RTE); } + +DEVICE_EXTERN_C_INLINE +float __devicelib_imf_sqrtf_ru(float x) { return __fp_sqrt(x, __IML_RTP); } + +DEVICE_EXTERN_C_INLINE +float __devicelib_imf_sqrtf_rz(float x) { return __fp_sqrt(x, __IML_RTZ); } #endif diff --git a/libdevice/imf_utils/fp64_round.cpp b/libdevice/imf_utils/fp64_round.cpp index 9da6872433c99..aa4de27a669e1 100644 --- a/libdevice/imf_utils/fp64_round.cpp +++ b/libdevice/imf_utils/fp64_round.cpp @@ -109,4 +109,16 @@ DEVICE_EXTERN_C_INLINE double __devicelib_imf_fma_rz(double x, double y, double z) { return __fp_fma(x, y, z, __IML_RTZ); } + +DEVICE_EXTERN_C_INLINE +double __devicelib_imf_sqrt_rd(double x) { return __fp_sqrt(x, __IML_RTN); } + +DEVICE_EXTERN_C_INLINE +double __devicelib_imf_sqrt_rn(double x) { return __fp_sqrt(x, __IML_RTE); } + +DEVICE_EXTERN_C_INLINE +double __devicelib_imf_sqrt_ru(double x) { return __fp_sqrt(x, __IML_RTP); } + +DEVICE_EXTERN_C_INLINE +double __devicelib_imf_sqrt_rz(double x) { return __fp_sqrt(x, __IML_RTZ); } #endif diff --git a/libdevice/imf_wrapper.cpp b/libdevice/imf_wrapper.cpp index 9cca5c264ce61..d0c36e7fbc087 100644 --- a/libdevice/imf_wrapper.cpp +++ b/libdevice/imf_wrapper.cpp @@ -1992,4 +1992,28 @@ DEVICE_EXTERN_C_INLINE float __imf_fmaf_rz(float x, float y, float z) { return __devicelib_imf_fmaf_rz(x, y, z); } + +DEVICE_EXTERN_C_INLINE +float __devicelib_imf_sqrtf_rd(float); + +DEVICE_EXTERN_C_INLINE +float __imf_sqrtf_rd(float x) { return __devicelib_imf_sqrtf_rd(x); } + +DEVICE_EXTERN_C_INLINE +float __devicelib_imf_sqrtf_rn(float); + +DEVICE_EXTERN_C_INLINE +float __imf_sqrtf_rn(float x) { return __devicelib_imf_sqrtf_rn(x); } + +DEVICE_EXTERN_C_INLINE +float __devicelib_imf_sqrtf_ru(float); + +DEVICE_EXTERN_C_INLINE +float __imf_sqrtf_ru(float x) { return __devicelib_imf_sqrtf_ru(x); } + +DEVICE_EXTERN_C_INLINE +float __devicelib_imf_sqrtf_rz(float); + +DEVICE_EXTERN_C_INLINE +float __imf_sqrtf_rz(float x) { return __devicelib_imf_sqrtf_rz(x); } #endif // __LIBDEVICE_IMF_ENABLED__ diff --git a/libdevice/imf_wrapper_fp64.cpp b/libdevice/imf_wrapper_fp64.cpp index dfd6e00278b1d..7fa60f0011468 100644 --- a/libdevice/imf_wrapper_fp64.cpp +++ b/libdevice/imf_wrapper_fp64.cpp @@ -549,4 +549,28 @@ DEVICE_EXTERN_C_INLINE double __imf_fma_rz(double x, double y, double z) { return __devicelib_imf_fma_rz(x, y, z); } + +DEVICE_EXTERN_C_INLINE +double __devicelib_imf_sqrt_rd(double); + +DEVICE_EXTERN_C_INLINE +double __imf_sqrt_rd(double x) { return __devicelib_imf_sqrt_rd(x); } + +DEVICE_EXTERN_C_INLINE +double __devicelib_imf_sqrt_rn(double); + +DEVICE_EXTERN_C_INLINE +double __imf_sqrt_rn(double x) { return __devicelib_imf_sqrt_rn(x); } + +DEVICE_EXTERN_C_INLINE +double __devicelib_imf_sqrt_ru(double); + +DEVICE_EXTERN_C_INLINE +double __imf_sqrt_ru(double x) { return __devicelib_imf_sqrt_ru(x); } + +DEVICE_EXTERN_C_INLINE +double __devicelib_imf_sqrt_rz(double); + +DEVICE_EXTERN_C_INLINE +double __imf_sqrt_rz(double x) { return __devicelib_imf_sqrt_rz(x); } #endif // __LIBDEVICE_IMF_ENABLED__ diff --git a/llvm/tools/sycl-post-link/SYCLDeviceLibReqMask.cpp b/llvm/tools/sycl-post-link/SYCLDeviceLibReqMask.cpp index 1b6cdefb9a541..9c52b8b524dd8 100644 --- a/llvm/tools/sycl-post-link/SYCLDeviceLibReqMask.cpp +++ b/llvm/tools/sycl-post-link/SYCLDeviceLibReqMask.cpp @@ -243,6 +243,10 @@ SYCLDeviceLibFuncMap SDLMap = { {"__devicelib_imf_fmaf_rn", DeviceLibExt::cl_intel_devicelib_imf}, {"__devicelib_imf_fmaf_ru", DeviceLibExt::cl_intel_devicelib_imf}, {"__devicelib_imf_fmaf_rz", DeviceLibExt::cl_intel_devicelib_imf}, + {"__devicelib_imf_sqrtf_rd", DeviceLibExt::cl_intel_devicelib_imf}, + {"__devicelib_imf_sqrtf_rn", DeviceLibExt::cl_intel_devicelib_imf}, + {"__devicelib_imf_sqrtf_ru", DeviceLibExt::cl_intel_devicelib_imf}, + {"__devicelib_imf_sqrtf_rz", DeviceLibExt::cl_intel_devicelib_imf}, {"__devicelib_imf_float2int_rd", DeviceLibExt::cl_intel_devicelib_imf}, {"__devicelib_imf_float2int_rn", DeviceLibExt::cl_intel_devicelib_imf}, {"__devicelib_imf_float2int_ru", DeviceLibExt::cl_intel_devicelib_imf}, @@ -528,6 +532,10 @@ SYCLDeviceLibFuncMap SDLMap = { {"__devicelib_imf_fma_rn", DeviceLibExt::cl_intel_devicelib_imf_fp64}, {"__devicelib_imf_fma_ru", DeviceLibExt::cl_intel_devicelib_imf_fp64}, {"__devicelib_imf_fma_rz", DeviceLibExt::cl_intel_devicelib_imf_fp64}, + {"__devicelib_imf_sqrt_rd", DeviceLibExt::cl_intel_devicelib_imf_fp64}, + {"__devicelib_imf_sqrt_rn", DeviceLibExt::cl_intel_devicelib_imf_fp64}, + {"__devicelib_imf_sqrt_ru", DeviceLibExt::cl_intel_devicelib_imf_fp64}, + {"__devicelib_imf_sqrt_rz", DeviceLibExt::cl_intel_devicelib_imf_fp64}, {"__devicelib_imf_bfloat162float", DeviceLibExt::cl_intel_devicelib_imf_bf16}, {"__devicelib_imf_bfloat162int_rd", diff --git a/sycl/include/sycl/builtins.hpp b/sycl/include/sycl/builtins.hpp index 4edeeb8a73878..87a4e84122d83 100644 --- a/sycl/include/sycl/builtins.hpp +++ b/sycl/include/sycl/builtins.hpp @@ -121,6 +121,10 @@ extern __DPCPP_SYCL_EXTERNAL float __imf_fmaf_rd(float x, float y, float z); extern __DPCPP_SYCL_EXTERNAL float __imf_fmaf_rn(float x, float y, float z); extern __DPCPP_SYCL_EXTERNAL float __imf_fmaf_ru(float x, float y, float z); extern __DPCPP_SYCL_EXTERNAL float __imf_fmaf_rz(float x, float y, float z); +extern __DPCPP_SYCL_EXTERNAL float __imf_sqrtf_rd(float x); +extern __DPCPP_SYCL_EXTERNAL float __imf_sqrtf_rn(float x); +extern __DPCPP_SYCL_EXTERNAL float __imf_sqrtf_ru(float x); +extern __DPCPP_SYCL_EXTERNAL float __imf_sqrtf_rz(float x); extern __DPCPP_SYCL_EXTERNAL int __imf_float2int_rd(float x); extern __DPCPP_SYCL_EXTERNAL int __imf_float2int_rn(float x); extern __DPCPP_SYCL_EXTERNAL int __imf_float2int_ru(float x); @@ -358,6 +362,10 @@ extern __DPCPP_SYCL_EXTERNAL double __imf_drcp_rd(double x); extern __DPCPP_SYCL_EXTERNAL double __imf_drcp_rn(double x); extern __DPCPP_SYCL_EXTERNAL double __imf_drcp_ru(double x); extern __DPCPP_SYCL_EXTERNAL double __imf_drcp_rz(double x); +extern __DPCPP_SYCL_EXTERNAL double __imf_sqrt_rd(double x); +extern __DPCPP_SYCL_EXTERNAL double __imf_sqrt_rn(double x); +extern __DPCPP_SYCL_EXTERNAL double __imf_sqrt_ru(double x); +extern __DPCPP_SYCL_EXTERNAL double __imf_sqrt_rz(double x); extern __DPCPP_SYCL_EXTERNAL float __imf_double2float_rd(double x); extern __DPCPP_SYCL_EXTERNAL float __imf_double2float_rn(double x); extern __DPCPP_SYCL_EXTERNAL float __imf_double2float_ru(double x); diff --git a/sycl/include/sycl/ext/intel/math/imf_rounding_math.hpp b/sycl/include/sycl/ext/intel/math/imf_rounding_math.hpp index e76b897599e80..d895d587fc987 100644 --- a/sycl/include/sycl/ext/intel/math/imf_rounding_math.hpp +++ b/sycl/include/sycl/ext/intel/math/imf_rounding_math.hpp @@ -35,6 +35,10 @@ float __imf_fmaf_rz(float, float, float); float __imf_fmaf_rn(float, float, float); float __imf_fmaf_ru(float, float, float); float __imf_fmaf_rd(float, float, float); +float __imf_sqrtf_rz(float); +float __imf_sqrtf_rn(float); +float __imf_sqrtf_ru(float); +float __imf_sqrtf_rd(float); double __imf_dadd_rz(double, double); double __imf_dadd_rn(double, double); @@ -60,6 +64,10 @@ double __imf_fma_rz(double, double, double); double __imf_fma_rn(double, double, double); double __imf_fma_ru(double, double, double); double __imf_fma_rd(double, double, double); +double __imf_sqrt_rz(double); +double __imf_sqrt_rn(double); +double __imf_sqrt_ru(double); +double __imf_sqrt_rd(double); }; namespace sycl { @@ -154,6 +162,14 @@ template Tp fmaf_rz(Tp x, Tp y, Tp z) { return __imf_fmaf_rz(x, y, z); } +template Tp fsqrt_rd(Tp x) { return __imf_sqrtf_rd(x); } + +template Tp fsqrt_rn(Tp x) { return __imf_sqrtf_rn(x); } + +template Tp fsqrt_ru(Tp x) { return __imf_sqrtf_ru(x); } + +template Tp fsqrt_rz(Tp x) { return __imf_sqrtf_rz(x); } + template Tp dadd_rd(Tp x, Tp y) { return __imf_dadd_rd(x, y); } @@ -242,6 +258,14 @@ template Tp fma_rz(Tp x, Tp y, Tp z) { return __imf_fma_rz(x, y, z); } +template Tp dsqrt_rd(Tp x) { return __imf_sqrt_rd(x); } + +template Tp dsqrt_rn(Tp x) { return __imf_sqrt_rn(x); } + +template Tp dsqrt_ru(Tp x) { return __imf_sqrt_ru(x); } + +template Tp dsqrt_rz(Tp x) { return __imf_sqrt_rz(x); } + } // namespace ext::intel::math } // namespace _V1 } // namespace sycl diff --git a/sycl/test-e2e/DeviceLib/imf_fp32_rounding_test.cpp b/sycl/test-e2e/DeviceLib/imf_fp32_rounding_test.cpp index ff3af30dabe37..ba8f9de045c9e 100644 --- a/sycl/test-e2e/DeviceLib/imf_fp32_rounding_test.cpp +++ b/sycl/test-e2e/DeviceLib/imf_fp32_rounding_test.cpp @@ -180,5 +180,35 @@ int main(int, char **) { std::cout << "sycl::ext::intel::math::fmaf_rz passes." << std::endl; } + { + std::initializer_list input_vals = { + 0x1.ba90e6p+1, 0x1.4p+1, 0x1.ea77e6p-2, 0x1.e8330ap+19, + 0x1.4ffd68p+5, 0x1.443084p-15, 0x1.605fb2p+6, 0x1.2eb718p-7}; + std::initializer_list ref_vals_rd = { + 0x3fee0264, 0x3fca62c1, 0x3f312c12, 0x4479faa2, + 0x40cf616d, 0x3bcbb4d0, 0x41162c48, 0x3dc4d80e}; + std::initializer_list ref_vals_rn = { + 0x3fee0265, 0x3fca62c2, 0x3f312c13, 0x4479faa2, + 0x40cf616d, 0x3bcbb4d0, 0x41162c48, 0x3dc4d80e}; + std::initializer_list ref_vals_ru = { + 0x3fee0265, 0x3fca62c2, 0x3f312c13, 0x4479faa3, + 0x40cf616e, 0x3bcbb4d1, 0x41162c49, 0x3dc4d80f}; + std::initializer_list ref_vals_rz = { + 0x3fee0264, 0x3fca62c1, 0x3f312c12, 0x4479faa2, + 0x40cf616d, 0x3bcbb4d0, 0x41162c48, 0x3dc4d80e}; + test(device_queue, input_vals, ref_vals_rd, + FT(unsigned, sycl::ext::intel::math::fsqrt_rd)); + std::cout << "sycl::ext::intel::math::fsqrt_rd passes." << std::endl; + test(device_queue, input_vals, ref_vals_rn, + FT(unsigned, sycl::ext::intel::math::fsqrt_rn)); + std::cout << "sycl::ext::intel::math::fsqrt_rn passes." << std::endl; + test(device_queue, input_vals, ref_vals_ru, + FT(unsigned, sycl::ext::intel::math::fsqrt_ru)); + std::cout << "sycl::ext::intel::math::fsqrt_ru passes." << std::endl; + test(device_queue, input_vals, ref_vals_rz, + FT(unsigned, sycl::ext::intel::math::fsqrt_rz)); + std::cout << "sycl::ext::intel::math::fsqrt_rz passes." << std::endl; + } + return 0; } diff --git a/sycl/test-e2e/DeviceLib/imf_fp64_rounding_test.cpp b/sycl/test-e2e/DeviceLib/imf_fp64_rounding_test.cpp index d7f40bf0c408a..44236e37b4c4d 100644 --- a/sycl/test-e2e/DeviceLib/imf_fp64_rounding_test.cpp +++ b/sycl/test-e2e/DeviceLib/imf_fp64_rounding_test.cpp @@ -210,5 +210,34 @@ int main(int, char **) { std::cout << "sycl::ext::intel::math::fmaf_rz passes." << std::endl; } + { + std::initializer_list input_vals1 = { + 0x1p+2, 0x1.fbd37afb0f8edp-1, 0x1.9238e38e38e35p+6, 0x1.7p+3}; + std::initializer_list ref_vals_rd = { + 0x4000000000000000, 0x3fefde8a59acb0bb, 0x40240e33d899cd1b, + 0x400b211b1c70d023}; + std::initializer_list ref_vals_rn = { + 0x4000000000000000, 0x3fefde8a59acb0bc, 0x40240e33d899cd1c, + 0x400b211b1c70d023}; + std::initializer_list ref_vals_ru = { + 0x4000000000000000, 0x3fefde8a59acb0bc, 0x40240e33d899cd1c, + 0x400b211b1c70d024}; + std::initializer_list ref_vals_rz = { + 0x4000000000000000, 0x3fefde8a59acb0bb, 0x40240e33d899cd1b, + 0x400b211b1c70d023}; + test(device_queue, input_vals1, ref_vals_rd, + FT(unsigned long long, sycl::ext::intel::math::dsqrt_rd)); + std::cout << "sycl::ext::intel::math::dsqrt_rd passes." << std::endl; + test(device_queue, input_vals1, ref_vals_rn, + FT(unsigned long long, sycl::ext::intel::math::dsqrt_rn)); + std::cout << "sycl::ext::intel::math::dsqrt_rn passes." << std::endl; + test(device_queue, input_vals1, ref_vals_ru, + FT(unsigned long long, sycl::ext::intel::math::dsqrt_ru)); + std::cout << "sycl::ext::intel::math::dsqrt_ru passes." << std::endl; + test(device_queue, input_vals1, ref_vals_rz, + FT(unsigned long long, sycl::ext::intel::math::dsqrt_rz)); + std::cout << "sycl::ext::intel::math::dsqrt_rz passes." << std::endl; + } + return 0; }