diff --git a/cpp/include/ucxx/delayed_submission.h b/cpp/include/ucxx/delayed_submission.h index fc248694..209cf158 100644 --- a/cpp/include/ucxx/delayed_submission.h +++ b/cpp/include/ucxx/delayed_submission.h @@ -45,6 +45,8 @@ class BaseDelayedSubmissionCollection { std::string _name{"undefined"}; ///< The human-readable name of the collection, used for logging bool _enabled{true}; ///< Whether the resource required to process the collection is enabled. ItemIdType _itemId{0}; ///< The item ID counter, used to allow cancelation. + std::optional _processing{ + std::nullopt}; ///< The ID of the item being processed, if any. std::deque> _collection{}; ///< The collection. std::set _canceled{}; ///< IDs of canceled items. std::mutex _mutex{}; ///< Mutex to provide access to `_collection`. @@ -150,10 +152,17 @@ class BaseDelayedSubmissionCollection { item = std::move(_collection.front()); _collection.pop_front(); if (_canceled.erase(item.first)) continue; + _processing = std::optional{item.first}; } processItem(item.first, item.second); } + + { + // Clear the value of `_processing` as no more requests will be processed. + std::lock_guard lock(_mutex); + _processing = std::nullopt; + } } /** @@ -162,17 +171,17 @@ class BaseDelayedSubmissionCollection { * Cancel a pending callback and thus do not execute it, unless the execution has * already begun, in which case cancelation cannot be done. * + * @throws std::runtime_error if the item is being processed and canceling is not + * possible anymore. + * * @param[in] id the ID of the scheduled item, as returned by `schedule()`. */ void cancel(ItemIdType id) { std::lock_guard lock(_mutex); - // TODO: Check if not cancellable anymore? Will likely need a separate set to keep - // track of registered items. - // - // If the callback is already running - // and the user has no way of knowing that but still destroys it, undefined - // behavior may occur. + if (_processing.has_value() && _processing.value() == id) + throw std::runtime_error("Cannot cancel, item is being processed."); + _canceled.insert(id); ucxx_trace_req("Canceled item: %lu", id); } diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index a5fb3c8d..c9808d0b 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -437,8 +437,12 @@ class Worker : public Component { * If `period` is `0` this is a blocking call that only returns when the callback has been * executed and will always return `true`, and if `period` is a positive integer the time * in nanoseconds will be waited for the callback to complete and return `true` in the - * successful case or `false` otherwise. `period` only applies if the worker progress - * thread is running, otherwise the callback is immediately executed. + * successful case or `false` otherwise. However, if the callback is not cancelable + * anymore (i.e., it has already started), this method will keep retrying and may never + * return if the callback never completes, it is unsafe to return as this would allow the + * caller to destroy the callback and its resources causing undefined behavior. `period` + * only applies if the worker progress thread is running, otherwise the callback is + * immediately executed. * * @param[in] callback the callback to execute before progressing the worker. * @param[in] period the time in nanoseconds to wait for the callback to complete. @@ -462,8 +466,12 @@ class Worker : public Component { * If `period` is `0` this is a blocking call that only returns when the callback has been * executed and will always return `true`, and if `period` is a positive integer the time * in nanoseconds will be waited for the callback to complete and return `true` in the - * successful case or `false` otherwise. `period` only applies if the worker progress - * thread is running, otherwise the callback is immediately executed. + * successful case or `false` otherwise. However, if the callback is not cancelable + * anymore (i.e., it has already started), this method will keep retrying and may never + * return if the callback never completes, it is unsafe to return as this would allow the + * caller to destroy the callback and its resources causing undefined behavior. `period` + * only applies if the worker progress thread is running, otherwise the callback is + * immediately executed. * * @param[in] callback the callback to execute before progressing the worker. * @param[in] period the time in nanoseconds to wait for the callback to complete. diff --git a/cpp/src/delayed_submission.cpp b/cpp/src/delayed_submission.cpp index 978ba268..ff7f1ed0 100644 --- a/cpp/src/delayed_submission.cpp +++ b/cpp/src/delayed_submission.cpp @@ -94,6 +94,6 @@ ItemIdType DelayedSubmissionCollection::registerGenericPost(DelayedSubmissionCal void DelayedSubmissionCollection::cancelGenericPre(ItemIdType id) { _genericPre.cancel(id); } -void DelayedSubmissionCollection::cancelGenericPost(ItemIdType id) { _genericPre.cancel(id); } +void DelayedSubmissionCollection::cancelGenericPost(ItemIdType id) { _genericPost.cancel(id); } } // namespace ucxx diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index ac431370..4731b78d 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -136,10 +136,12 @@ void Endpoint::create(ucp_ep_params_t* params) 3000000000 /* 3s */)) break; - if (i == maxAttempts - 1) + if (i == maxAttempts - 1) { + status = UCS_ERR_TIMED_OUT; ucxx_error("Timeout waiting for ucp_ep_create, all attempts failed"); - else + } else { ucxx_warn("Timeout waiting for ucp_ep_create, retrying"); + } } utils::ucsErrorThrow(status); } else { diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index bc80d609..2adf8eca 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -347,11 +347,21 @@ bool Worker::registerGenericPre(DelayedSubmissionCallbackType callback, uint64_t } signalWorkerFunction(); - auto ret = callbackNotifier.wait(period, signalWorkerFunction); - - if (!ret) _delayedSubmissionCollection->cancelGenericPre(id); - - return ret; + size_t retryCount = 0; + while (true) { + auto ret = callbackNotifier.wait(period, signalWorkerFunction); + + try { + if (!ret) _delayedSubmissionCollection->cancelGenericPre(id); + return ret; + } catch (const std::runtime_error& e) { + if (++retryCount % 10 == 0) + ucxx_warn( + "Could not cancel after %lu attempts, the callback has not returned and the process " + "may stop responding.", + retryCount); + } + } } } @@ -384,11 +394,21 @@ bool Worker::registerGenericPost(DelayedSubmissionCallbackType callback, uint64_ } signalWorkerFunction(); - auto ret = callbackNotifier.wait(period, signalWorkerFunction); - - if (!ret) _delayedSubmissionCollection->cancelGenericPost(id); - - return ret; + size_t retryCount = 0; + while (true) { + auto ret = callbackNotifier.wait(period, signalWorkerFunction); + + try { + if (!ret) _delayedSubmissionCollection->cancelGenericPost(id); + return ret; + } catch (const std::runtime_error& e) { + if (++retryCount % 10 == 0) + ucxx_warn( + "Could not cancel after %lu attempts, the callback has not returned and the process " + "may stop responding.", + retryCount); + } + } } } diff --git a/cpp/src/worker_progress_thread.cpp b/cpp/src/worker_progress_thread.cpp index 21e56573..c645be3f 100644 --- a/cpp/src/worker_progress_thread.cpp +++ b/cpp/src/worker_progress_thread.cpp @@ -83,7 +83,7 @@ void WorkerProgressThread::stop() }); _signalWorkerFunction(); if (!callbackNotifierPost.wait(3000000000)) { - _delayedSubmissionCollection->cancelGenericPre(idPost); + _delayedSubmissionCollection->cancelGenericPost(idPost); } _thread.join(); diff --git a/cpp/tests/worker.cpp b/cpp/tests/worker.cpp index 77667045..c941c20a 100644 --- a/cpp/tests/worker.cpp +++ b/cpp/tests/worker.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -17,6 +18,18 @@ namespace { using ::testing::Combine; using ::testing::Values; +enum class GenericCallbackType { + None = 0, + Pre, + Post, + PrePost, + PostPre, +}; + +struct ExtraParams { + GenericCallbackType genericCallbackType{GenericCallbackType::None}; +}; + class WorkerTest : public ::testing::Test { protected: std::shared_ptr _context{ @@ -43,16 +56,18 @@ class WorkerCapabilityTest : public ::testing::Test, } }; -class WorkerProgressTest : public WorkerTest, - public ::testing::WithParamInterface> { +class WorkerProgressTest + : public WorkerTest, + public ::testing::WithParamInterface> { protected: std::function _progressWorker; bool _enableDelayedSubmission; ProgressMode _progressMode; + ExtraParams _extraParams; void SetUp() { - std::tie(_enableDelayedSubmission, _progressMode) = GetParam(); + std::tie(_enableDelayedSubmission, _progressMode, _extraParams) = GetParam(); _worker = _context->createWorker(_enableDelayedSubmission); @@ -74,6 +89,10 @@ class WorkerProgressTest : public WorkerTest, } }; +class WorkerGenericCallbackTest : public WorkerProgressTest {}; + +class WorkerGenericCallbackSingleTest : public WorkerProgressTest {}; + TEST_F(WorkerTest, HandleIsValid) { ASSERT_TRUE(_worker->getHandle() != nullptr); } TEST_P(WorkerCapabilityTest, CheckCapability) @@ -327,6 +346,144 @@ TEST_P(WorkerProgressTest, ProgressTagMulti) } } +TEST_P(WorkerGenericCallbackTest, RegisterGeneric) +{ + bool done1 = false; + bool done2 = false; + auto callback1 = [&done1]() { done1 = true; }; + auto callback2 = [&done2]() { done2 = true; }; + + if (_extraParams.genericCallbackType == GenericCallbackType::Pre) { + ASSERT_TRUE(_worker->registerGenericPre(callback1)); + ASSERT_TRUE(done1); + } else if (_extraParams.genericCallbackType == GenericCallbackType::Post) { + ASSERT_TRUE(_worker->registerGenericPre(callback1)); + ASSERT_TRUE(done1); + } else if (_extraParams.genericCallbackType == GenericCallbackType::PrePost) { + ASSERT_TRUE(_worker->registerGenericPre(callback1)); + ASSERT_TRUE(_worker->registerGenericPost(callback2)); + ASSERT_TRUE(done1); + ASSERT_TRUE(done2); + } else if (_extraParams.genericCallbackType == GenericCallbackType::PostPre) { + ASSERT_TRUE(_worker->registerGenericPost(callback1)); + ASSERT_TRUE(_worker->registerGenericPre(callback2)); + ASSERT_TRUE(done1); + ASSERT_TRUE(done2); + } +} + +TEST_P(WorkerGenericCallbackTest, RegisterGenericCancel) +{ + bool threadStarted = false; + bool terminateThread = false; + bool done = false; + auto callback = [&done] { done = true; }; + + std::mutex m{}; + std::condition_variable conditionVariable{}; + + std::thread thread = + std::thread([this, &threadStarted, &terminateThread, &m, &conditionVariable]() { + auto threadCallback = [&threadStarted, &terminateThread, &m, &conditionVariable]() { + // Allow main thread to test for generic callback cancelation. + threadStarted = true; + conditionVariable.notify_one(); + + { + std::unique_lock l(m); + // Wait until the main thread had a generic callback cancelled + conditionVariable.wait(l, [&terminateThread] { return terminateThread; }); + } + }; + + if (_extraParams.genericCallbackType == GenericCallbackType::Pre || + _extraParams.genericCallbackType == GenericCallbackType::PrePost) { + ASSERT_TRUE(_worker->registerGenericPre(threadCallback)); + } else if (_extraParams.genericCallbackType == GenericCallbackType::Post || + _extraParams.genericCallbackType == GenericCallbackType::PostPre) { + ASSERT_TRUE(_worker->registerGenericPost(threadCallback)); + } + }); + + { + std::unique_lock l(m); + // Wait until thread starts and blocks. + conditionVariable.wait(l, [&threadStarted] { return threadStarted; }); + } + + // The thread should be running, therefore the callback will be canceled before running. + // Note here `PrePost`/`PostPre` order is the opposite as from `thread`. + if (_extraParams.genericCallbackType == GenericCallbackType::Pre || + _extraParams.genericCallbackType == GenericCallbackType::PostPre) { + ASSERT_FALSE(_worker->registerGenericPre(callback, 1)); + } else if (_extraParams.genericCallbackType == GenericCallbackType::Post || + _extraParams.genericCallbackType == GenericCallbackType::PrePost) { + ASSERT_FALSE(_worker->registerGenericPost(callback, 1)); + } + ASSERT_FALSE(done); + + // Unblock thread to terminate. + terminateThread = true; + conditionVariable.notify_one(); + thread.join(); + + // Nothing should be blocking the progress thread now, the callback should succeed. + // Note here `PrePost`/`PostPre` order is the opposite as from `thread`. + if (_extraParams.genericCallbackType == GenericCallbackType::Pre || + _extraParams.genericCallbackType == GenericCallbackType::PostPre) { + ASSERT_TRUE(_worker->registerGenericPre(callback)); + } else if (_extraParams.genericCallbackType == GenericCallbackType::Post || + _extraParams.genericCallbackType == GenericCallbackType::PrePost) { + ASSERT_TRUE(_worker->registerGenericPost(callback)); + } + ASSERT_TRUE(done); +} + +TEST_P(WorkerGenericCallbackSingleTest, RegisterGenericPreUncancelable) +{ + bool terminateThread = false; + bool match = false; + + std::mutex m{}; + std::condition_variable conditionVariable{}; + + std::thread thread = std::thread([this, &terminateThread, &m, &conditionVariable]() { + auto threadCallback = [&terminateThread, &m, &conditionVariable]() { + { + std::unique_lock l(m); + conditionVariable.wait(l, [&terminateThread] { return terminateThread; }); + } + }; + + // This will submit the callback and attempt to cancel once every 1ms, + // a warning is logged when multiples of 10 attempts to cancel are made. + if (_extraParams.genericCallbackType == GenericCallbackType::Pre) + ASSERT_TRUE(_worker->registerGenericPre(threadCallback, 1000000 /* 1ms */)); + else if (_extraParams.genericCallbackType == GenericCallbackType::Post) + ASSERT_TRUE(_worker->registerGenericPost(threadCallback, 1000000 /* 1ms */)); + }); + + loopWithTimeout(std::chrono::milliseconds(5000), [&match] { + testing::internal::CaptureStdout(); + + // We need to allow some time for stdout to be populated, + // `GetCapturedStdout()` does not return the cumulative log. + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + match = ::testing::Matches(::testing::ContainsRegex( + "Could not cancel after .* attempts, the callback has not returned and the process may stop " + "responding."))(::testing::internal::GetCapturedStdout()); + return match; + }); + + // Unblock thread to terminate. + terminateThread = true; + conditionVariable.notify_one(); + thread.join(); + + ASSERT_TRUE(match); +} + INSTANTIATE_TEST_SUITE_P(ProgressModes, WorkerProgressTest, Combine(Values(false), @@ -334,11 +491,31 @@ INSTANTIATE_TEST_SUITE_P(ProgressModes, ProgressMode::Blocking, ProgressMode::Wait, ProgressMode::ThreadPolling, - ProgressMode::ThreadBlocking))); + ProgressMode::ThreadBlocking), + Values(ExtraParams{}))); + +INSTANTIATE_TEST_SUITE_P(DelayedSubmission, + WorkerProgressTest, + Combine(Values(true), + Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), + Values(ExtraParams{}))); + +INSTANTIATE_TEST_SUITE_P( + GenericCallbacks, + WorkerGenericCallbackTest, + Combine(Values(false, true), + Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), + Values(ExtraParams{.genericCallbackType = GenericCallbackType::Pre}, + ExtraParams{.genericCallbackType = GenericCallbackType::Post}, + ExtraParams{.genericCallbackType = GenericCallbackType::PrePost}, + ExtraParams{.genericCallbackType = GenericCallbackType::PostPre}))); INSTANTIATE_TEST_SUITE_P( - DelayedSubmission, - WorkerProgressTest, - Combine(Values(true), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking))); + GenericCallbacksSingle, + WorkerGenericCallbackSingleTest, + Combine(Values(false, true), + Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), + Values(ExtraParams{.genericCallbackType = GenericCallbackType::Pre}, + ExtraParams{.genericCallbackType = GenericCallbackType::Post}))); } // namespace