Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: no wait for data sending on error. #210

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions fed/_private/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Callable

from fed.cleanup import CleanupManager
from fed.exceptions import FedRemoteError


class GlobalContext:
Expand All @@ -25,6 +26,7 @@ def __init__(
current_party: str,
sending_failure_handler: Callable[[Exception], None],
exit_on_sending_failure=False,
continue_waiting_for_data_sending_on_error=False,
) -> None:
self._job_name = job_name
self._seq_count = 0
Expand All @@ -35,6 +37,10 @@ def __init__(
self._cleanup_manager = CleanupManager(
current_party, self.acquire_shutdown_flag
)
self._last_received_error: FedRemoteError = None
self._continue_waiting_for_data_sending_on_error = (
continue_waiting_for_data_sending_on_error
)

def next_seq_id(self) -> int:
self._seq_count += 1
Expand All @@ -52,6 +58,15 @@ def get_sending_failure_handler(self) -> Callable[[], None]:
def get_exit_on_sending_failure(self) -> bool:
return self._exit_on_sending_failure

def get_last_recevied_error(self) -> FedRemoteError:
return self._last_received_error

def set_last_recevied_error(self, err):
self._last_received_error = err

def get_continue_waiting_for_data_sending_on_error(self) -> bool:
return self._continue_waiting_for_data_sending_on_error

def acquire_shutdown_flag(self) -> bool:
"""
Acquiring a lock and set the flag to make sure
Expand All @@ -78,12 +93,18 @@ def acquire_shutdown_flag(self) -> bool:
def init_global_context(
current_party: str,
job_name: str,
exit_on_sending_failure: bool,
continue_waiting_for_data_sending_on_error: bool,
sending_failure_handler: Callable[[Exception], None] = None,
) -> None:
global _global_context
if _global_context is None:
_global_context = GlobalContext(
job_name, current_party, sending_failure_handler
job_name,
current_party,
exit_on_sending_failure=exit_on_sending_failure,
continue_waiting_for_data_sending_on_error=continue_waiting_for_data_sending_on_error,
sending_failure_handler=sending_failure_handler,
)


Expand All @@ -92,8 +113,8 @@ def get_global_context():
return _global_context


def clear_global_context(graceful=True):
def clear_global_context(wait_for_sending=False):
global _global_context
if _global_context is not None:
_global_context.get_cleanup_manager().stop(graceful=graceful)
_global_context.get_cleanup_manager().stop(wait_for_sending=wait_for_sending)
_global_context = None
20 changes: 10 additions & 10 deletions fed/_private/message_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


class MessageQueueManager:
def __init__(self, msg_handler, failure_handler=None, thread_name=''):
def __init__(self, msg_handler, failure_handler=None, thread_name=""):
assert callable(msg_handler), "msg_handler must be a callable function"
# `deque()` is thread safe on `popleft` and `append` operations.
# See https://docs.python.org/3/library/collections.html#deque-objects
Expand Down Expand Up @@ -73,16 +73,13 @@ def _notify_to_exit(self, immediately=False):
else:
self.append(STOP_SYMBOL)

def stop(self, immediately=False):
def stop(self, wait_for_sending=True):
"""
Stop the message queue.

Args:
immediately (bool): A flag indicating whether to stop the queue
immediately or not. Default is True.
If True: insert the STOP_SYMBOL at the begin of the queue.
If False: insert the STOP_SYMBOL at the end of the queue, which means
stop the for loop until all messages in queue are all sent.
wait_for_sending (bool): A flag indicating whether joining the thread to wait for
the loop stop. If True, do not join. Defaults to True.
"""
if threading.current_thread() == self._thread:
logger.error(
Expand All @@ -97,9 +94,12 @@ def stop(self, immediately=False):
# Therefore, currently, not support forcelly kill thread
if self.is_started():
logger.debug(f"Killing thread[{self._thread_name}].")
self._notify_to_exit(immediately=immediately)
self._thread.join()
logger.info(f"The message polling thread[{self._thread_name}] was exited.")
self._notify_to_exit(immediately=not wait_for_sending)
if wait_for_sending:
self._thread.join()
logger.info(
f"The message polling thread[{self._thread_name}] was exited."
)

def is_started(self):
return self._thread is not None and self._thread.is_alive()
65 changes: 42 additions & 23 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
import inspect
import logging
import signal
import sys
from typing import Any, Callable, Dict, List, Union

import cloudpickle
import ray
from ray.exceptions import RayError
import sys

import fed._private.compatible_utils as compatible_utils
import fed.config as fed_config
Expand Down Expand Up @@ -70,7 +69,7 @@ def init(
party: str = None,
config: Dict = {},
tls_config: Dict = None,
logging_level: str = 'info',
logging_level: str = "info",
sender_proxy_cls: SenderProxy = None,
receiver_proxy_cls: ReceiverProxy = None,
receiver_sender_proxy_cls: SenderReceiverProxy = None,
Expand Down Expand Up @@ -125,6 +124,7 @@ def init(
"exit_on_sending_failure": True,
"expose_error_trace": True,
"use_global_proxy": True,
"continue_waiting_for_data_sending_on_error": False,
},
"barrier_on_initializing": True,
}
Expand Down Expand Up @@ -182,16 +182,23 @@ def init(
job_name = constants.RAYFED_DEFAULT_JOB_NAME

fed_utils.validate_addresses(addresses)

cross_silo_comm_dict = config.get("cross_silo_comm", {})
cross_silo_comm_config = CrossSiloMessageConfig.from_dict(cross_silo_comm_dict)

init_global_context(
current_party=party,
job_name=job_name,
exit_on_sending_failure=cross_silo_comm_config.exit_on_sending_failure,
continue_waiting_for_data_sending_on_error=cross_silo_comm_config.continue_waiting_for_data_sending_on_error,
sending_failure_handler=sending_failure_handler,
)

tls_config = {} if tls_config is None else tls_config
if tls_config:
assert (
'cert' in tls_config and 'key' in tls_config
), 'Cert or key are not in tls_config.'
"cert" in tls_config and "key" in tls_config
), "Cert or key are not in tls_config."

# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv(job_name)
Expand All @@ -201,15 +208,15 @@ def init(
constants.KEY_OF_CURRENT_PARTY_NAME: party,
constants.KEY_OF_TLS_CONFIG: tls_config,
}
compatible_utils.kv.put(
constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)
)

cross_silo_comm_dict = config.get("cross_silo_comm", {})
job_config = {
constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT: cross_silo_comm_dict,
}
compatible_utils.kv.put(
constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)
)
compatible_utils.kv.put(constants.KEY_OF_JOB_CONFIG, cloudpickle.dumps(job_config))

# Set logger.
# Note(NKcqx): This should be called after internal_kv has party value, i.e.
# after `ray.init` and
Expand All @@ -222,8 +229,7 @@ def init(
job_name=job_name,
)

logger.info(f'Started rayfed with {cluster_config}')
cross_silo_comm_config = CrossSiloMessageConfig.from_dict(cross_silo_comm_dict)
logger.info(f"Started rayfed with {cluster_config}")
signal.signal(signal.SIGINT, _signal_handler)
get_global_context().get_cleanup_manager().start(
exit_on_sending_failure=cross_silo_comm_config.exit_on_sending_failure,
Expand Down Expand Up @@ -305,42 +311,53 @@ def _shutdown(intended=True):

Args:
intended: (Optional) Whether this is a intended shutdown. If not
a "failure handler" will be triggered and sys.exit will be called then.
a "failure handler" will be triggered and do not wait data sending.
"""

if get_global_context() is None:
# Do nothing since job has not been inited or is cleaned already.
return

if intended:
logger.info('Shutdowning rayfed intendedly...')
logger.info("Shutdowning rayfed intendedly...")
else:
logger.warn('Shutdowning rayfed unintendedly...')
logger.warn("Shutdowning rayfed unintendedly...")
global_context = get_global_context()
last_sending_error = global_context.get_cleanup_manager().get_last_sending_error()
last_received_error = global_context.get_last_recevied_error()
if last_sending_error is not None:
logging.error(f'Cross-silo sending error occured. {last_sending_error}')
logging.error(f"Cross-silo sending error occured. {last_sending_error}")

wait_for_sending = True
if (
last_sending_error is not None or last_received_error is not None
) and not global_context.get_continue_waiting_for_data_sending_on_error():
wait_for_sending = False
logging.info(f'{"Wait" if wait_for_sending else "No wait"} for data sending.')

if not intended:
# Execute failure_handler fisrtly.
failure_handler = global_context.get_sending_failure_handler()
if failure_handler is not None:
logger.info('Executing failure handler...')
logger.info(f"Executing failure handler {failure_handler} ...")
failure_handler(last_sending_error)

exit_on_sending_failure = global_context.get_exit_on_sending_failure()

# Clean context.
compatible_utils._clear_internal_kv()
clear_global_context(graceful=intended)
logger.info('Shutdowned rayfed.')
clear_global_context(wait_for_sending=wait_for_sending)
logger.info("Shutdowned rayfed.")

# Exit with error.
logger.critical('Exit now due to the previous error.')
sys.exit(1)
if exit_on_sending_failure:
# Exit with error.
logger.critical("Exit now due to the previous error.")
sys.exit(1)
else:
# Clean context.
compatible_utils._clear_internal_kv()
clear_global_context(graceful=intended)
logger.info('Shutdowned rayfed.')
clear_global_context(wait_for_sending=wait_for_sending)
logger.info("Shutdowned rayfed.")


def _get_addresses(job_name: str = None):
Expand Down Expand Up @@ -586,6 +603,8 @@ def get(
"Encounter RemoteError happend in other parties"
f", error message: {e.cause}"
)
if get_global_context() is not None:
get_global_context().set_last_recevied_error(e)
raise e


Expand Down
34 changes: 9 additions & 25 deletions fed/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class CleanupManager:
def __init__(self, current_party, acquire_shutdown_flag) -> None:
self._sending_data_q = MessageQueueManager(
lambda msg: self._process_data_sending_task_return(msg),
thread_name='DataSendingQueueThread',
thread_name="DataSendingQueueThread",
)

self._sending_error_q = MessageQueueManager(
Expand All @@ -64,32 +64,16 @@ def start(self, exit_on_sending_failure=False, expose_error_trace=False):
self._expose_error_trace = expose_error_trace

self._sending_data_q.start()
logger.debug('Start check sending thread.')
logger.debug("Start check sending thread.")
self._sending_error_q.start()
logger.debug('Start check error sending thread.')
logger.debug("Start check error sending thread.")

def _main_thread_monitor():
main_thread = threading.main_thread()
main_thread.join()
logging.debug('Stoping sending queue ...')
self.stop(graceful=True)

self._monitor_thread = threading.Thread(target=_main_thread_monitor)
self._monitor_thread.start()
logger.info('Start check sending monitor thread.')

def stop(self, graceful=True):
def stop(self, wait_for_sending=False):
# NOTE(NKcqx): MUST firstly stop the data queue, because it
# may still throw errors during the termination which need to
# be sent to the error queue.
if graceful:
self._sending_data_q.stop(immediately=False)
self._sending_error_q.stop(immediately=False)
else:
# Stop data queue immediately, but stop error queue not immediately always
# to sure that error can be sent to peers.
self._sending_data_q.stop(immediately=True)
self._sending_error_q.stop(immediately=False)
self._sending_data_q.stop(wait_for_sending=wait_for_sending)
self._sending_error_q.stop(wait_for_sending=wait_for_sending)

def push_to_sending(
self,
Expand Down Expand Up @@ -168,9 +152,9 @@ def _process_data_sending_task_return(self, message):
res = ray.get(obj_ref)
except Exception as e:
logger.warn(
f'Failed to send {obj_ref} to {dest_party}, error: {e},'
f'upstream_seq_id: {upstream_seq_id}, '
f'downstream_seq_id: {downstream_seq_id}.'
f"Failed to send {obj_ref} to {dest_party}, error: {e},"
f"upstream_seq_id: {upstream_seq_id}, "
f"downstream_seq_id: {downstream_seq_id}."
)
self._last_sending_error = e
if isinstance(e, RayError):
Expand Down
5 changes: 5 additions & 0 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class CrossSiloMessageConfig:
exit_on_sending_failure:
whether exit when failure on cross-silo sending. If True, a SIGINT will be
signaled to self if failed to sending cross-silo data and exit then.
continue_waiting_for_data_sending_on_error:
Whether to continue waiting for data sending if an error occurs, including
data-sending errors and receiving errors from the peer. If True, wait until
all data has been sent.
messages_max_size_in_bytes:
The maximum length in bytes of cross-silo messages. If None, the default
value of 500 MB is specified.
Expand All @@ -122,6 +126,7 @@ class CrossSiloMessageConfig:
timeout_in_ms: int = 60000
messages_max_size_in_bytes: int = None
exit_on_sending_failure: Optional[bool] = False
continue_waiting_for_data_sending_on_error: Optional[bool] = False
serializing_allowed_list: Optional[Dict[str, str]] = None
send_resource_label: Optional[Dict[str, str]] = None
recv_resource_label: Optional[Dict[str, str]] = None
Expand Down
Loading
Loading