Skip to content

Commit

Permalink
feat: expose sending error to drive main thread and fix some cross-silo
Browse files Browse the repository at this point in the history
error bugs.
  • Loading branch information
zhouaihui committed Dec 29, 2023
1 parent a615474 commit 1e78372
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 76 deletions.
11 changes: 11 additions & 0 deletions fed/_private/fed_call_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@
# Set config in the very beginning to avoid being overwritten by other packages.
logging.basicConfig(level=logging.INFO)

from fed._private.global_context import get_global_context
from fed.fed_object import FedObject
from fed.proxy.barriers import send
from fed.utils import resolve_dependencies

try:
from jax.tree_util import tree_flatten
except ImportError:
from fed.tree_util import tree_flatten

import fed.config as fed_config

logger = logging.getLogger(__name__)

Expand Down
26 changes: 17 additions & 9 deletions fed/_private/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@

class GlobalContext:
def __init__(
self, job_name: str, current_party: str, failure_handler: Callable[[], None]
self,
job_name: str,
current_party: str,
sending_failure_handler: Callable[[], None],
exit_on_sending_failure=False,
) -> None:
self._job_name = job_name
self._seq_count = 0
self._failure_handler = failure_handler
self._sending_failure_handler = sending_failure_handler
self._exit_on_sending_failure = exit_on_sending_failure
self._atomic_shutdown_flag_lock = threading.Lock()
self._atomic_shutdown_flag = True
self._cleanup_manager = CleanupManager(
Expand All @@ -41,13 +46,16 @@ def get_cleanup_manager(self) -> CleanupManager:
def get_job_name(self) -> str:
return self._job_name

def get_failure_handler(self) -> Callable[[], None]:
return self._failure_handler
def get_sending_failure_handler(self) -> Callable[[], None]:
return self._sending_failure_handler

def get_exit_on_sending_failure(self) -> bool:
return self._exit_on_sending_failure

def acquire_shutdown_flag(self) -> bool:
"""
Acquiring a lock and set the flag to make sure
`fed.shutdown(intended=False)` can be called only once.
`fed.shutdown()` can be called only once.
The unintended shutdown, i.e. `fed.shutdown(intended=False)`, needs to
be executed only once. However, `fed.shutdown` may get called duing
Expand All @@ -68,20 +76,20 @@ def acquire_shutdown_flag(self) -> bool:


def init_global_context(
current_party: str, job_name: str, failure_handler: Callable[[], None] = None
current_party: str, job_name: str, sending_failure_handler: Callable[[], None] = None
) -> None:
global _global_context
if _global_context is None:
_global_context = GlobalContext(job_name, current_party, failure_handler)
_global_context = GlobalContext(job_name, current_party, sending_failure_handler)


def get_global_context():
global _global_context
return _global_context


def clear_global_context():
def clear_global_context(graceful=True):
global _global_context
if _global_context is not None:
_global_context.get_cleanup_manager().stop()
_global_context.get_cleanup_manager().stop(graceful=graceful)
_global_context = None
29 changes: 17 additions & 12 deletions fed/_private/message_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,26 @@ def _loop():
def append(self, message):
self._queue.append(message)

def notify_to_exit(self):
def appendleft(self, message):
self._queue.appendleft(message)

def _notify_to_exit(self, immediately=False):
logger.info(f"Notify message polling thread[{self._thread_name}] to exit.")
self.append(STOP_SYMBOL)
if immediately:
self.appendleft(STOP_SYMBOL)
else:
self.append(STOP_SYMBOL)

def stop(self):
def stop(self, immediately=False):
"""
Stop the message queue.
Args:
graceful (bool): A flag indicating whether to stop the queue
gracefully or not. Default is True.
If True: insert the STOP_SYMBOL at the end of the queue
and wait for it to be processed, which will break the for-loop;
If False: forcelly kill the for-loop sub-thread.
immediately (bool): A flag indicating whether to stop the queue
immediately or not. Default is True.
If True: insert the STOP_SYMBOL at the begin of the queue.
If False: insert the STOP_SYMBOL at the end of the queue, which means
stop the for loop until all messages in queue are all sent.
"""
if threading.current_thread() == self._thread:
logger.error(
Expand All @@ -90,11 +96,10 @@ def stop(self):
# encounter AssertionError because sub-thread's lock is not released.
# Therefore, currently, not support forcelly kill thread
if self.is_started():
logger.debug(f"Gracefully killing thread[{self._thread_name}].")
self.notify_to_exit()
logger.debug(f"Killing thread[{self._thread_name}].")
self._notify_to_exit(immediately=immediately)
self._thread.join()

logger.info(f"The message polling thread[{self._thread_name}] was exited.")
logger.info(f"The message polling thread[{self._thread_name}] was exited.")

def is_started(self):
return self._thread is not None and self._thread.is_alive()
63 changes: 45 additions & 18 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import cloudpickle
import ray
from ray.exceptions import RayError
import sys

import fed._private.compatible_utils as compatible_utils
import fed.config as fed_config
Expand Down Expand Up @@ -74,7 +75,7 @@ def init(
receiver_proxy_cls: ReceiverProxy = None,
receiver_sender_proxy_cls: SenderReceiverProxy = None,
job_name: str = constants.RAYFED_DEFAULT_JOB_NAME,
failure_handler: Callable[[], None] = None,
sending_failure_handler: Callable[[Exception], None] = None,
):
"""
Initialize a RayFed client.
Expand Down Expand Up @@ -146,6 +147,9 @@ def init(
default fixed name will be assigned, therefore messages of all anonymous
jobs will be mixed together, which should only be used in the single job
scenario or test mode.
sending_failure_handler: optional; a callback which will be triggeed if
cross-silo message sending failed and exit_on_sending_failure in config is
True.
Examples:
>>> import fed
>>> import ray
Expand All @@ -164,7 +168,7 @@ def init(

fed_utils.validate_addresses(addresses)
init_global_context(
current_party=party, job_name=job_name, failure_handler=failure_handler
current_party=party, job_name=job_name, failure_handler=sending_failure_handler
)
tls_config = {} if tls_config is None else tls_config
if tls_config:
Expand Down Expand Up @@ -281,16 +285,42 @@ def _shutdown(intended=True):
Shutdown a RayFed client.
Args:
intended: (Optional) Whether this is a intended exit. If not
a "failure handler" will be triggered.
intended: (Optional) Whether this is a intended shutdown. If not
a "failure handler" will be triggered and sys.exit will be called then.
"""
if get_global_context() is not None:
# Job has inited, can be shutdown
failure_handler = get_global_context().get_failure_handler()

if get_global_context() is None:
# Do nothing since job has not been inited or is cleaned already.
return

if intended:
logger.info('Shutdowning rayfed intendedly...')
else:
logger.warn('Shutdowning rayfed unintendedly...')
global_context = get_global_context()
last_sending_error = global_context.get_cleanup_manager().get_last_sending_error()
if last_sending_error is not None:
logging.error(f'Cross-silo sending error occured. {last_sending_error}')

if not intended:
# Execute failure_handler fisrtly.
failure_handler = global_context.get_sending_failure_handler()
if failure_handler is not None:
logger.info('Executing failure handler...')
failure_handler(last_sending_error)

# Clean context.
compatible_utils._clear_internal_kv()
clear_global_context(graceful=intended)
logger.info('Shutdowned rayfed.')

# Exit with error.
logger.critical('Exit now due to the previous error.')
sys.exit(1)
else:
# Clean context.
compatible_utils._clear_internal_kv()
clear_global_context()
if not intended and failure_handler is not None:
failure_handler()
clear_global_context(graceful=intended)
logger.info('Shutdowned rayfed.')


Expand Down Expand Up @@ -474,14 +504,11 @@ def get(
if is_individual_id:
values = values[0]
return values
except RayError as e:
if isinstance(e, FedRemoteError):
logger.warning(
"Encounter RemoteError happend in other parties"
f", prepare to exit, error message: {e.cause}"
)
if get_global_context().acquire_shutdown_flag():
_shutdown(intended=False)
except FedRemoteError as e:
logger.warning(
"Encounter RemoteError happend in other parties"
f", error message: {e.cause}"
)
raise e


Expand Down
22 changes: 17 additions & 5 deletions fed/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, current_party, acquire_shutdown_flag) -> None:

self._current_party = current_party
self._acquire_shutdown_flag = acquire_shutdown_flag
self._last_sending_error = None

def start(self, exit_on_sending_failure=False, expose_error_trace=False):
self._exit_on_sending_failure = exit_on_sending_failure
Expand All @@ -70,18 +71,25 @@ def start(self, exit_on_sending_failure=False, expose_error_trace=False):
def _main_thread_monitor():
main_thread = threading.main_thread()
main_thread.join()
self.stop()
logging.debug('Stoping sending queue ...')
self.stop(graceful=True)

self._monitor_thread = threading.Thread(target=_main_thread_monitor)
self._monitor_thread.start()
logger.info('Start check sending monitor thread.')

def stop(self):
def stop(self, graceful=True):
# NOTE(NKcqx): MUST firstly stop the data queue, because it
# may still throw errors during the termination which need to
# be sent to the error queue.
self._sending_data_q.stop()
self._sending_error_q.stop()
if graceful:
self._sending_data_q.stop(immediately=False)
self._sending_error_q.stop(immediately=False)
else:
# Stop data queue immediately, but stop error queue not immediately always
# to sure that error can be sent to peers.
self._sending_data_q.stop(immediately=True)
self._sending_error_q.stop(immediately=False)

def push_to_sending(
self,
Expand Down Expand Up @@ -114,6 +122,9 @@ def push_to_sending(
else:
self._sending_data_q.append(msg_pack)

def get_last_sending_error(self):
return self._last_sending_error

def _signal_exit(self):
"""
Exit the current process immediately. The signal will be captured
Expand All @@ -129,7 +140,7 @@ def _signal_exit(self):
# once and avoid dead lock, the lock must be checked before sending
# signals.
if self._acquire_shutdown_flag():
logger.debug("Signal SIGINT to exit.")
logger.warn("Signal SIGINT to exit.")
os.kill(os.getpid(), signal.SIGINT)

def _process_data_sending_task_return(self, message):
Expand Down Expand Up @@ -161,6 +172,7 @@ def _process_data_sending_task_return(self, message):
f'upstream_seq_id: {upstream_seq_id}, '
f'downstream_seq_id: {downstream_seq_id}.'
)
self._last_sending_error = e
if isinstance(e, RayError):
logger.info(f"Sending error {e.cause} to {dest_party}.")
from fed.proxy.barriers import send
Expand Down
14 changes: 1 addition & 13 deletions fed/tests/test_cross_silo_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def error_func(self):


def run(party):
my_failure_handler = Mock()
compatible_utils.init_ray(address='local')
addresses = {
'alice': '127.0.0.1:11012',
Expand All @@ -57,12 +56,10 @@ def run(party):
logging_level='debug',
config={
'cross_silo_comm': {
'exit_on_sending_failure': True,
'timeout_ms': 20 * 1000,
'expose_error_trace': True,
},
},
failure_handler=my_failure_handler,
)

# Both party should catch the error
Expand All @@ -76,7 +73,6 @@ def run(party):
else:
assert isinstance(e.value.cause, MyError)
assert "normal task Error" in str(e.value.cause)
my_failure_handler.assert_called_once()
fed.shutdown()
ray.shutdown()

Expand All @@ -93,7 +89,6 @@ def test_cross_silo_normal_task_error():


def run2(party):
my_failure_handler = Mock()
compatible_utils.init_ray(address='local')
addresses = {
'alice': '127.0.0.1:11012',
Expand All @@ -105,12 +100,10 @@ def run2(party):
logging_level='debug',
config={
'cross_silo_comm': {
'exit_on_sending_failure': True,
'timeout_ms': 20 * 1000,
'expose_error_trace': True,
},
},
failure_handler=my_failure_handler,
)

# Both party should catch the error
Expand All @@ -123,11 +116,9 @@ def run2(party):
assert isinstance(e.value.cause, FedRemoteError)
assert 'RemoteError occurred at alice' in str(e.value.cause)
assert "actor task Error" in str(e.value.cause)
my_failure_handler.assert_called_once()
else:
assert isinstance(e.value.cause, MyError)
assert "actor task Error" in str(e.value.cause)
my_failure_handler.assert_called_once()

fed.shutdown()
ray.shutdown()
Expand All @@ -145,7 +136,6 @@ def test_cross_silo_actor_task_error():


def run3(party):
my_failure_handler = Mock()
compatible_utils.init_ray(address='local')
addresses = {
'alice': '127.0.0.1:11012',
Expand All @@ -158,11 +148,10 @@ def run3(party):
logging_level='debug',
config={
'cross_silo_comm': {
'exit_on_sending_failure': True,
'timeout_ms': 20 * 1000,
'expose_error_trace': False,
},
},
failure_handler=my_failure_handler,
)

# Both party should catch the error
Expand All @@ -176,7 +165,6 @@ def run3(party):
else:
assert isinstance(e.value.cause, MyError)
assert "normal task Error" in str(e.value.cause)
my_failure_handler.assert_called_once()
fed.shutdown()
ray.shutdown()

Expand Down
Loading

0 comments on commit 1e78372

Please sign in to comment.