diff --git a/fed/_private/global_context.py b/fed/_private/global_context.py index 673bd00..d62f2e0 100644 --- a/fed/_private/global_context.py +++ b/fed/_private/global_context.py @@ -13,13 +13,16 @@ # limitations under the License. from fed.cleanup import CleanupManager +from typing import Callable class GlobalContext: - def __init__(self, job_name: str) -> None: + def __init__(self, job_name: str, + failure_handler: Callable[[], None] ) -> None: self._job_name = job_name self._seq_count = 0 self._cleanup_manager = CleanupManager() + self._failure_handler = failure_handler def next_seq_id(self) -> int: self._seq_count += 1 @@ -31,24 +34,28 @@ def get_cleanup_manager(self) -> CleanupManager: def job_name(self) -> str: return self._job_name + def failure_handler(self) -> Callable[[], None]: + return self._failure_handler + _global_context = None -def init_global_context(job_name: str) -> None: +def init_global_context(job_name: str, failure_handler: Callable[[], None]) -> None: global _global_context if _global_context is None: - _global_context = GlobalContext(job_name) + _global_context = GlobalContext(job_name, failure_handler) def get_global_context(): global _global_context - if _global_context is None: - _global_context = GlobalContext() + # if _global_context is None: + # _global_context = GlobalContext() return _global_context def clear_global_context(): global _global_context - _global_context.get_cleanup_manager().graceful_stop() - _global_context = None + if _global_context is not None: + _global_context.get_cleanup_manager().graceful_stop() + _global_context = None diff --git a/fed/_private/queue.py b/fed/_private/queue.py index 6ff9854..947bf99 100644 --- a/fed/_private/queue.py +++ b/fed/_private/queue.py @@ -51,7 +51,7 @@ def _loop(): if self._thread is None or not self._thread.is_alive(): logger.debug(f"Starting new thread[{self._name}] for message polling.") self._queue = deque() - self._thread = threading.Thread(target=_loop) + self._thread = threading.Thread(target=_loop, name=self._name) self._thread.start() def push(self, message): @@ -80,6 +80,7 @@ def stop(self, graceful=True): if graceful: if self.is_started(): + logger.debug(f"Gracefully killing thread[{self._name}].") self.notify_to_exit() self._thread.join() else: diff --git a/fed/api.py b/fed/api.py index 59bb9d5..781ef4f 100644 --- a/fed/api.py +++ b/fed/api.py @@ -16,7 +16,7 @@ import inspect import logging import signal -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Callable import cloudpickle import ray @@ -55,8 +55,11 @@ def _signal_handler(signum, frame): if signum == signal.SIGINT: signal.signal(signal.SIGINT, original_sigint) - logger.warning("Receiving SIGINT, try to shutdown fed.") - shutdown() + logger.warning( + "Stop signal received (e.g. via SIGINT/Ctrl+C), " + "try to shutdown fed. Press CTRL+C " + "(or send SIGINT/SIGKILL/SIGTERM) to skip.") + shutdown(intended=False) def init( @@ -68,7 +71,8 @@ def init( sender_proxy_cls: SenderProxy = None, receiver_proxy_cls: ReceiverProxy = None, receiver_sender_proxy_cls: SenderReceiverProxy = None, - job_name: str = constants.RAYFED_DEFAULT_JOB_NAME + job_name: str = constants.RAYFED_DEFAULT_JOB_NAME, + failure_handler: Callable[[], None] = None, ): """ Initialize a RayFed client. @@ -129,7 +133,7 @@ def init( assert party in addresses, f"Party {party} is not in the addresses {addresses}." fed_utils.validate_addresses(addresses) - init_global_context(job_name=job_name) + init_global_context(job_name=job_name, failure_handler=failure_handler) tls_config = {} if tls_config is None else tls_config if tls_config: assert ( @@ -228,13 +232,22 @@ def init( ping_others(addresses=addresses, self_party=party, max_retries=3600) -def shutdown(): +def shutdown(intended=True): """ Shutdown a RayFed client. + + Args: + intended: (Optional) Whether this is a intended exit. If not, a failure handler + will be triggered. """ - compatible_utils._clear_internal_kv() - clear_global_context() - logger.info('Shutdowned rayfed.') + if (get_global_context() is not None): + # Job has inited, can be shutdown + failure_handler = get_global_context().failure_handler() + compatible_utils._clear_internal_kv() + clear_global_context() + if(not intended and failure_handler is not None): + failure_handler() + logger.info('Shutdowned rayfed.') def _get_addresses(job_name: str = None): diff --git a/fed/cleanup.py b/fed/cleanup.py index e9ddb98..6a26254 100644 --- a/fed/cleanup.py +++ b/fed/cleanup.py @@ -111,6 +111,7 @@ def _signal_exit(self): Exit the current process immediately. The signal will be captured in main thread where the `stop` will be called. """ + logger.debug("Signal SIGINT to exit.") os.kill(os.getpid(), signal.SIGINT) def _process_data_message(self, message): @@ -134,16 +135,12 @@ def _process_data_message(self, message): res = False if not res and self._exit_on_sending_failure: - # NOTE(NKcqx): this will exit the data sending thread and - # the error sending thread. However, the former will IGNORE - # the stop command because this function is called inside - # the data sending thread but it can't kill itself. The - # data sending thread is exiting by the return value. - # self._signal_exit() - # Notify main thread to clear all sub-threads - os.kill(os.getpid(), signal.SIGINT) - # This will notify the queue to break the for-loop and - # exit the thread. + # NOTE(NKcqx): Send signal to main thread so that it can + # do some cleaning, e.g. kill the error sending thread. + self._signal_exit() + # Return False to exit the loop in sub-thread. Note that + # the above signal will also make the main thread to kill + # the sub-thread eventually. return False return True diff --git a/tests/test_exit_on_failure_sending.py b/tests/test_exit_on_failure_sending.py index 1de27d3..9b59553 100644 --- a/tests/test_exit_on_failure_sending.py +++ b/tests/test_exit_on_failure_sending.py @@ -73,6 +73,7 @@ def run(party): 'timeout_ms': 20 * 1000, }, }, + failure_handler= lambda : os.kill(os.getpid(), signal.SIGTERM) ) o = f.party("alice").remote()