From b94958dff044a1c063beff0d998e402525f913ca Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 23 Oct 2024 20:21:55 +0200 Subject: [PATCH] Prevent canceling uncancelable generic callbacks (#303) Generic callbacks should not be canceled once they start running. The guarantee provided by `ucxx::Worker` is that once the function returns it should be safe to destroy the callback and all its associated resources, which becomes invalid if the callback is scheduled for cancellation but it is already running, therefore, it's a requirement to check whether the callback is already executing and block it until it's finished. If the callback never completes this may cause an irrecoverable hang which cannot be dealt with from UCXX since it's impossible to stop a callback from executing once it has started, it's the user's responsibility to guarantee the callback must return. A warning is raised after multiples of 10 attempts have been tried to cancel a callback that is being executed and canceling did not succeed, so that the user is informed of what is happening. The most notable issue is somewhat frequently observable in CI, where the Python async test `test_from_worker_address_multinode` would segfault, in particular with larger amount of endpoints. This was observable in those tests more frequently because there's a large amount of endpoints being created simultaneously by multiple processes, putting more pressure in the resources and causing endpoint creation to take several seconds to complete. In those cases the generic callback executing `ucp_ep_create` would take longer than the default timeout of 3 seconds and in some cases that would be interpreted as the callback timed out, since `ucp_ep_create` itself took longer than 3 seconds, causing the worker to attempt to cancel the callback while it was still executing. With this change, the callback will still timeout but only if it didn't start executing yet, if `ucp_ep_create` ends up never returning, this will cause a deadlock in the application but there's no way for UCXX to recover on its own and warnings are raised, although those hypothetical deadlocks have not been observed in local tests so far. Segfaults should not occur in this situation anymore. Additionally, unit tests for generic callbacks are now included, which previously were a gap in the testing suite. Authors: - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/ucxx/pull/303 --- cpp/include/ucxx/delayed_submission.h | 21 ++- cpp/include/ucxx/worker.h | 16 ++- cpp/src/delayed_submission.cpp | 2 +- cpp/src/endpoint.cpp | 6 +- cpp/src/worker.cpp | 40 ++++-- cpp/src/worker_progress_thread.cpp | 2 +- cpp/tests/worker.cpp | 191 +++++++++++++++++++++++++- 7 files changed, 247 insertions(+), 31 deletions(-) diff --git a/cpp/include/ucxx/delayed_submission.h b/cpp/include/ucxx/delayed_submission.h index fc248694..209cf158 100644 --- a/cpp/include/ucxx/delayed_submission.h +++ b/cpp/include/ucxx/delayed_submission.h @@ -45,6 +45,8 @@ class BaseDelayedSubmissionCollection { std::string _name{"undefined"}; ///< The human-readable name of the collection, used for logging bool _enabled{true}; ///< Whether the resource required to process the collection is enabled. ItemIdType _itemId{0}; ///< The item ID counter, used to allow cancelation. + std::optional _processing{ + std::nullopt}; ///< The ID of the item being processed, if any. std::deque> _collection{}; ///< The collection. std::set _canceled{}; ///< IDs of canceled items. std::mutex _mutex{}; ///< Mutex to provide access to `_collection`. @@ -150,10 +152,17 @@ class BaseDelayedSubmissionCollection { item = std::move(_collection.front()); _collection.pop_front(); if (_canceled.erase(item.first)) continue; + _processing = std::optional{item.first}; } processItem(item.first, item.second); } + + { + // Clear the value of `_processing` as no more requests will be processed. + std::lock_guard lock(_mutex); + _processing = std::nullopt; + } } /** @@ -162,17 +171,17 @@ class BaseDelayedSubmissionCollection { * Cancel a pending callback and thus do not execute it, unless the execution has * already begun, in which case cancelation cannot be done. * + * @throws std::runtime_error if the item is being processed and canceling is not + * possible anymore. + * * @param[in] id the ID of the scheduled item, as returned by `schedule()`. */ void cancel(ItemIdType id) { std::lock_guard lock(_mutex); - // TODO: Check if not cancellable anymore? Will likely need a separate set to keep - // track of registered items. - // - // If the callback is already running - // and the user has no way of knowing that but still destroys it, undefined - // behavior may occur. + if (_processing.has_value() && _processing.value() == id) + throw std::runtime_error("Cannot cancel, item is being processed."); + _canceled.insert(id); ucxx_trace_req("Canceled item: %lu", id); } diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index a5fb3c8d..c9808d0b 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -437,8 +437,12 @@ class Worker : public Component { * If `period` is `0` this is a blocking call that only returns when the callback has been * executed and will always return `true`, and if `period` is a positive integer the time * in nanoseconds will be waited for the callback to complete and return `true` in the - * successful case or `false` otherwise. `period` only applies if the worker progress - * thread is running, otherwise the callback is immediately executed. + * successful case or `false` otherwise. However, if the callback is not cancelable + * anymore (i.e., it has already started), this method will keep retrying and may never + * return if the callback never completes, it is unsafe to return as this would allow the + * caller to destroy the callback and its resources causing undefined behavior. `period` + * only applies if the worker progress thread is running, otherwise the callback is + * immediately executed. * * @param[in] callback the callback to execute before progressing the worker. * @param[in] period the time in nanoseconds to wait for the callback to complete. @@ -462,8 +466,12 @@ class Worker : public Component { * If `period` is `0` this is a blocking call that only returns when the callback has been * executed and will always return `true`, and if `period` is a positive integer the time * in nanoseconds will be waited for the callback to complete and return `true` in the - * successful case or `false` otherwise. `period` only applies if the worker progress - * thread is running, otherwise the callback is immediately executed. + * successful case or `false` otherwise. However, if the callback is not cancelable + * anymore (i.e., it has already started), this method will keep retrying and may never + * return if the callback never completes, it is unsafe to return as this would allow the + * caller to destroy the callback and its resources causing undefined behavior. `period` + * only applies if the worker progress thread is running, otherwise the callback is + * immediately executed. * * @param[in] callback the callback to execute before progressing the worker. * @param[in] period the time in nanoseconds to wait for the callback to complete. diff --git a/cpp/src/delayed_submission.cpp b/cpp/src/delayed_submission.cpp index 978ba268..ff7f1ed0 100644 --- a/cpp/src/delayed_submission.cpp +++ b/cpp/src/delayed_submission.cpp @@ -94,6 +94,6 @@ ItemIdType DelayedSubmissionCollection::registerGenericPost(DelayedSubmissionCal void DelayedSubmissionCollection::cancelGenericPre(ItemIdType id) { _genericPre.cancel(id); } -void DelayedSubmissionCollection::cancelGenericPost(ItemIdType id) { _genericPre.cancel(id); } +void DelayedSubmissionCollection::cancelGenericPost(ItemIdType id) { _genericPost.cancel(id); } } // namespace ucxx diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index ac431370..4731b78d 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -136,10 +136,12 @@ void Endpoint::create(ucp_ep_params_t* params) 3000000000 /* 3s */)) break; - if (i == maxAttempts - 1) + if (i == maxAttempts - 1) { + status = UCS_ERR_TIMED_OUT; ucxx_error("Timeout waiting for ucp_ep_create, all attempts failed"); - else + } else { ucxx_warn("Timeout waiting for ucp_ep_create, retrying"); + } } utils::ucsErrorThrow(status); } else { diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index bc80d609..2adf8eca 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -347,11 +347,21 @@ bool Worker::registerGenericPre(DelayedSubmissionCallbackType callback, uint64_t } signalWorkerFunction(); - auto ret = callbackNotifier.wait(period, signalWorkerFunction); - - if (!ret) _delayedSubmissionCollection->cancelGenericPre(id); - - return ret; + size_t retryCount = 0; + while (true) { + auto ret = callbackNotifier.wait(period, signalWorkerFunction); + + try { + if (!ret) _delayedSubmissionCollection->cancelGenericPre(id); + return ret; + } catch (const std::runtime_error& e) { + if (++retryCount % 10 == 0) + ucxx_warn( + "Could not cancel after %lu attempts, the callback has not returned and the process " + "may stop responding.", + retryCount); + } + } } } @@ -384,11 +394,21 @@ bool Worker::registerGenericPost(DelayedSubmissionCallbackType callback, uint64_ } signalWorkerFunction(); - auto ret = callbackNotifier.wait(period, signalWorkerFunction); - - if (!ret) _delayedSubmissionCollection->cancelGenericPost(id); - - return ret; + size_t retryCount = 0; + while (true) { + auto ret = callbackNotifier.wait(period, signalWorkerFunction); + + try { + if (!ret) _delayedSubmissionCollection->cancelGenericPost(id); + return ret; + } catch (const std::runtime_error& e) { + if (++retryCount % 10 == 0) + ucxx_warn( + "Could not cancel after %lu attempts, the callback has not returned and the process " + "may stop responding.", + retryCount); + } + } } } diff --git a/cpp/src/worker_progress_thread.cpp b/cpp/src/worker_progress_thread.cpp index 21e56573..c645be3f 100644 --- a/cpp/src/worker_progress_thread.cpp +++ b/cpp/src/worker_progress_thread.cpp @@ -83,7 +83,7 @@ void WorkerProgressThread::stop() }); _signalWorkerFunction(); if (!callbackNotifierPost.wait(3000000000)) { - _delayedSubmissionCollection->cancelGenericPre(idPost); + _delayedSubmissionCollection->cancelGenericPost(idPost); } _thread.join(); diff --git a/cpp/tests/worker.cpp b/cpp/tests/worker.cpp index 77667045..c941c20a 100644 --- a/cpp/tests/worker.cpp +++ b/cpp/tests/worker.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -17,6 +18,18 @@ namespace { using ::testing::Combine; using ::testing::Values; +enum class GenericCallbackType { + None = 0, + Pre, + Post, + PrePost, + PostPre, +}; + +struct ExtraParams { + GenericCallbackType genericCallbackType{GenericCallbackType::None}; +}; + class WorkerTest : public ::testing::Test { protected: std::shared_ptr _context{ @@ -43,16 +56,18 @@ class WorkerCapabilityTest : public ::testing::Test, } }; -class WorkerProgressTest : public WorkerTest, - public ::testing::WithParamInterface> { +class WorkerProgressTest + : public WorkerTest, + public ::testing::WithParamInterface> { protected: std::function _progressWorker; bool _enableDelayedSubmission; ProgressMode _progressMode; + ExtraParams _extraParams; void SetUp() { - std::tie(_enableDelayedSubmission, _progressMode) = GetParam(); + std::tie(_enableDelayedSubmission, _progressMode, _extraParams) = GetParam(); _worker = _context->createWorker(_enableDelayedSubmission); @@ -74,6 +89,10 @@ class WorkerProgressTest : public WorkerTest, } }; +class WorkerGenericCallbackTest : public WorkerProgressTest {}; + +class WorkerGenericCallbackSingleTest : public WorkerProgressTest {}; + TEST_F(WorkerTest, HandleIsValid) { ASSERT_TRUE(_worker->getHandle() != nullptr); } TEST_P(WorkerCapabilityTest, CheckCapability) @@ -327,6 +346,144 @@ TEST_P(WorkerProgressTest, ProgressTagMulti) } } +TEST_P(WorkerGenericCallbackTest, RegisterGeneric) +{ + bool done1 = false; + bool done2 = false; + auto callback1 = [&done1]() { done1 = true; }; + auto callback2 = [&done2]() { done2 = true; }; + + if (_extraParams.genericCallbackType == GenericCallbackType::Pre) { + ASSERT_TRUE(_worker->registerGenericPre(callback1)); + ASSERT_TRUE(done1); + } else if (_extraParams.genericCallbackType == GenericCallbackType::Post) { + ASSERT_TRUE(_worker->registerGenericPre(callback1)); + ASSERT_TRUE(done1); + } else if (_extraParams.genericCallbackType == GenericCallbackType::PrePost) { + ASSERT_TRUE(_worker->registerGenericPre(callback1)); + ASSERT_TRUE(_worker->registerGenericPost(callback2)); + ASSERT_TRUE(done1); + ASSERT_TRUE(done2); + } else if (_extraParams.genericCallbackType == GenericCallbackType::PostPre) { + ASSERT_TRUE(_worker->registerGenericPost(callback1)); + ASSERT_TRUE(_worker->registerGenericPre(callback2)); + ASSERT_TRUE(done1); + ASSERT_TRUE(done2); + } +} + +TEST_P(WorkerGenericCallbackTest, RegisterGenericCancel) +{ + bool threadStarted = false; + bool terminateThread = false; + bool done = false; + auto callback = [&done] { done = true; }; + + std::mutex m{}; + std::condition_variable conditionVariable{}; + + std::thread thread = + std::thread([this, &threadStarted, &terminateThread, &m, &conditionVariable]() { + auto threadCallback = [&threadStarted, &terminateThread, &m, &conditionVariable]() { + // Allow main thread to test for generic callback cancelation. + threadStarted = true; + conditionVariable.notify_one(); + + { + std::unique_lock l(m); + // Wait until the main thread had a generic callback cancelled + conditionVariable.wait(l, [&terminateThread] { return terminateThread; }); + } + }; + + if (_extraParams.genericCallbackType == GenericCallbackType::Pre || + _extraParams.genericCallbackType == GenericCallbackType::PrePost) { + ASSERT_TRUE(_worker->registerGenericPre(threadCallback)); + } else if (_extraParams.genericCallbackType == GenericCallbackType::Post || + _extraParams.genericCallbackType == GenericCallbackType::PostPre) { + ASSERT_TRUE(_worker->registerGenericPost(threadCallback)); + } + }); + + { + std::unique_lock l(m); + // Wait until thread starts and blocks. + conditionVariable.wait(l, [&threadStarted] { return threadStarted; }); + } + + // The thread should be running, therefore the callback will be canceled before running. + // Note here `PrePost`/`PostPre` order is the opposite as from `thread`. + if (_extraParams.genericCallbackType == GenericCallbackType::Pre || + _extraParams.genericCallbackType == GenericCallbackType::PostPre) { + ASSERT_FALSE(_worker->registerGenericPre(callback, 1)); + } else if (_extraParams.genericCallbackType == GenericCallbackType::Post || + _extraParams.genericCallbackType == GenericCallbackType::PrePost) { + ASSERT_FALSE(_worker->registerGenericPost(callback, 1)); + } + ASSERT_FALSE(done); + + // Unblock thread to terminate. + terminateThread = true; + conditionVariable.notify_one(); + thread.join(); + + // Nothing should be blocking the progress thread now, the callback should succeed. + // Note here `PrePost`/`PostPre` order is the opposite as from `thread`. + if (_extraParams.genericCallbackType == GenericCallbackType::Pre || + _extraParams.genericCallbackType == GenericCallbackType::PostPre) { + ASSERT_TRUE(_worker->registerGenericPre(callback)); + } else if (_extraParams.genericCallbackType == GenericCallbackType::Post || + _extraParams.genericCallbackType == GenericCallbackType::PrePost) { + ASSERT_TRUE(_worker->registerGenericPost(callback)); + } + ASSERT_TRUE(done); +} + +TEST_P(WorkerGenericCallbackSingleTest, RegisterGenericPreUncancelable) +{ + bool terminateThread = false; + bool match = false; + + std::mutex m{}; + std::condition_variable conditionVariable{}; + + std::thread thread = std::thread([this, &terminateThread, &m, &conditionVariable]() { + auto threadCallback = [&terminateThread, &m, &conditionVariable]() { + { + std::unique_lock l(m); + conditionVariable.wait(l, [&terminateThread] { return terminateThread; }); + } + }; + + // This will submit the callback and attempt to cancel once every 1ms, + // a warning is logged when multiples of 10 attempts to cancel are made. + if (_extraParams.genericCallbackType == GenericCallbackType::Pre) + ASSERT_TRUE(_worker->registerGenericPre(threadCallback, 1000000 /* 1ms */)); + else if (_extraParams.genericCallbackType == GenericCallbackType::Post) + ASSERT_TRUE(_worker->registerGenericPost(threadCallback, 1000000 /* 1ms */)); + }); + + loopWithTimeout(std::chrono::milliseconds(5000), [&match] { + testing::internal::CaptureStdout(); + + // We need to allow some time for stdout to be populated, + // `GetCapturedStdout()` does not return the cumulative log. + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + match = ::testing::Matches(::testing::ContainsRegex( + "Could not cancel after .* attempts, the callback has not returned and the process may stop " + "responding."))(::testing::internal::GetCapturedStdout()); + return match; + }); + + // Unblock thread to terminate. + terminateThread = true; + conditionVariable.notify_one(); + thread.join(); + + ASSERT_TRUE(match); +} + INSTANTIATE_TEST_SUITE_P(ProgressModes, WorkerProgressTest, Combine(Values(false), @@ -334,11 +491,31 @@ INSTANTIATE_TEST_SUITE_P(ProgressModes, ProgressMode::Blocking, ProgressMode::Wait, ProgressMode::ThreadPolling, - ProgressMode::ThreadBlocking))); + ProgressMode::ThreadBlocking), + Values(ExtraParams{}))); + +INSTANTIATE_TEST_SUITE_P(DelayedSubmission, + WorkerProgressTest, + Combine(Values(true), + Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), + Values(ExtraParams{}))); + +INSTANTIATE_TEST_SUITE_P( + GenericCallbacks, + WorkerGenericCallbackTest, + Combine(Values(false, true), + Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), + Values(ExtraParams{.genericCallbackType = GenericCallbackType::Pre}, + ExtraParams{.genericCallbackType = GenericCallbackType::Post}, + ExtraParams{.genericCallbackType = GenericCallbackType::PrePost}, + ExtraParams{.genericCallbackType = GenericCallbackType::PostPre}))); INSTANTIATE_TEST_SUITE_P( - DelayedSubmission, - WorkerProgressTest, - Combine(Values(true), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking))); + GenericCallbacksSingle, + WorkerGenericCallbackSingleTest, + Combine(Values(false, true), + Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), + Values(ExtraParams{.genericCallbackType = GenericCallbackType::Pre}, + ExtraParams{.genericCallbackType = GenericCallbackType::Post}))); } // namespace