Skip to content

Commit

Permalink
feat: no wait for data sending on error.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouaihui committed Jan 30, 2024
1 parent 15262ba commit 55d7fe1
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 96 deletions.
27 changes: 24 additions & 3 deletions fed/_private/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Callable

from fed.cleanup import CleanupManager
from fed.exceptions import FedRemoteError


class GlobalContext:
Expand All @@ -25,6 +26,7 @@ def __init__(
current_party: str,
sending_failure_handler: Callable[[Exception], None],
exit_on_sending_failure=False,
continue_waiting_for_data_sending_on_error=False,
) -> None:
self._job_name = job_name
self._seq_count = 0
Expand All @@ -35,6 +37,10 @@ def __init__(
self._cleanup_manager = CleanupManager(
current_party, self.acquire_shutdown_flag
)
self._last_received_error: FedRemoteError = None
self._continue_waiting_for_data_sending_on_error = (
continue_waiting_for_data_sending_on_error
)

def next_seq_id(self) -> int:
self._seq_count += 1
Expand All @@ -52,6 +58,15 @@ def get_sending_failure_handler(self) -> Callable[[], None]:
def get_exit_on_sending_failure(self) -> bool:
return self._exit_on_sending_failure

def get_last_recevied_error(self) -> FedRemoteError:
return self._last_received_error

def set_last_recevied_error(self, err):
self._last_received_error = err

def get_continue_waiting_for_data_sending_on_error(self) -> bool:
return self._continue_waiting_for_data_sending_on_error

def acquire_shutdown_flag(self) -> bool:
"""
Acquiring a lock and set the flag to make sure
Expand All @@ -78,12 +93,18 @@ def acquire_shutdown_flag(self) -> bool:
def init_global_context(
current_party: str,
job_name: str,
exit_on_sending_failure: bool,
continue_waiting_for_data_sending_on_error: bool,
sending_failure_handler: Callable[[Exception], None] = None,
) -> None:
global _global_context
if _global_context is None:
_global_context = GlobalContext(
job_name, current_party, sending_failure_handler
job_name,
current_party,
exit_on_sending_failure=exit_on_sending_failure,
continue_waiting_for_data_sending_on_error=continue_waiting_for_data_sending_on_error,
sending_failure_handler=sending_failure_handler,
)


Expand All @@ -92,8 +113,8 @@ def get_global_context():
return _global_context


def clear_global_context(graceful=True):
def clear_global_context(wait_for_sending=False):
global _global_context
if _global_context is not None:
_global_context.get_cleanup_manager().stop(graceful=graceful)
_global_context.get_cleanup_manager().stop(wait_for_sending=wait_for_sending)
_global_context = None
18 changes: 8 additions & 10 deletions fed/_private/message_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


class MessageQueueManager:
def __init__(self, msg_handler, failure_handler=None, thread_name=''):
def __init__(self, msg_handler, failure_handler=None, thread_name=""):
assert callable(msg_handler), "msg_handler must be a callable function"
# `deque()` is thread safe on `popleft` and `append` operations.
# See https://docs.python.org/3/library/collections.html#deque-objects
Expand Down Expand Up @@ -73,16 +73,13 @@ def _notify_to_exit(self, immediately=False):
else:
self.append(STOP_SYMBOL)

def stop(self, immediately=False):
def stop(self, wait_for_sending=True):
"""
Stop the message queue.
Args:
immediately (bool): A flag indicating whether to stop the queue
immediately or not. Default is True.
If True: insert the STOP_SYMBOL at the begin of the queue.
If False: insert the STOP_SYMBOL at the end of the queue, which means
stop the for loop until all messages in queue are all sent.
wait_for_sending (bool): A flag indicating whether joining the thread to wait for
the loop stop. If True, do not join. Defaults to True.
"""
if threading.current_thread() == self._thread:
logger.error(
Expand All @@ -97,9 +94,10 @@ def stop(self, immediately=False):
# Therefore, currently, not support forcelly kill thread
if self.is_started():
logger.debug(f"Killing thread[{self._thread_name}].")
self._notify_to_exit(immediately=immediately)
self._thread.join()
logger.info(f"The message polling thread[{self._thread_name}] was exited.")
self._notify_to_exit(immediately=not wait_for_sending)
if wait_for_sending:
self._thread.join()
logger.info(f"The message polling thread[{self._thread_name}] was exited.")

def is_started(self):
return self._thread is not None and self._thread.is_alive()
65 changes: 42 additions & 23 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
import inspect
import logging
import signal
import sys
from typing import Any, Callable, Dict, List, Union

import cloudpickle
import ray
from ray.exceptions import RayError
import sys

import fed._private.compatible_utils as compatible_utils
import fed.config as fed_config
Expand Down Expand Up @@ -70,7 +69,7 @@ def init(
party: str = None,
config: Dict = {},
tls_config: Dict = None,
logging_level: str = 'info',
logging_level: str = "info",
sender_proxy_cls: SenderProxy = None,
receiver_proxy_cls: ReceiverProxy = None,
receiver_sender_proxy_cls: SenderReceiverProxy = None,
Expand Down Expand Up @@ -125,6 +124,7 @@ def init(
"exit_on_sending_failure": True,
"expose_error_trace": True,
"use_global_proxy": True,
"continue_waiting_for_data_sending_on_error": False,
},
"barrier_on_initializing": True,
}
Expand Down Expand Up @@ -182,16 +182,23 @@ def init(
job_name = constants.RAYFED_DEFAULT_JOB_NAME

fed_utils.validate_addresses(addresses)

cross_silo_comm_dict = config.get("cross_silo_comm", {})
cross_silo_comm_config = CrossSiloMessageConfig.from_dict(cross_silo_comm_dict)

init_global_context(
current_party=party,
job_name=job_name,
exit_on_sending_failure=cross_silo_comm_config.exit_on_sending_failure,
continue_waiting_for_data_sending_on_error=cross_silo_comm_config.continue_waiting_for_data_sending_on_error,
sending_failure_handler=sending_failure_handler,
)

tls_config = {} if tls_config is None else tls_config
if tls_config:
assert (
'cert' in tls_config and 'key' in tls_config
), 'Cert or key are not in tls_config.'
"cert" in tls_config and "key" in tls_config
), "Cert or key are not in tls_config."

# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv(job_name)
Expand All @@ -201,15 +208,15 @@ def init(
constants.KEY_OF_CURRENT_PARTY_NAME: party,
constants.KEY_OF_TLS_CONFIG: tls_config,
}
compatible_utils.kv.put(
constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)
)

cross_silo_comm_dict = config.get("cross_silo_comm", {})
job_config = {
constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT: cross_silo_comm_dict,
}
compatible_utils.kv.put(
constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)
)
compatible_utils.kv.put(constants.KEY_OF_JOB_CONFIG, cloudpickle.dumps(job_config))

# Set logger.
# Note(NKcqx): This should be called after internal_kv has party value, i.e.
# after `ray.init` and
Expand All @@ -222,8 +229,7 @@ def init(
job_name=job_name,
)

logger.info(f'Started rayfed with {cluster_config}')
cross_silo_comm_config = CrossSiloMessageConfig.from_dict(cross_silo_comm_dict)
logger.info(f"Started rayfed with {cluster_config}")
signal.signal(signal.SIGINT, _signal_handler)
get_global_context().get_cleanup_manager().start(
exit_on_sending_failure=cross_silo_comm_config.exit_on_sending_failure,
Expand Down Expand Up @@ -305,42 +311,53 @@ def _shutdown(intended=True):
Args:
intended: (Optional) Whether this is a intended shutdown. If not
a "failure handler" will be triggered and sys.exit will be called then.
a "failure handler" will be triggered and do not wait data sending.
"""

if get_global_context() is None:
# Do nothing since job has not been inited or is cleaned already.
return

if intended:
logger.info('Shutdowning rayfed intendedly...')
logger.info("Shutdowning rayfed intendedly...")
else:
logger.warn('Shutdowning rayfed unintendedly...')
logger.warn("Shutdowning rayfed unintendedly...")
global_context = get_global_context()
last_sending_error = global_context.get_cleanup_manager().get_last_sending_error()
last_received_error = global_context.get_last_recevied_error()
if last_sending_error is not None:
logging.error(f'Cross-silo sending error occured. {last_sending_error}')
logging.error(f"Cross-silo sending error occured. {last_sending_error}")

wait_for_sending = True
if (
last_sending_error is not None or last_received_error is not None
) and not global_context.get_continue_waiting_for_data_sending_on_error():
wait_for_sending = False
logging.info(f'{"Wait" if wait_for_sending else "No wait"} for data sending.')

if not intended:
# Execute failure_handler fisrtly.
failure_handler = global_context.get_sending_failure_handler()
if failure_handler is not None:
logger.info('Executing failure handler...')
logger.info(f"Executing failure handler {failure_handler} ...")
failure_handler(last_sending_error)

exit_on_sending_failure = global_context.get_exit_on_sending_failure()

# Clean context.
compatible_utils._clear_internal_kv()
clear_global_context(graceful=intended)
logger.info('Shutdowned rayfed.')
clear_global_context(wait_for_sending=wait_for_sending)
logger.info("Shutdowned rayfed.")

# Exit with error.
logger.critical('Exit now due to the previous error.')
sys.exit(1)
if exit_on_sending_failure:
# Exit with error.
logger.critical("Exit now due to the previous error.")
sys.exit(1)
else:
# Clean context.
compatible_utils._clear_internal_kv()
clear_global_context(graceful=intended)
logger.info('Shutdowned rayfed.')
clear_global_context(wait_for_sending=wait_for_sending)
logger.info("Shutdowned rayfed.")


def _get_addresses(job_name: str = None):
Expand Down Expand Up @@ -586,6 +603,8 @@ def get(
"Encounter RemoteError happend in other parties"
f", error message: {e.cause}"
)
if get_global_context() is not None:
get_global_context().set_last_recevied_error(e)
raise e


Expand Down
34 changes: 9 additions & 25 deletions fed/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class CleanupManager:
def __init__(self, current_party, acquire_shutdown_flag) -> None:
self._sending_data_q = MessageQueueManager(
lambda msg: self._process_data_sending_task_return(msg),
thread_name='DataSendingQueueThread',
thread_name="DataSendingQueueThread",
)

self._sending_error_q = MessageQueueManager(
Expand All @@ -64,32 +64,16 @@ def start(self, exit_on_sending_failure=False, expose_error_trace=False):
self._expose_error_trace = expose_error_trace

self._sending_data_q.start()
logger.debug('Start check sending thread.')
logger.debug("Start check sending thread.")
self._sending_error_q.start()
logger.debug('Start check error sending thread.')
logger.debug("Start check error sending thread.")

def _main_thread_monitor():
main_thread = threading.main_thread()
main_thread.join()
logging.debug('Stoping sending queue ...')
self.stop(graceful=True)

self._monitor_thread = threading.Thread(target=_main_thread_monitor)
self._monitor_thread.start()
logger.info('Start check sending monitor thread.')

def stop(self, graceful=True):
def stop(self, wait_for_sending=False):
# NOTE(NKcqx): MUST firstly stop the data queue, because it
# may still throw errors during the termination which need to
# be sent to the error queue.
if graceful:
self._sending_data_q.stop(immediately=False)
self._sending_error_q.stop(immediately=False)
else:
# Stop data queue immediately, but stop error queue not immediately always
# to sure that error can be sent to peers.
self._sending_data_q.stop(immediately=True)
self._sending_error_q.stop(immediately=False)
self._sending_data_q.stop(wait_for_sending=wait_for_sending)
self._sending_error_q.stop(wait_for_sending=wait_for_sending)

def push_to_sending(
self,
Expand Down Expand Up @@ -168,9 +152,9 @@ def _process_data_sending_task_return(self, message):
res = ray.get(obj_ref)
except Exception as e:
logger.warn(
f'Failed to send {obj_ref} to {dest_party}, error: {e},'
f'upstream_seq_id: {upstream_seq_id}, '
f'downstream_seq_id: {downstream_seq_id}.'
f"Failed to send {obj_ref} to {dest_party}, error: {e},"
f"upstream_seq_id: {upstream_seq_id}, "
f"downstream_seq_id: {downstream_seq_id}."
)
self._last_sending_error = e
if isinstance(e, RayError):
Expand Down
5 changes: 5 additions & 0 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class CrossSiloMessageConfig:
exit_on_sending_failure:
whether exit when failure on cross-silo sending. If True, a SIGINT will be
signaled to self if failed to sending cross-silo data and exit then.
continue_waiting_for_data_sending_on_error:
Whether to continue waiting for data sending if an error occurs, including
data-sending errors and receiving errors from the peer. If True, wait until
all data has been sent.
messages_max_size_in_bytes:
The maximum length in bytes of cross-silo messages. If None, the default
value of 500 MB is specified.
Expand All @@ -122,6 +126,7 @@ class CrossSiloMessageConfig:
timeout_in_ms: int = 60000
messages_max_size_in_bytes: int = None
exit_on_sending_failure: Optional[bool] = False
continue_waiting_for_data_sending_on_error: Optional[bool] = False
serializing_allowed_list: Optional[Dict[str, str]] = None
send_resource_label: Optional[Dict[str, str]] = None
recv_resource_label: Optional[Dict[str, str]] = None
Expand Down
Loading

0 comments on commit 55d7fe1

Please sign in to comment.