Skip to content

Commit

Permalink
[SYCL][ESIMD]Replace use of intrinsics with spirv functions (#13553)
Browse files Browse the repository at this point in the history
  • Loading branch information
fineg74 committed May 14, 2024
1 parent 9beb70b commit af65855
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 34 deletions.
46 changes: 46 additions & 0 deletions sycl/include/sycl/ext/intel/esimd/detail/math_intrin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,52 @@
#define __ESIMD_cpp_vec_t(T, SZ) \
__ESIMD_DNS::vector_type_t<__ESIMD_DNS::__cpp_t<T>, SZ>

// The following spirv intrinsics declarations are put here to avoid unintended
// use by other targets where it causes run time failures due to the fact that
// they are implemented for INTEL GPU only.
template <typename T> extern __DPCPP_SYCL_EXTERNAL T __spirv_ocl_native_exp2(T);
template <typename T, int N>
extern __DPCPP_SYCL_EXTERNAL __ESIMD_raw_vec_t(T, N)
__spirv_ocl_native_exp2(__ESIMD_raw_vec_t(T, N));

template <typename T>
extern __DPCPP_SYCL_EXTERNAL T __spirv_ocl_native_recip(T);
template <typename T, int N>
extern __DPCPP_SYCL_EXTERNAL __ESIMD_raw_vec_t(T, N)
__spirv_ocl_native_recip(__ESIMD_raw_vec_t(T, N));

template <typename T> extern __DPCPP_SYCL_EXTERNAL T __spirv_ocl_native_cos(T);
template <typename T, int N>
extern __DPCPP_SYCL_EXTERNAL __ESIMD_raw_vec_t(T, N)
__spirv_ocl_native_cos(__ESIMD_raw_vec_t(T, N));

template <typename T> extern __DPCPP_SYCL_EXTERNAL T __spirv_ocl_native_log2(T);
template <typename T, int N>
extern __DPCPP_SYCL_EXTERNAL __ESIMD_raw_vec_t(T, N)
__spirv_ocl_native_log2(__ESIMD_raw_vec_t(T, N));

template <typename T>
extern __DPCPP_SYCL_EXTERNAL T __spirv_ocl_native_rsqrt(T);
template <typename T, int N>
extern __DPCPP_SYCL_EXTERNAL __ESIMD_raw_vec_t(T, N)
__spirv_ocl_native_rsqrt(__ESIMD_raw_vec_t(T, N));

template <typename T> extern __DPCPP_SYCL_EXTERNAL T __spirv_ocl_native_sin(T);
template <typename T, int N>
extern __DPCPP_SYCL_EXTERNAL __ESIMD_raw_vec_t(T, N)
__spirv_ocl_native_sin(__ESIMD_raw_vec_t(T, N));

template <typename T> extern __DPCPP_SYCL_EXTERNAL T __spirv_ocl_native_sqrt(T);
template <typename T, int N>
extern __DPCPP_SYCL_EXTERNAL __ESIMD_raw_vec_t(T, N)
__spirv_ocl_native_sqrt(__ESIMD_raw_vec_t(T, N));

template <typename T>
extern __DPCPP_SYCL_EXTERNAL T __spirv_ocl_native_powr(T, T);
template <typename T, int N>
__ESIMD_INTRIN __ESIMD_raw_vec_t(T, N)
__spirv_ocl_native_powr(__ESIMD_raw_vec_t(T, N), __ESIMD_raw_vec_t(T, N));

// saturation intrinsics
template <typename T0, typename T1, int SZ>
__ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
Expand Down
126 changes: 99 additions & 27 deletions sycl/include/sycl/ext/intel/esimd/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ __esimd_abs_common_internal(simd<TArg, SZ> src0) {
}

template <typename TRes, typename TArg>
ESIMD_NODEBUG
ESIMD_INLINE std::enable_if_t<detail::is_esimd_scalar<TRes>::value &&
detail::is_esimd_scalar<TArg>::value,
TRes>
__esimd_abs_common_internal(TArg src0) {

__ESIMD_API std::enable_if_t<detail::is_esimd_scalar<TRes>::value &&
detail::is_esimd_scalar<TArg>::value,
TRes>
__esimd_abs_common_internal(TArg src0) {
simd<TArg, 1> Src0 = src0;
simd<TArg, 1> Result = __esimd_abs_common_internal<TArg>(Src0);
return convert<TRes>(Result)[0];
Expand Down Expand Up @@ -342,67 +342,98 @@ std::enable_if_t<detail::is_esimd_scalar<T>::value, T>(min)(T src0, T src1,
/// @addtogroup sycl_esimd_math_ext
/// @{

#if defined(__SYCL_DEVICE_ONLY__)
#define __ESIMD_VECTOR_IMPL(T, name, iname) \
__ESIMD_DNS::vector_type_t<__ESIMD_DNS::__raw_t<T>, N> res = \
__spirv_ocl_native_##iname<__ESIMD_DNS::__raw_t<T>, N>(src.data()); \
if constexpr (std::is_same_v<Sat, saturation_off_tag>) \
return res; \
else \
return esimd::saturate<T>(simd<T, N>(res));
#define __ESIMD_SCALAR_IMPL(T, name, iname) \
__ESIMD_DNS::__raw_t<T> res = \
__spirv_ocl_native_##iname<__ESIMD_DNS::__raw_t<T>>(src); \
if constexpr (std::is_same_v<Sat, saturation_off_tag>) \
return res; \
else \
return esimd::saturate<T>(simd<T, 1>(res))[0];
#else
#define __ESIMD_VECTOR_IMPL(T, name, iname) return 0;
#define __ESIMD_SCALAR_IMPL(T, name, iname) return 0;
#endif // __SYCL_DEVICE_ONLY__

#define __ESIMD_UNARY_INTRINSIC_DEF(COND, name, iname) \
/** Vector version. */ \
template <class T, int N, class Sat = saturation_off_tag, \
class = std::enable_if_t<COND>> \
__ESIMD_API simd<T, N> name(simd<T, N> src, Sat sat = {}) { \
__ESIMD_DNS::vector_type_t<__ESIMD_DNS::__raw_t<T>, N> res = \
__esimd_##iname<T, N>(src.data()); \
if constexpr (std::is_same_v<Sat, saturation_off_tag>) \
return res; \
else \
return esimd::saturate<T>(simd<T, N>(res)); \
__ESIMD_VECTOR_IMPL(T, name, iname) \
} \
\
/** Scalar version. */ \
template <typename T, class Sat = saturation_off_tag, \
class = std::enable_if_t<COND>> \
__ESIMD_API T name(T src, Sat sat = {}) { \
simd<T, 1> src_vec = src; \
simd<T, 1> res = name<T, 1>(src_vec, sat); \
return res[0]; \
__ESIMD_SCALAR_IMPL(T, name, iname) \
}

#define __ESIMD_EMATH_COND \
detail::is_generic_floating_point_v<T> && (sizeof(T) <= 4)

#define __ESIMD_EMATH_IEEE_COND \
detail::is_generic_floating_point_v<T> && (sizeof(T) >= 4)

#define __ESIMD_EMATH_SPIRV_COND \
std::is_same_v<T, float> || std::is_same_v<T, sycl::half>

/// Inversion - calculates (1/x). Supports \c half and \c float.
/// Precision: 1 ULP.
__ESIMD_UNARY_INTRINSIC_DEF(__ESIMD_EMATH_COND, inv, inv)
__ESIMD_UNARY_INTRINSIC_DEF(__ESIMD_EMATH_SPIRV_COND, inv, recip)

/// Logarithm base 2. Supports \c half and \c float.
/// Precision depending on argument range:
/// - [0.5..2]: absolute error is <code>2^-21</code> or less
/// - (0..0.5) or (2..+INF]: relative error is <code>2^-21</code> or less
__ESIMD_UNARY_INTRINSIC_DEF(__ESIMD_EMATH_COND, log2, log)
__ESIMD_UNARY_INTRINSIC_DEF(__ESIMD_EMATH_SPIRV_COND, log2, log2)

/// Exponent base 2. Supports \c half and \c float.
/// Precision: 4 ULP.
__ESIMD_UNARY_INTRINSIC_DEF(__ESIMD_EMATH_COND, exp2, exp)
__ESIMD_UNARY_INTRINSIC_DEF(__ESIMD_EMATH_SPIRV_COND, exp2, exp2)

/// Square root. Is not IEEE754-compatible. Supports \c half, \c float and
/// \c double. Precision: 4 ULP.
__ESIMD_UNARY_INTRINSIC_DEF(detail::is_generic_floating_point_v<T>, sqrt, sqrt)

/// IEEE754-compliant square root. Supports \c float and \c double.
__ESIMD_UNARY_INTRINSIC_DEF(__ESIMD_EMATH_IEEE_COND, sqrt_ieee, ieee_sqrt)
template <class T, int N, class Sat = saturation_off_tag,
class = std::enable_if_t<__ESIMD_EMATH_IEEE_COND>>
__ESIMD_API simd<T, N> sqrt_ieee(simd<T, N> src, Sat sat = {}) {
__ESIMD_DNS::vector_type_t<__ESIMD_DNS::__raw_t<T>, N> res =
__esimd_ieee_sqrt<T, N>(src.data());
if constexpr (std::is_same_v<Sat, saturation_off_tag>)
return res;
else
return esimd::saturate<T>(simd<T, N>(res));
}

/** Scalar version. */
template <typename T, class Sat = saturation_off_tag,
class = std::enable_if_t<__ESIMD_EMATH_IEEE_COND>>
__ESIMD_API T sqrt_ieee(T src, Sat sat = {}) {
simd<T, 1> src_vec = src;
simd<T, 1> res = sqrt_ieee<T, 1>(src_vec, sat);
return res[0];
}

/// Square root reciprocal - calculates <code>1/sqrt(x)</code>.
/// Supports \c half and \c float.
/// Precision: 4 ULP.
__ESIMD_UNARY_INTRINSIC_DEF(__ESIMD_EMATH_COND, rsqrt, rsqrt)
__ESIMD_UNARY_INTRINSIC_DEF(__ESIMD_EMATH_SPIRV_COND, rsqrt, rsqrt)

/// Sine. Supports \c half and \c float.
/// Absolute error: \c 0.0008 or less for the range [-32767*pi, 32767*pi].
__ESIMD_UNARY_INTRINSIC_DEF(__ESIMD_EMATH_COND, sin, sin)
__ESIMD_UNARY_INTRINSIC_DEF(__ESIMD_EMATH_SPIRV_COND, sin, sin)

/// Cosine. Supports \c half and \c float.
/// Absolute error: \c 0.0008 or less for the range [-32767*pi, 32767*pi].
__ESIMD_UNARY_INTRINSIC_DEF(__ESIMD_EMATH_COND, cos, cos)
__ESIMD_UNARY_INTRINSIC_DEF(__ESIMD_EMATH_SPIRV_COND, cos, cos)

template <class T, int N, class Sat = saturation_off_tag>
__ESIMD_API std::enable_if_t<std::is_same_v<T, double>, simd<double, N>>
Expand All @@ -424,6 +455,8 @@ rsqrt(T src, Sat sat = {}) {
}

#undef __ESIMD_UNARY_INTRINSIC_DEF
#undef __ESIMD_VECTOR_IMPL
#undef __ESIMD_SCALAR_IMPL

#define __ESIMD_BINARY_INTRINSIC_DEF(COND, name, iname) \
/** (vector, vector) version. */ \
Expand Down Expand Up @@ -457,15 +490,54 @@ rsqrt(T src, Sat sat = {}) {

/// Power - calculates \c src0 in power of \c src1. Note available in DG2, PVC.
/// Supports \c half and \c float.
/// TODO document accuracy etc.
__ESIMD_BINARY_INTRINSIC_DEF(__ESIMD_EMATH_COND, pow, pow)
template <class T, int N, class U, class Sat = saturation_off_tag,
class = std::enable_if_t<__ESIMD_EMATH_SPIRV_COND>>
__ESIMD_API simd<T, N> pow(simd<T, N> src0, simd<U, N> src1, Sat sat = {}) {
#if defined(__SYCL_DEVICE_ONLY__)
using RawVecT = __ESIMD_DNS::vector_type_t<__ESIMD_DNS::__raw_t<T>, N>;
RawVecT src1_raw_conv = detail::convert_vector<T, U, N>(src1.data());
RawVecT res_raw = __spirv_ocl_native_powr<__ESIMD_DNS::__raw_t<T>, N>(
src0.data(), src1_raw_conv);
if constexpr (std::is_same_v<Sat, saturation_off_tag>)
return res_raw;
else
return esimd::saturate<T>(simd<T, N>(res_raw));
#else
return 0;
#endif // __SYCL_DEVICE_ONLY__
}

/** (vector, scalar) version. */
template <class T, int N, class U, class Sat = saturation_off_tag,
class = std::enable_if_t<__ESIMD_EMATH_SPIRV_COND>>
__ESIMD_API simd<T, N> pow(simd<T, N> src0, U src1, Sat sat = {}) {
return pow<T, N, U>(src0, simd<U, N>(src1), sat);
}

/** (scalar, scalar) version. */
template <class T, class U, class Sat = saturation_off_tag,
class = std::enable_if_t<__ESIMD_EMATH_SPIRV_COND>>
__ESIMD_API T pow(T src0, U src1, Sat sat = {}) {
#if defined(__SYCL_DEVICE_ONLY__)
using ResT = __ESIMD_DNS::__raw_t<T>;
ResT src1_raw_conv = detail::convert_scalar<T, U>(src1);
ResT res_raw =
__spirv_ocl_native_powr<__ESIMD_DNS::__raw_t<T>>(src0, src1_raw_conv);
if constexpr (std::is_same_v<Sat, saturation_off_tag>)
return res_raw;
else
return esimd::saturate<T>(simd<T, 1>(res_raw))[0];
#else
return 0;
#endif // __SYCL_DEVICE_ONLY__
}

/// IEEE754-compliant floating-point division. Supports \c float and \c double.
__ESIMD_BINARY_INTRINSIC_DEF(__ESIMD_EMATH_IEEE_COND, div_ieee, ieee_div)

#undef __ESIMD_BINARY_INTRINSIC_DEF
#undef __ESIMD_EMATH_COND
#undef __ESIMD_EMATH_IEEE_COND
#undef __ESIMD_EMATH_SPIRV_COND

/// @} sycl_esimd_math_ext

Expand Down
12 changes: 6 additions & 6 deletions sycl/test/esimd/math_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ SYCL_ESIMD_FUNCTION SYCL_EXTERNAL simd<float, 16> sycl_math(simd<float, 16> x) {
SYCL_ESIMD_FUNCTION SYCL_EXTERNAL simd<float, 16>
esimd_math(simd<float, 16> x) {
simd<float, 16> v = 0;
//CHECK: call spir_func noundef <16 x float> @_Z11__esimd_cos
//CHECK: call spir_func noundef <16 x float> @_Z22__spirv_ocl_native_cos{{[^\(]*}}
v = esimd::cos(x);
//CHECK: call spir_func noundef <16 x float> @_Z11__esimd_sin
//CHECK: call spir_func noundef <16 x float> @_Z22__spirv_ocl_native_sin{{[^\(]*}}
v = esimd::sin(v);
//CHECK: call spir_func noundef <16 x float> @_Z11__esimd_log
//CHECK: call spir_func noundef <16 x float> @_Z23__spirv_ocl_native_log2{{[^\(]*}}
v = esimd::log2(v);
//CHECK: call spir_func noundef <16 x float> @_Z11__esimd_exp
//CHECK: call spir_func noundef <16 x float> @_Z23__spirv_ocl_native_exp2{{[^\(]*}}
v = esimd::exp2(v);
return v;
}
Expand All @@ -47,9 +47,9 @@ esimd_math(simd<float, 16> x) {
SYCL_ESIMD_FUNCTION SYCL_EXTERNAL simd<float, 16>
esimd_math_emu(simd<float, 16> x) {
simd<float, 16> v = 0;
//CHECK: call spir_func noundef <16 x float> @_Z11__esimd_log
//CHECK: call spir_func noundef <16 x float> @_Z23__spirv_ocl_native_log2{{[^\(]*}}
v = esimd::log(x);
//CHECK: call spir_func noundef <16 x float> @_Z11__esimd_exp
//CHECK: call spir_func noundef <16 x float> @_Z23__spirv_ocl_native_exp2{{[^\(]*}}
v = esimd::exp(v);
return v;
}
Expand Down
2 changes: 1 addition & 1 deletion sycl/test/esimd/sycl_half_math_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ SYCL_EXTERNAL auto test_ext_math_op(simd<sycl::half, 8> val) SYCL_ESIMD_FUNCTION
// CHECK: define dso_local spir_func <8 x half> @_Z16test_ext_math_op{{[^\(]*}}(
// CHECK: <8 x half> %[[VAL_VEC:[a-zA-Z0-9_\.]+]]){{.*}} {
return esimd::cos(val);
// CHECK: %[[RES:[a-zA-Z0-9_\.]+]] = call <8 x half> @llvm.genx.cos.v8f16(<8 x half> %[[VAL_VEC]])
// CHECK: %[[RES:[a-zA-Z0-9_\.]+]] = call spir_func noundef <8 x half> @_Z22__spirv_ocl_native_cos{{[^\(]*}}(<8 x half> noundef %[[VAL_VEC]])
// CHECK-NEXT: ret <8 x half> %[[RES]]
// CHECK-LABEL: }
}
Expand Down

0 comments on commit af65855

Please sign in to comment.