Skip to content

Commit

Permalink
Add checks for unsafe implicit conversions in RangePolicy (kokkos#6754)
Browse files Browse the repository at this point in the history
* Added a check for unintended implicit conversion in RangePolicy

* Changed to integrate Rangepolicy implicit checks in existing constructors instead

* Changed ifdef conditions to allow outputting warning messages even when deprecated code is used
Modified the unit test to test warning outputs
Changed implicit conversion check to be tested in debug mode only

* Fixed incorrect gtest_skip call and unused var warning

* Removed ifdef kokkos_enable_debug guards

* Switch to use the new interface
  • Loading branch information
ldh4 authored Feb 7, 2024
1 parent 31fb476 commit 442e4d4
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 6 deletions.
75 changes: 69 additions & 6 deletions core/src/Kokkos_ExecPolicy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ static_assert(false,
#include <impl/Kokkos_AnalyzePolicy.hpp>
#include <Kokkos_Concepts.hpp>
#include <typeinfo>
#include <limits>

//----------------------------------------------------------------------------

Expand Down Expand Up @@ -114,39 +115,57 @@ class RangePolicy : public Impl::PolicyTraits<Properties...> {
m_granularity_mask(0) {}

/** \brief Total range */
template <typename IndexType1, typename IndexType2,
std::enable_if_t<(std::is_convertible_v<IndexType1, member_type> &&
std::is_convertible_v<IndexType2, member_type>),
bool> = false>
inline RangePolicy(const typename traits::execution_space& work_space,
const member_type work_begin, const member_type work_end)
const IndexType1 work_begin, const IndexType2 work_end)
: m_space(work_space),
m_begin(work_begin),
m_end(work_end),
m_granularity(0),
m_granularity_mask(0) {
check_conversion_safety(work_begin);
check_conversion_safety(work_end);
check_bounds_validity();
set_auto_chunk_size();
}

/** \brief Total range */
inline RangePolicy(const member_type work_begin, const member_type work_end)
template <typename IndexType1, typename IndexType2,
std::enable_if_t<(std::is_convertible_v<IndexType1, member_type> &&
std::is_convertible_v<IndexType2, member_type>),
bool> = false>
inline RangePolicy(const IndexType1 work_begin, const IndexType2 work_end)
: RangePolicy(typename traits::execution_space(), work_begin, work_end) {}

/** \brief Total range */
template <class... Args>
template <typename IndexType1, typename IndexType2, typename... Args,
std::enable_if_t<(std::is_convertible_v<IndexType1, member_type> &&
std::is_convertible_v<IndexType2, member_type>),
bool> = false>
inline RangePolicy(const typename traits::execution_space& work_space,
const member_type work_begin, const member_type work_end,
const IndexType1 work_begin, const IndexType2 work_end,
Args... args)
: m_space(work_space),
m_begin(work_begin),
m_end(work_end),
m_granularity(0),
m_granularity_mask(0) {
check_conversion_safety(work_begin);
check_conversion_safety(work_end);
check_bounds_validity();
set_auto_chunk_size();
set(args...);
}

/** \brief Total range */
template <class... Args>
inline RangePolicy(const member_type work_begin, const member_type work_end,
template <typename IndexType1, typename IndexType2, typename... Args,
std::enable_if_t<(std::is_convertible_v<IndexType1, member_type> &&
std::is_convertible_v<IndexType2, member_type>),
bool> = false>
inline RangePolicy(const IndexType1 work_begin, const IndexType2 work_end,
Args... args)
: RangePolicy(typename traits::execution_space(), work_begin, work_end,
args...) {}
Expand Down Expand Up @@ -233,6 +252,50 @@ class RangePolicy : public Impl::PolicyTraits<Properties...> {
}
}

// To be replaced with std::in_range (c++20)
template <typename IndexType>
static void check_conversion_safety(const IndexType bound) {
#if !defined(KOKKOS_ENABLE_DEPRECATED_CODE_4) || \
defined(KOKKOS_ENABLE_DEPRECATION_WARNINGS)

std::string msg =
"Kokkos::RangePolicy bound type error: an unsafe implicit conversion "
"is performed on a bound (" +
std::to_string(bound) +
"), which may "
"not preserve its original value.\n";
bool warn = false;

if constexpr (std::is_signed_v<IndexType> !=
std::is_signed_v<member_type>) {
// check signed to unsigned
if constexpr (std::is_signed_v<IndexType>)
warn |= (bound < static_cast<IndexType>(
std::numeric_limits<member_type>::min()));

// check unsigned to signed
if constexpr (std::is_signed_v<member_type>)
warn |= (bound > static_cast<IndexType>(
std::numeric_limits<member_type>::max()));
}

// check narrowing
warn |= (static_cast<IndexType>(static_cast<member_type>(bound)) != bound);

if (warn) {
#ifndef KOKKOS_ENABLE_DEPRECATED_CODE_4
Kokkos::abort(msg.c_str());
#endif

#ifdef KOKKOS_ENABLE_DEPRECATION_WARNINGS
Kokkos::Impl::log_warning(msg);
#endif
}
#else
(void)bound;
#endif
}

public:
/** \brief Subrange for a partition's rank and size.
*
Expand Down
75 changes: 75 additions & 0 deletions core/unit_test/TestRangePolicyConstructors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <Kokkos_Core.hpp>

#include <regex>
#include <limits>

namespace {

Expand Down Expand Up @@ -121,4 +122,78 @@ TEST(TEST_CATEGORY_DEATH, range_policy_invalid_bounds) {
#endif
}

TEST(TEST_CATEGORY_DEATH, range_policy_implicitly_converted_bounds) {
using UIntIndexType = Kokkos::IndexType<unsigned>;
using IntIndexType = Kokkos::IndexType<int>;
using UIntPolicy = Kokkos::RangePolicy<TEST_EXECSPACE, UIntIndexType>;
using IntPolicy = Kokkos::RangePolicy<TEST_EXECSPACE, IntIndexType>;

std::string msg =
"Kokkos::RangePolicy bound type error: an unsafe implicit conversion is "
"performed on a bound (), which may not preserve its original value.\n";

auto get_error_msg = [](auto str, auto val) {
return str.insert(str.find("(") + 1, std::to_string(val).c_str());
};
#ifndef KOKKOS_ENABLE_DEPRECATED_CODE_4
std::string expected = std::regex_replace(msg, std::regex("\\(|\\)"), "\\$&");
{
int test_val = -1;
ASSERT_DEATH({ (void)UIntPolicy(test_val, 10); },
get_error_msg(expected, test_val));
}
{
unsigned test_val = std::numeric_limits<unsigned>::max();
ASSERT_DEATH({ (void)IntPolicy(0u, test_val); },
get_error_msg(expected, test_val));
}
{
long long test_val = std::numeric_limits<long long>::max();
ASSERT_DEATH({ (void)IntPolicy(0LL, test_val); },
get_error_msg(expected, test_val));
}
{
int test_val = -1;
ASSERT_DEATH({ (void)UIntPolicy(test_val, 10, Kokkos::ChunkSize(2)); },
get_error_msg(expected, test_val));
}

#else
{
::testing::internal::CaptureStderr();
int test_val = -1;
UIntPolicy policy(test_val, 10);
ASSERT_EQ(policy.begin(), 0u);
ASSERT_EQ(policy.end(), 0u);
#ifdef KOKKOS_ENABLE_DEPRECATION_WARNINGS
if (Kokkos::show_warnings()) {
auto s = std::string(::testing::internal::GetCapturedStderr());
ASSERT_EQ(s.substr(0, s.find("\n") + 1), get_error_msg(msg, test_val));
}
#else
ASSERT_TRUE(::testing::internal::GetCapturedStderr().empty());
(void)msg;
(void)get_error_msg;
#endif
}
{
::testing::internal::CaptureStderr();
unsigned test_val = std::numeric_limits<unsigned>::max();
IntPolicy policy(0u, test_val);
ASSERT_EQ(policy.begin(), 0);
ASSERT_EQ(policy.end(), 0);
#ifdef KOKKOS_ENABLE_DEPRECATION_WARNINGS
if (Kokkos::show_warnings()) {
auto s = std::string(::testing::internal::GetCapturedStderr());
ASSERT_EQ(s.substr(0, s.find("\n") + 1), get_error_msg(msg, test_val));
}
#else
ASSERT_TRUE(::testing::internal::GetCapturedStderr().empty());
(void)msg;
(void)get_error_msg;
#endif
}
#endif
}

} // namespace

0 comments on commit 442e4d4

Please sign in to comment.