Skip to content

Commit

Permalink
failure handler when shutdown
Browse files Browse the repository at this point in the history
Signed-off-by: paer <chenqixiang.cqx@antgroup.com>
  • Loading branch information
paer committed Sep 1, 2023
1 parent 6d29ea5 commit 26c6e2f
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 27 deletions.
21 changes: 14 additions & 7 deletions fed/_private/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# limitations under the License.

from fed.cleanup import CleanupManager
from typing import Callable


class GlobalContext:
def __init__(self, job_name: str) -> None:
def __init__(self, job_name: str,
failure_handler: Callable[[], None] ) -> None:
self._job_name = job_name
self._seq_count = 0
self._cleanup_manager = CleanupManager()
self._failure_handler = failure_handler

def next_seq_id(self) -> int:
self._seq_count += 1
Expand All @@ -31,24 +34,28 @@ def get_cleanup_manager(self) -> CleanupManager:
def job_name(self) -> str:
return self._job_name

def failure_handler(self) -> Callable[[], None]:
return self._failure_handler


_global_context = None


def init_global_context(job_name: str) -> None:
def init_global_context(job_name: str, failure_handler: Callable[[], None]) -> None:
global _global_context
if _global_context is None:
_global_context = GlobalContext(job_name)
_global_context = GlobalContext(job_name, failure_handler)


def get_global_context():
global _global_context
if _global_context is None:
_global_context = GlobalContext()
# if _global_context is None:
# _global_context = GlobalContext()
return _global_context


def clear_global_context():
global _global_context
_global_context.get_cleanup_manager().graceful_stop()
_global_context = None
if _global_context is not None:
_global_context.get_cleanup_manager().graceful_stop()
_global_context = None
3 changes: 2 additions & 1 deletion fed/_private/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _loop():
if self._thread is None or not self._thread.is_alive():
logger.debug(f"Starting new thread[{self._name}] for message polling.")
self._queue = deque()
self._thread = threading.Thread(target=_loop)
self._thread = threading.Thread(target=_loop, name=self._name)
self._thread.start()

def push(self, message):
Expand Down Expand Up @@ -80,6 +80,7 @@ def stop(self, graceful=True):

if graceful:
if self.is_started():
logger.debug(f"Gracefully killing thread[{self._name}].")
self.notify_to_exit()
self._thread.join()
else:
Expand Down
31 changes: 22 additions & 9 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import inspect
import logging
import signal
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, Callable

import cloudpickle
import ray
Expand Down Expand Up @@ -55,8 +55,11 @@
def _signal_handler(signum, frame):
if signum == signal.SIGINT:
signal.signal(signal.SIGINT, original_sigint)
logger.warning("Receiving SIGINT, try to shutdown fed.")
shutdown()
logger.warning(
"Stop signal received (e.g. via SIGINT/Ctrl+C), "
"try to shutdown fed. Press CTRL+C "
"(or send SIGINT/SIGKILL/SIGTERM) to skip.")
shutdown(intended=False)


def init(
Expand All @@ -68,7 +71,8 @@ def init(
sender_proxy_cls: SenderProxy = None,
receiver_proxy_cls: ReceiverProxy = None,
receiver_sender_proxy_cls: SenderReceiverProxy = None,
job_name: str = constants.RAYFED_DEFAULT_JOB_NAME
job_name: str = constants.RAYFED_DEFAULT_JOB_NAME,
failure_handler: Callable[[], None] = None,
):
"""
Initialize a RayFed client.
Expand Down Expand Up @@ -129,7 +133,7 @@ def init(
assert party in addresses, f"Party {party} is not in the addresses {addresses}."

fed_utils.validate_addresses(addresses)
init_global_context(job_name=job_name)
init_global_context(job_name=job_name, failure_handler=failure_handler)
tls_config = {} if tls_config is None else tls_config
if tls_config:
assert (
Expand Down Expand Up @@ -228,13 +232,22 @@ def init(
ping_others(addresses=addresses, self_party=party, max_retries=3600)


def shutdown():
def shutdown(intended=True):
"""
Shutdown a RayFed client.
Args:
intended: (Optional) Whether this is a intended exit. If not, a failure handler
will be triggered.
"""
compatible_utils._clear_internal_kv()
clear_global_context()
logger.info('Shutdowned rayfed.')
if (get_global_context() is not None):
# Job has inited, can be shutdown
failure_handler = get_global_context().failure_handler()
compatible_utils._clear_internal_kv()
clear_global_context()
if(not intended and failure_handler is not None):
failure_handler()
logger.info('Shutdowned rayfed.')


def _get_addresses(job_name: str = None):
Expand Down
17 changes: 7 additions & 10 deletions fed/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _signal_exit(self):
Exit the current process immediately. The signal will be captured
in main thread where the `stop` will be called.
"""
logger.debug("Signal SIGINT to exit.")
os.kill(os.getpid(), signal.SIGINT)

def _process_data_message(self, message):
Expand All @@ -134,16 +135,12 @@ def _process_data_message(self, message):
res = False

if not res and self._exit_on_sending_failure:
# NOTE(NKcqx): this will exit the data sending thread and
# the error sending thread. However, the former will IGNORE
# the stop command because this function is called inside
# the data sending thread but it can't kill itself. The
# data sending thread is exiting by the return value.
# self._signal_exit()
# Notify main thread to clear all sub-threads
os.kill(os.getpid(), signal.SIGINT)
# This will notify the queue to break the for-loop and
# exit the thread.
# NOTE(NKcqx): Send signal to main thread so that it can
# do some cleaning, e.g. kill the error sending thread.
self._signal_exit()
# Return False to exit the loop in sub-thread. Note that
# the above signal will also make the main thread to kill
# the sub-thread eventually.
return False
return True

Expand Down
1 change: 1 addition & 0 deletions tests/test_exit_on_failure_sending.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def run(party):
'timeout_ms': 20 * 1000,
},
},
failure_handler= lambda : os.kill(os.getpid(), signal.SIGTERM)
)

o = f.party("alice").remote()
Expand Down

0 comments on commit 26c6e2f

Please sign in to comment.