Skip to content

Commit

Permalink
change contains interface to be more restrictive (facebook#2309)
Browse files Browse the repository at this point in the history
Summary:

previous interface allowed for implicit conversions.
So it so happens that the user would have to guard against it.

I think I better guard here.

Differential Revision: D63979345
  • Loading branch information
DenisYaroshevskiy authored and facebook-github-bot committed Oct 7, 2024
1 parent 45ffb40 commit 59978fe
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 11 deletions.
46 changes: 37 additions & 9 deletions folly/algorithm/simd/Contains.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,32 @@ template <typename R>
using std_range_value_t = typename std::iterator_traits<decltype(std::begin(
std::declval<R&>()))>::value_type;

template <typename From, typename To>
constexpr bool convertible_with_no_loss() {
if (sizeof(From) > sizeof(To)) {
return false;
}
if (std::is_signed_v<From>) {
return std::is_signed_v<To>;
}

return is_unsigned_v<To> || sizeof(From) < sizeof(To);
}

template <typename R, typename T>
constexpr bool contains_haystack_needle_test() {
if constexpr (!std::is_invocable_v<AsSimdFriendlyUintFn, R>) {
return false;
} else if constexpr (!has_integral_simd_friendly_equivalent_scalar_v<T>) {
return false;
} else {
using simd_haystack_value =
simd_friendly_equivalent_scalar_t<std_range_value_t<R>>;
using simd_needle = simd_friendly_equivalent_scalar_t<T>;
return convertible_with_no_loss<simd_needle, simd_haystack_value>();
}
}

} // namespace detail

/**
Expand All @@ -53,23 +79,25 @@ using std_range_value_t = typename std::iterator_traits<decltype(std::begin(
struct contains_fn {
template <
typename R,
typename = std::enable_if_t<
std::is_invocable_v<detail::AsSimdFriendlyUintFn, R>>>
FOLLY_ERASE bool operator()(R&& r, detail::std_range_value_t<R> x) const {
typename T,
typename =
std::enable_if_t<detail::contains_haystack_needle_test<R, T>()>>
FOLLY_ERASE bool operator()(R&& r, T x) const {
auto castR = detail::asSimdFriendlyUint(folly::span(r));
auto castX = detail::asSimdFriendlyUint(x);
using value_type = detail::std_range_value_t<decltype(castR)>;

using T = decltype(castX);
auto castX = static_cast<value_type>(x);

if constexpr (std::is_same_v<T, std::uint8_t>) {
if constexpr (std::is_same_v<value_type, std::uint8_t>) {
return detail::containsU8(castR, castX);
} else if constexpr (std::is_same_v<T, std::uint16_t>) {
} else if constexpr (std::is_same_v<value_type, std::uint16_t>) {
return detail::containsU16(castR, castX);
} else if constexpr (std::is_same_v<T, std::uint32_t>) {
} else if constexpr (std::is_same_v<value_type, std::uint32_t>) {
return detail::containsU32(castR, castX);
} else {
static_assert(
std::is_same_v<T, std::uint64_t>, "internal error, unknown type");
std::is_same_v<value_type, std::uint64_t>,
"internal error, unknown type");
return detail::containsU64(castR, castX);
}
}
Expand Down
4 changes: 2 additions & 2 deletions folly/algorithm/simd/detail/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ using simd_friendly_equivalent_scalar_t = std::enable_if_t<
like_t<T, decltype(findSimdFriendlyEquivalent<std::remove_const_t<T>>())>>;

template <typename T>
constexpr bool has_integral_simd_friendly_equivalent_scalar =
constexpr bool has_integral_simd_friendly_equivalent_scalar_v =
std::is_integral_v< // void will return false
decltype(findSimdFriendlyEquivalent<std::remove_const_t<T>>())>;

template <typename T>
using unsigned_simd_friendly_equivalent_scalar_t = std::enable_if_t<
has_integral_simd_friendly_equivalent_scalar<T>,
has_integral_simd_friendly_equivalent_scalar_v<T>,
like_t<T, uint_bits_t<sizeof(T) * 8>>>;

template <typename R>
Expand Down
48 changes: 48 additions & 0 deletions folly/algorithm/simd/test/ContainsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

#include <folly/portability/GTest.h>

#include <list>
#include <vector>

namespace folly::simd {
namespace not_invocable_tests {

Expand All @@ -33,6 +36,51 @@ static_assert(std::is_invocable_v< //
std::vector<int>&,
int>);

static_assert(std::is_invocable_v< //
folly::simd::contains_fn,
std::vector<int>&,
int>);

static_assert(std::is_invocable_v< //
folly::simd::contains_fn,
std::vector<int>&,
std::int16_t>);

static_assert(std::is_invocable_v< //
folly::simd::contains_fn,
std::vector<int>&,
std::uint16_t>);

static_assert(!std::is_invocable_v< //
folly::simd::contains_fn,
std::vector<int>&,
std::uint32_t>);

static_assert(!std::is_invocable_v< //
folly::simd::contains_fn,
std::vector<int>&,
std::int64_t>);

static_assert(!std::is_invocable_v< //
folly::simd::contains_fn,
std::vector<std::uint32_t>&,
std::int16_t>);

static_assert(std::is_invocable_v< //
folly::simd::contains_fn,
std::vector<std::uint32_t>&,
std::uint16_t>);

static_assert(!std::is_invocable_v< //
folly::simd::contains_fn,
std::list<std::int32_t>&,
std::int32_t>);

static_assert(!std::is_invocable_v< //
folly::simd::contains_fn,
const std::vector<std::vector<std::int32_t>>&,
std::vector<std::int32_t>>);

} // namespace not_invocable_tests

template <typename T>
Expand Down

0 comments on commit 59978fe

Please sign in to comment.