From 1f37b5ed804123d9068d78a739f8451584fdafca Mon Sep 17 00:00:00 2001 From: aelovikov-intel Date: Fri, 9 Feb 2024 02:33:50 -0800 Subject: [PATCH] [SYCL] Fix SFINAE rules for integer builtins/bitselect (#12671) In case of vectors/swizzles of integer types only fixed width types are allowed per SYCL 2020 revision 8. Update the implementation to match that. --- sycl/include/sycl/builtins_preview.hpp | 9 +- sycl/include/sycl/builtins_utils_scalar.hpp | 11 ++ .../sycl/detail/builtins/helper_macros.hpp | 8 ++ .../detail/builtins/integer_functions.inc | 7 +- .../detail/builtins/relational_functions.inc | 10 +- sycl/source/builtins/host_helper_macros.hpp | 7 + sycl/source/builtins/integer_functions.cpp | 3 +- sycl/source/builtins/relational_functions.cpp | 3 +- .../builtins/builtin_unit_tests.cpp | 133 ++++++++++++++++++ 9 files changed, 181 insertions(+), 10 deletions(-) create mode 100644 sycl/test/basic_tests/builtins/builtin_unit_tests.cpp diff --git a/sycl/include/sycl/builtins_preview.hpp b/sycl/include/sycl/builtins_preview.hpp index dfc44e5848f91..91fd50d2ec4f2 100644 --- a/sycl/include/sycl/builtins_preview.hpp +++ b/sycl/include/sycl/builtins_preview.hpp @@ -137,7 +137,14 @@ auto builtin_marray_impl(FuncTy F, const Ts &...x) { marray Res; constexpr auto N = T::size(); for (size_t I = 0; I < N / 2; ++I) { - auto PartialRes = F(to_vec2(x, I * 2)...); + auto PartialRes = [&]() { + using elem_ty = get_elem_type_t; + if constexpr (std::is_integral_v) + return F(to_vec2(x, I * 2) + .template as, 2>>()...); + else + return F(to_vec2(x, I * 2)...); + }(); std::memcpy(&Res[I * 2], &PartialRes, sizeof(decltype(PartialRes))); } if (N % 2) diff --git a/sycl/include/sycl/builtins_utils_scalar.hpp b/sycl/include/sycl/builtins_utils_scalar.hpp index 62d86df045cdf..31f267892243b 100644 --- a/sycl/include/sycl/builtins_utils_scalar.hpp +++ b/sycl/include/sycl/builtins_utils_scalar.hpp @@ -128,6 +128,17 @@ template struct get_unsigned_int_by_size { template struct same_size_unsigned_int { using type = typename get_unsigned_int_by_size::type; }; +template +using same_size_unsigned_int_t = typename same_size_unsigned_int::type; + +template struct get_fixed_sized_int { + static_assert(std::is_integral_v); + using type = + std::conditional_t, same_size_signed_int_t, + same_size_unsigned_int_t>; +}; +template +using get_fixed_sized_int_t = typename get_fixed_sized_int::type; // Utility trait for getting an upsampled integer type. // NOTE: For upsampling we look for an integer of double the size of the diff --git a/sycl/include/sycl/detail/builtins/helper_macros.hpp b/sycl/include/sycl/detail/builtins/helper_macros.hpp index 49d4af8981d70..38014c4b62c8f 100644 --- a/sycl/include/sycl/detail/builtins/helper_macros.hpp +++ b/sycl/include/sycl/detail/builtins/helper_macros.hpp @@ -48,6 +48,11 @@ FOR_EACH4_A6(BASE_CASE, FIXED1, FIXED2, FIXED3, FIXED4, ARG1, ARG2, ARG3, \ ARG4, ARG5, ARG6) \ BASE_CASE(FIXED1, FIXED2, FIXED3, FIXED4, ARG7) +#define FOR_EACH4_A8(BASE_CASE, FIXED1, FIXED2, FIXED3, FIXED4, ARG1, ARG2, \ + ARG3, ARG4, ARG5, ARG6, ARG7, ARG8) \ + FOR_EACH4_A7(BASE_CASE, FIXED1, FIXED2, FIXED3, FIXED4, ARG1, ARG2, ARG3, \ + ARG4, ARG5, ARG6, ARG7) \ + BASE_CASE(FIXED1, FIXED2, FIXED3, FIXED4, ARG8) #define FOR_EACH4_A11(BASE_CASE, FIXED1, FIXED2, FIXED3, FIXED4, ARG1, ARG2, \ ARG3, ARG4, ARG5, ARG6, ARG7, ARG8, ARG9, ARG10, ARG11) \ FOR_EACH4_A7(BASE_CASE, FIXED1, FIXED2, FIXED3, FIXED4, ARG1, ARG2, ARG3, \ @@ -169,6 +174,9 @@ unsigned char, unsigned short, unsigned int, unsigned long, unsigned long long // 11 types #define INTEGER_TYPES SIGNED_TYPES, UNSIGNED_TYPES +// 8 types +#define FIXED_WIDTH_INTEGER_TYPES \ + int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t #define DEVICE_IMPL_TEMPLATE_CUSTOM_DELEGATE( \ NUM_ARGS, NAME, ENABLER, DELEGATOR, NS, /*SCALAR_VEC_IMPL*/...) \ diff --git a/sycl/include/sycl/detail/builtins/integer_functions.inc b/sycl/include/sycl/detail/builtins/integer_functions.inc index dab0cee8a647a..44699765ff7fd 100644 --- a/sycl/include/sycl/detail/builtins/integer_functions.inc +++ b/sycl/include/sycl/detail/builtins/integer_functions.inc @@ -16,9 +16,10 @@ namespace detail { template struct integer_elem_type : std::bool_constant< - check_type_in_v, char, signed char, short, int, - long, long long, unsigned char, unsigned short, - unsigned int, unsigned long, unsigned long long>> {}; + (is_vec_or_swizzle_v && + check_type_in_v, FIXED_WIDTH_INTEGER_TYPES>) || + (!is_vec_or_swizzle_v && + check_type_in_v, INTEGER_TYPES>)> {}; template struct suint32_elem_type : std::bool_constant< diff --git a/sycl/include/sycl/detail/builtins/relational_functions.inc b/sycl/include/sycl/detail/builtins/relational_functions.inc index fb0b3f7682b0a..d63a7716aa9b7 100644 --- a/sycl/include/sycl/detail/builtins/relational_functions.inc +++ b/sycl/include/sycl/detail/builtins/relational_functions.inc @@ -15,10 +15,12 @@ inline namespace _V1 { namespace detail { template struct bitselect_elem_type - : std::bool_constant, float, double, half, char, signed char, short, - int, long, long long, unsigned char, unsigned short, unsigned int, - unsigned long, unsigned long long>> {}; + : std::bool_constant< + check_type_in_v, FP_TYPES> || + (is_vec_or_swizzle_v && + check_type_in_v, FIXED_WIDTH_INTEGER_TYPES>) || + (!is_vec_or_swizzle_v && + check_type_in_v, INTEGER_TYPES>)> {}; template struct rel_ret_traits diff --git a/sycl/source/builtins/host_helper_macros.hpp b/sycl/source/builtins/host_helper_macros.hpp index 484b0bc95fb8b..41aac2148db71 100644 --- a/sycl/source/builtins/host_helper_macros.hpp +++ b/sycl/source/builtins/host_helper_macros.hpp @@ -56,6 +56,9 @@ #define EXPORT_VEC(NUM_ARGS, NAME, TYPE, VL) \ EXPORT_VEC_NS(NUM_ARGS, NAME, sycl, TYPE, VL) +#define EXPORT_VEC_1_16_IMPL(NUM_ARGS, NAME, NS, TYPE) \ + FOR_VEC_1_16(EXPORT_VEC_NS, NUM_ARGS, NAME, NS, TYPE) + #define EXPORT_SCALAR_AND_VEC_1_16_IMPL(NUM_ARGS, NAME, NS, TYPE) \ EXPORT_SCALAR_NS(NUM_ARGS, NAME, NS, TYPE) \ FOR_VEC_1_16(EXPORT_VEC_NS, NUM_ARGS, NAME, NS, TYPE) @@ -69,8 +72,12 @@ #define EXPORT_SCALAR_AND_VEC_1_16_NS(NUM_ARGS, NAME, NS, ...) \ FOR_EACH3(EXPORT_SCALAR_AND_VEC_1_16_IMPL, NUM_ARGS, NAME, NS, __VA_ARGS__) +#define EXPORT_VEC_1_16_NS(NUM_ARGS, NAME, NS, ...) \ + FOR_EACH3(EXPORT_VEC_1_16_IMPL, NUM_ARGS, NAME, NS, __VA_ARGS__) #define EXPORT_SCALAR_AND_VEC_1_16(NUM_ARGS, NAME, ...) \ EXPORT_SCALAR_AND_VEC_1_16_NS(NUM_ARGS, NAME, sycl, __VA_ARGS__) +#define EXPORT_VEC_1_16(NUM_ARGS, NAME, ...) \ + EXPORT_VEC_1_16_NS(NUM_ARGS, NAME, sycl, __VA_ARGS__) #define EXPORT_SCALAR_AND_VEC_2_4(NUM_ARGS, NAME, ...) \ FOR_EACH2(EXPORT_SCALAR_AND_VEC_2_4_IMPL, NUM_ARGS, NAME, __VA_ARGS__) diff --git a/sycl/source/builtins/integer_functions.cpp b/sycl/source/builtins/integer_functions.cpp index 381d6f1fa0a10..cd92b2180df73 100644 --- a/sycl/source/builtins/integer_functions.cpp +++ b/sycl/source/builtins/integer_functions.cpp @@ -76,7 +76,8 @@ namespace sycl { inline namespace _V1 { #define BUILTIN_GENINT(NUM_ARGS, NAME, IMPL) \ HOST_IMPL(NAME, IMPL) \ - EXPORT_SCALAR_AND_VEC_1_16(NUM_ARGS, NAME, INTEGER_TYPES) + FOR_EACH2(EXPORT_SCALAR, NUM_ARGS, NAME, INTEGER_TYPES) \ + EXPORT_VEC_1_16(NUM_ARGS, NAME, FIXED_WIDTH_INTEGER_TYPES) #define BUILTIN_GENINT_SU(NUM_ARGS, NAME, IMPL) \ BUILTIN_GENINT(NUM_ARGS, NAME, IMPL) diff --git a/sycl/source/builtins/relational_functions.cpp b/sycl/source/builtins/relational_functions.cpp index b54c55e283e5e..b8b7795f6fb79 100644 --- a/sycl/source/builtins/relational_functions.cpp +++ b/sycl/source/builtins/relational_functions.cpp @@ -103,6 +103,7 @@ HOST_IMPL(bitselect, [](auto x, auto y, auto z) { assert((ures & std::numeric_limits::max()) == ures); return bit_cast(static_cast(ures)); }) -EXPORT_SCALAR_AND_VEC_1_16(THREE_ARGS, bitselect, INTEGER_TYPES, FP_TYPES) +FOR_EACH2(EXPORT_SCALAR, THREE_ARGS, bitselect, INTEGER_TYPES, FP_TYPES) +EXPORT_VEC_1_16(THREE_ARGS, bitselect, FIXED_WIDTH_INTEGER_TYPES, FP_TYPES) } // namespace _V1 } // namespace sycl diff --git a/sycl/test/basic_tests/builtins/builtin_unit_tests.cpp b/sycl/test/basic_tests/builtins/builtin_unit_tests.cpp new file mode 100644 index 0000000000000..d241a90568fc3 --- /dev/null +++ b/sycl/test/basic_tests/builtins/builtin_unit_tests.cpp @@ -0,0 +1,133 @@ +// RUN: %clangxx -fsycl -fpreview-breaking-changes -fsyntax-only %s -Xclang -verify +// REQUIRES: preview-breaking-changes-supported + +#include + +using namespace sycl; +using namespace sycl::detail; + +namespace builtin_same_shape_v_tests { +using swizzle1 = decltype(std::declval>().swizzle<0>()); +using swizzle2 = decltype(std::declval>().swizzle<0, 0>()); +using swizzle3 = decltype(std::declval>().swizzle<0, 0, 1>()); + +static_assert(builtin_same_shape_v); +static_assert(builtin_same_shape_v); +static_assert(builtin_same_shape_v>); +static_assert(builtin_same_shape_v, marray>); +static_assert(builtin_same_shape_v>); +static_assert(builtin_same_shape_v, vec>); +static_assert(builtin_same_shape_v, swizzle2>); + +static_assert(!builtin_same_shape_v>); +static_assert(!builtin_same_shape_v>); +static_assert(!builtin_same_shape_v, vec>); +static_assert(!builtin_same_shape_v); +static_assert(!builtin_same_shape_v, swizzle1>); +static_assert(!builtin_same_shape_v); +} // namespace builtin_same_shape_v_tests + +namespace builtin_marray_impl_tests { +// Integer functions/relational bitselect only accept fixed-width integer +// element types for vector/swizzle elements. Make sure that our marray->vec +// delegator can handle that. + +auto foo(char x) { return x; } +auto foo(signed char x) { return x; } +auto foo(unsigned char x) { return x; } +auto foo(vec x) { return x; } +auto foo(vec x) { return x; } + +auto test() { + marray x; + marray y; + marray z; + auto TestOne = [](auto x) { + std::ignore = builtin_marray_impl([](auto x) { return foo(x); }, x); + }; + TestOne(x); + TestOne(y); + TestOne(z); +} +} // namespace builtin_marray_impl_tests + +namespace builtin_enable_integer_tests { +using swizzle1 = decltype(std::declval>().swizzle<0>()); +using swizzle2 = decltype(std::declval>().swizzle<0, 0>()); +template void ignore() {} + +void test() { + // clang-format off + ignore, + builtin_enable_integer_t, + builtin_enable_integer_t>(); + // clang-format on + + ignore>, + builtin_enable_integer_t>>(); + + ignore>(); + ignore, vec>>(); + ignore, swizzle2>>(); + ignore>(); + + { + // Only one of char/signed char maps onto int8_t. The other type isn't a + // valid vector element type for integer builtins. + + static_assert(std::is_signed_v); + + // clang-format off + // expected-error-re@*:* {{no type named 'type' in 'sycl::detail::builtin_enable>'}} + // expected-note@+1 {{in instantiation of template type alias 'builtin_enable_integer_t' requested here}} + ignore>, builtin_enable_integer_t>>(); + // clang-format on + } + + // expected-error@*:* {{no type named 'type' in 'sycl::detail::builtin_enable'}} + // expected-note@+1 {{in instantiation of template type alias 'builtin_enable_integer_t' requested here}} + ignore>(); +} +} // namespace builtin_enable_integer_tests + +namespace builtin_enable_bitselect_tests { +// Essentially the same as builtin_enable_integer_t + FP types support. +using swizzle1 = decltype(std::declval>().swizzle<0>()); +using swizzle2 = decltype(std::declval>().swizzle<0, 0>()); +template void ignore() {} + +void test() { + // clang-format off + ignore, + builtin_enable_bitselect_t, + builtin_enable_bitselect_t, + builtin_enable_bitselect_t>(); + // clang-format on + + ignore>, + builtin_enable_bitselect_t>, + builtin_enable_bitselect_t>>(); + + ignore>(); + ignore, vec>>(); + ignore, swizzle2>>(); + ignore>(); + + { + // Only one of char/signed char maps onto int8_t. The other type isn't a + // valid vector element type for integer builtins. + + static_assert(std::is_signed_v); + + // clang-format off + // expected-error-re@*:* {{no type named 'type' in 'sycl::detail::builtin_enable>'}} + // expected-note@+1 {{in instantiation of template type alias 'builtin_enable_bitselect_t' requested here}} + ignore>, builtin_enable_bitselect_t>>(); + // clang-format on + } + + // expected-error@*:* {{no type named 'type' in 'sycl::detail::builtin_enable'}} + // expected-note@+1 {{in instantiation of template type alias 'builtin_enable_bitselect_t' requested here}} + ignore>(); +} +} // namespace builtin_enable_bitselect_tests