From bef767cfb21546c69c39e2bf4ecd6c0e05412f43 Mon Sep 17 00:00:00 2001 From: Denis Yaroshevskiy Date: Fri, 27 Sep 2024 04:55:21 -0700 Subject: [PATCH] clearing bit utils (#2301) Summary: Pull Request resolved: https://github.com/facebook/folly/pull/2301 make_maskl make_maskr set_lzero set_lone set_rzero set_lone Simple utils that correctly handle corner cases, such as shift == 64. I looked at the assembly a bit, probably that's ok. For x86 I used bmi2 where was appropriate. Differential Revision: D63329499 --- folly/lang/Bits.h | 150 ++++++++++++++++++++++++++ folly/lang/test/BitsTest.cpp | 200 +++++++++++++++++++++++++++++++++++ 2 files changed, 350 insertions(+) diff --git a/folly/lang/Bits.h b/folly/lang/Bits.h index 981feab9be7..1afc204b747 100644 --- a/folly/lang/Bits.h +++ b/folly/lang/Bits.h @@ -67,6 +67,10 @@ #include #include +#ifdef __BMI2__ +#include +#endif + #if __has_include() && (__cplusplus >= 202002L || (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)) #include #endif @@ -106,6 +110,11 @@ constexpr std::make_unsigned_t bits_to_unsigned(Src const s) { static_assert(std::is_unsigned::value, "signed type"); return static_cast(to_unsigned(s)); } + +template +inline constexpr bool supported_in_bits_operations_v = + std::is_unsigned_v && sizeof(T) <= 8; + } // namespace detail /// findFirstSet @@ -223,6 +232,147 @@ inline constexpr T strictPrevPowTwo(T const v) { return v > 1 ? prevPowTwo(T(v - 1)) : T(0); } +/// make_maskr +/// make_maskr_fn +/// +/// Returns an unsigned integer of type T, where n +/// least significant (right) bits are set and others are not. +template +struct make_maskr_fn { + static_assert(detail::supported_in_bits_operations_v, ""); + + FOLLY_NODISCARD constexpr T operator()(std::uint32_t n) const { + if (!folly::is_constant_evaluated_or(true)) { + compiler_may_unsafely_assume(n <= sizeof(T) * 8); + +#ifdef __BMI2__ + if constexpr (sizeof(T) <= 4) { + return static_cast(_bzhi_u32(static_cast(-1), n)); + } + return static_cast(_bzhi_u64(static_cast(-1), n)); +#endif + } + + if (sizeof(T) == 8 && n == 64) { + return static_cast(-1); + } + return static_cast((std::uint64_t{1} << n) - 1); + } +}; + +template +inline constexpr make_maskr_fn make_maskr; + +/// make_maskl +/// make_maskl_fn +/// +/// Returns an unsigned integer of type T, where n +/// most significant bits (left) are set and others are not. +template +struct make_maskl_fn { + static_assert(detail::supported_in_bits_operations_v, ""); + + FOLLY_NODISCARD constexpr T operator()(std::uint32_t n) const { + if (!folly::is_constant_evaluated_or(true)) { + compiler_may_unsafely_assume(n <= sizeof(T) * 8); + +#ifdef __BMI2__ + // assembler looks smaller here, if we use bzhi from `set_lowest_n_bits` + if constexpr (sizeof(T) == 8) { + return static_cast(~make_maskr(64 - n)); + } +#endif + } + + if (sizeof(T) == 8 && n == 0) { + return 0; + } + n = sizeof(T) * 8 - n; + + std::uint64_t ones = static_cast(~0); + return static_cast(ones << n); + } +}; + +template +inline constexpr make_maskl_fn make_maskl; + +/// set_rzero +/// set_rzero_fn +/// +/// Clears n least significant (right) bits. Other bits stay the same. +struct set_rzero_fn { + template + FOLLY_NODISCARD constexpr T operator()(T x, std::uint32_t n) const { + static_assert(detail::supported_in_bits_operations_v, ""); + + // alternative is to do two shifts but that has + // a dependency between them, so is likely worse + return x & make_maskl(sizeof(T) * 8 - n); + } +}; + +inline constexpr set_rzero_fn set_rzero; + +/// set_rone +/// set_rone_fn +/// +/// Sets n least significant (right) bits. Other bits stay the same. +struct set_rone_fn { + template + FOLLY_NODISCARD constexpr T operator()(T x, std::uint32_t n) const { + static_assert(detail::supported_in_bits_operations_v, ""); + + // alternative is to do two shifts but that has + // a dependency between them, so is likely worse + return x | make_maskr(n); + } +}; + +inline constexpr set_rone_fn set_rone; + +/// set_lzero +/// set_lzero_fn +/// +/// Clears n most significant (left) bits. Other bits stay the same. +struct set_lzero_fn { + template + FOLLY_NODISCARD constexpr T operator()(T x, std::uint32_t n) const { + static_assert(detail::supported_in_bits_operations_v, ""); + + if (!folly::is_constant_evaluated_or(true)) { + compiler_may_unsafely_assume(n <= sizeof(T) * 8); + +#ifdef __BMI2__ + if constexpr (sizeof(T) <= 4) { + return static_cast(_bzhi_u32(x, sizeof(T) * 8 - n)); + } + return static_cast(_bzhi_u64(x, sizeof(T) * 8 - n)); +#endif + } + + // alternative is to do two shifts but that has + // a dependency between them, so is likely worse + return x & make_maskr(sizeof(T) * 8 - n); + } +}; + +inline constexpr set_lzero_fn set_lzero; + +/// set_lone +/// set_lone_fn +/// +/// Sets n most significant (left) bits. Other bits stay the same. +struct set_lone_fn { + template + FOLLY_NODISCARD constexpr T operator()(T x, std::uint32_t n) const { + static_assert(detail::supported_in_bits_operations_v, ""); + return x | make_maskl(n); + } +}; + +inline constexpr set_lone_fn set_lone; + /** * Endianness detection and manipulation primitives. */ diff --git a/folly/lang/test/BitsTest.cpp b/folly/lang/test/BitsTest.cpp index 5888c748872..fec0d0bb34b 100644 --- a/folly/lang/test/BitsTest.cpp +++ b/folly/lang/test/BitsTest.cpp @@ -74,6 +74,14 @@ void testEFS() { } } +template +struct BitsAllUintsTest : ::testing::Test {}; + +using UintsToTest = + ::testing::Types; + +TYPED_TEST_SUITE(BitsAllUintsTest, UintsToTest); + } // namespace TEST(Bits, FindFirstSet) { @@ -350,4 +358,196 @@ TEST(Bits, LoadUnalignedUB) { EXPECT_EQ(0, x); } +TYPED_TEST(BitsAllUintsTest, MakeMaskR) { + using T = TypeParam; + + static_assert(make_maskr(0) == 0b0, ""); + static_assert(make_maskr(1) == 0b1, ""); + static_assert(make_maskr(2) == 0b11, ""); + static_assert(make_maskr(3) == 0b111, ""); + static_assert(make_maskr(4) == 0b1111, ""); + + auto test = [] { + for (std::uint32_t i = 0; i <= std::min(sizeof(T) * 8, 63UL); ++i) { + std::uint64_t expected = (std::uint64_t{1} << i) - 1; + T actual = make_maskr(i); + if (expected != actual) { + EXPECT_EQ(expected, actual) << i; + return false; + } + if (std::countr_one(expected) != static_cast(i)) { + EXPECT_EQ(i, std::countr_one(expected)) << i; + return false; + } + } + + if (sizeof(T) == 8) { + std::uint64_t expected = std::numeric_limits::max(); + T actual = make_maskr(64); + if (expected != actual) { + EXPECT_EQ(expected, actual) << 64; + return false; + } + } + + return true; + }; + + static_assert(test(), ""); + + // runtime can use a different implementation + EXPECT_TRUE(test()); +} + +TYPED_TEST(BitsAllUintsTest, MakeMaskL) { + using T = TypeParam; + + constexpr std::size_t kBitSize = sizeof(T) * 8; + + static_assert(make_maskl(kBitSize) == static_cast(~0b0), ""); + static_assert(make_maskl(kBitSize - 1) == static_cast(~0b1), ""); + static_assert(make_maskl(kBitSize - 2) == static_cast(~0b11), ""); + static_assert(make_maskl(kBitSize - 3) == static_cast(~0b111), ""); + static_assert(make_maskl(kBitSize - 4) == static_cast(~0b1111), ""); + + auto test = [] { + for (std::uint32_t i = 0; i <= kBitSize; ++i) { + T expected = ~make_maskr(kBitSize - i); + T actual = make_maskl(i); + if (expected != actual) { + EXPECT_EQ(expected, actual) << i; + return false; + } + if (std::countl_one(expected) != static_cast(i)) { + EXPECT_EQ(i, std::countl_one(expected)) << i; + return false; + } + } + return true; + }; + + static_assert(test(), ""); + + // runtime can use a different implementation + EXPECT_TRUE(test()); +} + +TYPED_TEST(BitsAllUintsTest, SetRzero) { + using T = TypeParam; + + constexpr std::size_t kBitSize = sizeof(T) * 8; + + static_assert(set_rzero(T{0b11U}, 1U) == 0b10U, ""); + static_assert(set_rzero(T{0b101U}, 1U) == 0b100U, ""); + + auto test = [] { + for (std::uint32_t i = 0; i <= kBitSize; ++i) { + T expected = make_maskl(kBitSize - i); + T actual = set_rzero(static_cast(-1), i); + if (expected != actual) { + EXPECT_EQ(expected, actual) << i; + return false; + } + if (std::countr_zero(expected) != static_cast(i)) { + EXPECT_EQ(i, std::countr_zero(expected)) << i; + return false; + } + } + return true; + }; + static_assert(test(), ""); + + // runtime can use a different implementation + EXPECT_TRUE(test()); +} + +TYPED_TEST(BitsAllUintsTest, SetRone) { + using T = TypeParam; + + constexpr std::size_t kBitSize = sizeof(T) * 8; + + static_assert(set_rone(T{0b10U}, 1U) == 0b11U, ""); + static_assert(set_rone(T{0b100U}, 1U) == 0b101U, ""); + static_assert(set_rone(T{0b100U}, 2U) == 0b111U, ""); + + auto test = [] { + for (std::uint32_t i = 0; i <= kBitSize; ++i) { + T expected = make_maskr(i); + T actual = set_rone(T{}, i); + if (expected != actual) { + EXPECT_EQ(expected, actual) << i; + return false; + } + if (std::countr_one(expected) != static_cast(i)) { + EXPECT_EQ(i, std::countr_one(expected)) << i; + return false; + } + } + return true; + }; + static_assert(test(), ""); + + // runtime can use a different implementation + EXPECT_TRUE(test()); +} + +TYPED_TEST(BitsAllUintsTest, SetLzero) { + using T = TypeParam; + + constexpr std::size_t kBitSize = sizeof(T) * 8; + + static_assert(set_lzero(T{0b101U}, kBitSize - 1) == 0b1U, ""); + static_assert(set_lzero(T{0b1100U}, kBitSize - 3) == 0b100U, ""); + + auto test = [] { + for (std::uint32_t i = 0; i <= kBitSize; ++i) { + T expected = make_maskr(kBitSize - i); + T actual = set_lzero(static_cast(-1), i); + if (expected != actual) { + EXPECT_EQ(expected, actual) << i; + return false; + } + if (std::countl_zero(expected) != static_cast(i)) { + EXPECT_EQ(i, std::countl_zero(expected)) << i; + return false; + } + } + return true; + }; + static_assert(test(), ""); + + // runtime can use a different implementation + EXPECT_TRUE(test()); +} + +TYPED_TEST(BitsAllUintsTest, SetLone) { + using T = TypeParam; + + constexpr std::size_t kBitSize = sizeof(T) * 8; + + static_assert(set_lone(T{0b1}, kBitSize - 2) == static_cast(~0b10), ""); + static_assert( + set_lone(T{0b1100U}, kBitSize - 3) == static_cast(~0b11), ""); + + auto test = [] { + for (std::uint32_t i = 0; i <= kBitSize; ++i) { + T expected = make_maskl(i); + T actual = set_lone(static_cast(0), i); + if (expected != actual) { + EXPECT_EQ(expected, actual) << i; + return false; + } + if (std::countl_one(expected) != static_cast(i)) { + EXPECT_EQ(i, std::countl_one(expected)) << i; + return false; + } + } + return true; + }; + static_assert(test(), ""); + + // runtime can use a different implementation + EXPECT_TRUE(test()); +} + } // namespace folly