diff --git a/fed/_private/global_context.py b/fed/_private/global_context.py index 1d73b64..b0367b1 100644 --- a/fed/_private/global_context.py +++ b/fed/_private/global_context.py @@ -12,15 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from fed.cleanup import CleanupManager + class GlobalContext: def __init__(self) -> None: self._seq_count = 0 + self._cleanup_manager = CleanupManager() - def next_seq_id(self): + def next_seq_id(self) -> int: self._seq_count += 1 return self._seq_count + def get_cleanup_manager(self) -> CleanupManager: + return self._cleanup_manager + _global_context = None @@ -34,4 +40,5 @@ def get_global_context(): def clear_global_context(): global _global_context + _global_context.get_cleanup_manager().graceful_stop() _global_context = None diff --git a/fed/api.py b/fed/api.py index b2fbfcb..8548c5f 100644 --- a/fed/api.py +++ b/fed/api.py @@ -34,7 +34,6 @@ start_recv_proxy, start_send_proxy, ) -from fed.cleanup import set_exit_on_failure_sending, wait_sending from fed.fed_object import FedObject from fed.utils import is_ray_object_refs, setup_logger @@ -215,7 +214,9 @@ def init( ) logger.info(f'Started rayfed with {cluster_config}') - set_exit_on_failure_sending(exit_on_failure_cross_silo_sending) + get_global_context().get_cleanup_manager().start( + exit_when_failure_sending=exit_on_failure_cross_silo_sending) + recv_actor_config = fed_config.ProxyActorConfig( resource_label=cross_silo_recv_resource_label) # Start recv proxy @@ -249,7 +250,6 @@ def shutdown(): """ Shutdown a RayFed client. """ - wait_sending() compatible_utils._clear_internal_kv() clear_global_context() logger.info('Shutdowned rayfed.') diff --git a/fed/cleanup.py b/fed/cleanup.py index 495ff0d..cf06a1d 100644 --- a/fed/cleanup.py +++ b/fed/cleanup.py @@ -23,97 +23,84 @@ logger = logging.getLogger(__name__) -_sending_obj_refs_q = None -_check_send_thread = None - -_EXIT_ON_FAILURE_SENDING = False - - -def set_exit_on_failure_sending(exit_when_failure_sending: bool): - global _EXIT_ON_FAILURE_SENDING - _EXIT_ON_FAILURE_SENDING = exit_when_failure_sending - - -def get_exit_when_failure_sending(): - global _EXIT_ON_FAILURE_SENDING - return _EXIT_ON_FAILURE_SENDING - - -def _check_sending_objs(): - def _signal_exit(): - os.kill(os.getpid(), signal.SIGTERM) - - global _sending_obj_refs_q - if not _sending_obj_refs_q: - _sending_obj_refs_q = deque() - - while True: - try: - obj_ref = _sending_obj_refs_q.popleft() - except IndexError: - time.sleep(0.5) - continue - if isinstance(obj_ref, bool): - break - try: - res = ray.get(obj_ref) - except Exception as e: - logger.warn(f'Failed to send {obj_ref} with error: {e}') - res = False - if not res and get_exit_when_failure_sending(): - logger.warn('Signal self to exit.') - _signal_exit() - break - - logger.info('Check sending thread was exited.') - global _check_send_thread - _check_send_thread = None - logger.info('Clearing sending queue.') - _sending_obj_refs_q = None - - -def _main_thread_monitor(): - main_thread = threading.main_thread() - main_thread.join() - notify_to_exit() - - -_monitor_thread = None - - -def _start_check_sending(): - global _sending_obj_refs_q - if not _sending_obj_refs_q: - _sending_obj_refs_q = deque() - - global _check_send_thread - if not _check_send_thread: - _check_send_thread = threading.Thread(target=_check_sending_objs) - _check_send_thread.start() +class CleanupManager: + """ + This class is used to manage the related works when the fed driver exiting. + It monitors whether the main thread is broken and it needs wait until all sending + objects get repsonsed. + + The main logic path is: + A. If `fed.shutdown()` is invoked in the main thread and every thing works well, + the `graceful_stop()` will be invoked as well and the checking thread will be + notifiled to exit gracefully. + + B. If the main thread are broken before sending the notification flag to the + sending thread, the monitor thread will detect that and it joins until the main + thread exited, then notifys the checking thread. + """ + + def __init__(self) -> None: + # `deque()` is thread safe on `popleft` and `append` operations. + # See https://docs.python.org/3/library/collections.html#deque-objects + self._sending_obj_refs_q = deque() + self._check_send_thread = None + self._monitor_thread = None + + def start(self, exit_when_failure_sending=False): + self._exit_when_failure_sending = exit_when_failure_sending + + def __check_func(): + self._check_sending_objs() + + self._check_send_thread = threading.Thread(target=__check_func) + self._check_send_thread.start() logger.info('Start check sending thread.') - global _monitor_thread - if not _monitor_thread: - _monitor_thread = threading.Thread(target=_main_thread_monitor) - _monitor_thread.start() - logger.info('Start check sending monitor thread.') - - -def push_to_sending(obj_ref: ray.ObjectRef): - _start_check_sending() - global _sending_obj_refs_q - _sending_obj_refs_q.append(obj_ref) - - -def notify_to_exit(): - # Sending the termination signal - push_to_sending(True) - logger.info('Notify check sending thread to exit.') - - -def wait_sending(): - global _check_send_thread - if _check_send_thread: - notify_to_exit() - _check_send_thread.join() + def _main_thread_monitor(): + main_thread = threading.main_thread() + main_thread.join() + self._notify_to_exit() + + self._monitor_thread = threading.Thread(target=_main_thread_monitor) + self._monitor_thread.start() + logger.info('Start check sending monitor thread.') + + def graceful_stop(self): + assert self._check_send_thread is not None + self._notify_to_exit() + self._check_send_thread.join() + + def _notify_to_exit(self): + # Sending the termination signal + self.push_to_sending(True) + logger.info('Notify check sending thread to exit.') + + def push_to_sending(self, obj_ref: ray.ObjectRef): + self._sending_obj_refs_q.append(obj_ref) + + def _check_sending_objs(self): + def _signal_exit(): + os.kill(os.getpid(), signal.SIGTERM) + + assert self._sending_obj_refs_q is not None + + while True: + try: + obj_ref = self._sending_obj_refs_q.popleft() + except IndexError: + time.sleep(0.1) + continue + if isinstance(obj_ref, bool): + break + try: + res = ray.get(obj_ref) + except Exception as e: + logger.warn(f'Failed to send {obj_ref} with error: {e}') + res = False + if not res and self._exit_when_failure_sending: + logger.warn('Signal self to exit.') + _signal_exit() + break + + logger.info('Check sending thread was exited.') diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 341515a..8ec0531 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -27,10 +27,10 @@ import fed.utils as fed_utils from fed._private import constants from fed._private.grpc_options import get_grpc_options, set_max_message_length -from fed.cleanup import push_to_sending from fed.config import get_cluster_config from fed.grpc import fed_pb2, fed_pb2_grpc from fed.utils import setup_logger +from fed._private.global_context import get_global_context logger = logging.getLogger(__name__) @@ -452,7 +452,7 @@ def send( upstream_seq_id=upstream_seq_id, downstream_seq_id=downstream_seq_id, ) - push_to_sending(res) + get_global_context().get_cleanup_manager().push_to_sending(res) return res diff --git a/tests/test_internal_kv.py b/tests/test_internal_kv.py index b36a773..c57dec4 100644 --- a/tests/test_internal_kv.py +++ b/tests/test_internal_kv.py @@ -6,28 +6,29 @@ import fed._private.compatible_utils as compatible_utils -def test_kv_init(): - def run(party): - compatible_utils.init_ray("local") - cluster = { - 'alice': {'address': '127.0.0.1:11010', 'listen_addr': '0.0.0.0:11010'}, - 'bob': {'address': '127.0.0.1:11011', 'listen_addr': '0.0.0.0:11011'}, - } - assert compatible_utils.kv is None - fed.init(cluster=cluster, party=party) - assert compatible_utils.kv - assert not compatible_utils.kv.put(b"test_key", b"test_val") - assert compatible_utils.kv.get(b"test_key") == b"test_val" - - time.sleep(5) - fed.shutdown() - ray.shutdown() - - assert compatible_utils.kv is None - with pytest.raises(ValueError): - # Make sure the kv actor is non-exist no matter whether it's in client mode - ray.get_actor("_INTERNAL_KV_ACTOR") +def run(party): + compatible_utils.init_ray("local") + cluster = { + 'alice': {'address': '127.0.0.1:11010', 'listen_addr': '0.0.0.0:11010'}, + 'bob': {'address': '127.0.0.1:11011', 'listen_addr': '0.0.0.0:11011'}, + } + assert compatible_utils.kv is None + fed.init(cluster=cluster, party=party) + assert compatible_utils.kv + assert not compatible_utils.kv.put(b"test_key", b"test_val") + assert compatible_utils.kv.get(b"test_key") == b"test_val" + + time.sleep(5) + fed.shutdown() + + assert compatible_utils.kv is None + with pytest.raises(ValueError): + # Make sure the kv actor is non-exist no matter whether it's in client mode + ray.get_actor("_INTERNAL_KV_ACTOR") + ray.shutdown() + +def test_kv_init(): p_alice = multiprocessing.Process(target=run, args=('alice',)) p_bob = multiprocessing.Process(target=run, args=('bob',)) p_alice.start() diff --git a/tests/test_listen_addr.py b/tests/test_listen_addr.py index b0ab900..72753e2 100644 --- a/tests/test_listen_addr.py +++ b/tests/test_listen_addr.py @@ -34,28 +34,29 @@ def get_value(self): return self._value -def test_listen_addr(): - def run(party, is_inner_party): - compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012', 'listen_addr': '0.0.0.0:11012'}, - 'bob': {'address': '127.0.0.1:11011', 'listen_addr': '0.0.0.0:11011'}, - } - fed.init(cluster=cluster, party=party) - - o = f.party("alice").remote() - actor_location = "alice" if is_inner_party else "bob" - my = My.party(actor_location).remote(o) - val = my.get_value.remote() - result = fed.get(val) - assert result == 100 - assert fed.get(o) == 100 - import time - - time.sleep(5) - fed.shutdown() - ray.shutdown() +def run(party, is_inner_party): + compatible_utils.init_ray(address='local') + cluster = { + 'alice': {'address': '127.0.0.1:11012', 'listen_addr': '0.0.0.0:11012'}, + 'bob': {'address': '127.0.0.1:11011', 'listen_addr': '0.0.0.0:11011'}, + } + fed.init(cluster=cluster, party=party) + + o = f.party("alice").remote() + actor_location = "alice" if is_inner_party else "bob" + my = My.party(actor_location).remote(o) + val = my.get_value.remote() + result = fed.get(val) + assert result == 100 + assert fed.get(o) == 100 + import time + + time.sleep(5) + fed.shutdown() + ray.shutdown() + +def test_listen_addr(): p_alice = multiprocessing.Process(target=run, args=('alice', True)) p_bob = multiprocessing.Process(target=run, args=('bob', True)) p_alice.start() diff --git a/tests/test_repeat_init.py b/tests/test_repeat_init.py index 42eaa74..5e2d49f 100644 --- a/tests/test_repeat_init.py +++ b/tests/test_repeat_init.py @@ -16,13 +16,10 @@ import multiprocessing import pytest -import time import fed import fed._private.compatible_utils as compatible_utils import ray -from fed.cleanup import _start_check_sending, push_to_sending - @fed.remote class My: @@ -48,16 +45,8 @@ def bar(self, li): def run(party): def _run(): - assert fed.cleanup._sending_obj_refs_q is None compatible_utils.init_ray(address='local') fed.init(cluster=cluster, party=party) - _start_check_sending() - time.sleep(0.5) - assert fed.cleanup._sending_obj_refs_q is not None - push_to_sending(True) - # Slightly longer than the queue polling - time.sleep(0.6) - assert fed.cleanup._sending_obj_refs_q is None my1 = My.party("alice").remote() my2 = My.party("bob").remote() @@ -71,7 +60,6 @@ def _run(): fed.shutdown() ray.shutdown() - assert fed.cleanup._sending_obj_refs_q is None _run() _run() diff --git a/tests/test_setup_proxy_actor.py b/tests/test_setup_proxy_actor.py index 7a3aabd..5b8cf8b 100644 --- a/tests/test_setup_proxy_actor.py +++ b/tests/test_setup_proxy_actor.py @@ -21,32 +21,58 @@ import ray -def test_setup_proxy_success(): - def run(party): - compatible_utils.init_ray(address='local', resources={"127.0.0.1": 2}) - cluster = { - 'alice': {'address': '127.0.0.1:11010'}, - 'bob': {'address': '127.0.0.1:11011'}, - } - send_proxy_resources = { - "127.0.0.1": 1 - } - recv_proxy_resources = { - "127.0.0.1": 1 - } +def run(party): + compatible_utils.init_ray(address='local', resources={"127.0.0.1": 2}) + cluster = { + 'alice': {'address': '127.0.0.1:11010'}, + 'bob': {'address': '127.0.0.1:11011'}, + } + send_proxy_resources = { + "127.0.0.1": 1 + } + recv_proxy_resources = { + "127.0.0.1": 1 + } + fed.init( + cluster=cluster, + party=party, + cross_silo_send_resource_label=send_proxy_resources, + cross_silo_recv_resource_label=recv_proxy_resources, + ) + + assert ray.get_actor("SendProxyActor") is not None + assert ray.get_actor(f"RecverProxyActor-{party}") is not None + + fed.shutdown() + ray.shutdown() + + +def run_failure(party): + compatible_utils.init_ray(address='local', resources={"127.0.0.1": 1}) + cluster = { + 'alice': {'address': '127.0.0.1:11010'}, + 'bob': {'address': '127.0.0.1:11011'}, + } + send_proxy_resources = { + "127.0.0.2": 1 # Insufficient resource + } + recv_proxy_resources = { + "127.0.0.2": 1 # Insufficient resource + } + with pytest.raises(ray.exceptions.GetTimeoutError): fed.init( cluster=cluster, party=party, cross_silo_send_resource_label=send_proxy_resources, cross_silo_recv_resource_label=recv_proxy_resources, + cross_silo_timeout_in_seconds=10, # Quick fail in test ) - assert ray.get_actor("SendProxyActor") is not None - assert ray.get_actor(f"RecverProxyActor-{party}") is not None + fed.shutdown() + ray.shutdown() - fed.shutdown() - ray.shutdown() +def test_setup_proxy_success(): p_alice = multiprocessing.Process(target=run, args=('alice',)) p_bob = multiprocessing.Process(target=run, args=('bob',)) p_alice.start() @@ -57,32 +83,8 @@ def run(party): def test_setup_proxy_failed(): - def run(party): - compatible_utils.init_ray(address='local', resources={"127.0.0.1": 1}) - cluster = { - 'alice': {'address': '127.0.0.1:11010'}, - 'bob': {'address': '127.0.0.1:11011'}, - } - send_proxy_resources = { - "127.0.0.2": 1 # Insufficient resource - } - recv_proxy_resources = { - "127.0.0.2": 1 # Insufficient resource - } - with pytest.raises(ray.exceptions.GetTimeoutError): - fed.init( - cluster=cluster, - party=party, - cross_silo_send_resource_label=send_proxy_resources, - cross_silo_recv_resource_label=recv_proxy_resources, - cross_silo_timeout_in_seconds=10, # Quick fail in test - ) - - fed.shutdown() - ray.shutdown() - - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run_failure, args=('alice',)) + p_bob = multiprocessing.Process(target=run_failure, args=('bob',)) p_alice.start() p_bob.start() p_alice.join() diff --git a/tests/test_transport_proxy.py b/tests/test_transport_proxy.py index 89a3ecd..ae5274c 100644 --- a/tests/test_transport_proxy.py +++ b/tests/test_transport_proxy.py @@ -21,9 +21,9 @@ import fed._private.compatible_utils as compatible_utils from fed._private import constants +from fed._private import global_context from fed.grpc import fed_pb2, fed_pb2_grpc from fed.proxy.barriers import send, start_recv_proxy, start_send_proxy -from fed.cleanup import wait_sending def test_n_to_1_transport(): @@ -33,6 +33,7 @@ def test_n_to_1_transport(): """ compatible_utils.init_ray(address='local') + global_context.get_global_context().get_cleanup_manager().start() cluster_config = { constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", @@ -70,7 +71,8 @@ def test_n_to_1_transport(): for i in range(NUM_DATA): assert f"data-{i}" in ray.get(get_objs) - wait_sending() + global_context.get_global_context().get_cleanup_manager().graceful_stop() + global_context.clear_global_context() ray.shutdown() @@ -171,6 +173,7 @@ def test_send_grpc_with_meta(): cloudpickle.dumps(cluster_config)) compatible_utils.kv.put(constants.KEY_OF_JOB_CONFIG, cloudpickle.dumps(job_config)) + global_context.get_global_context().get_cleanup_manager().start() SERVER_ADDRESS = "127.0.0.1:12344" party = 'test_party' @@ -186,7 +189,8 @@ def test_send_grpc_with_meta(): for result in ray.get(sent_objs): assert result - wait_sending() + global_context.get_global_context().get_cleanup_manager().graceful_stop() + global_context.clear_global_context() ray.shutdown() @@ -210,6 +214,7 @@ def test_send_grpc_with_party_specific_meta(): cloudpickle.dumps(cluster_config)) compatible_utils.kv.put(constants.KEY_OF_JOB_CONFIG, cloudpickle.dumps(job_config)) + global_context.get_global_context().get_cleanup_manager().start() SERVER_ADDRESS = "127.0.0.1:12344" party = 'test_party' @@ -230,7 +235,8 @@ def test_send_grpc_with_party_specific_meta(): for result in ray.get(sent_objs): assert result - wait_sending() + global_context.get_global_context().get_cleanup_manager().graceful_stop() + global_context.clear_global_context() ray.shutdown() diff --git a/tests/test_transport_proxy_tls.py b/tests/test_transport_proxy_tls.py index 0fe5d95..65d24f9 100644 --- a/tests/test_transport_proxy_tls.py +++ b/tests/test_transport_proxy_tls.py @@ -20,8 +20,8 @@ import fed._private.compatible_utils as compatible_utils from fed._private import constants +from fed._private import global_context from fed.proxy.barriers import send, start_recv_proxy, start_send_proxy -from fed.cleanup import wait_sending def test_n_to_1_transport(): @@ -48,6 +48,8 @@ def test_n_to_1_transport(): constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: {}, constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: 60, } + + global_context.get_global_context().get_cleanup_manager().start() compatible_utils._init_internal_kv() compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)) @@ -88,7 +90,8 @@ def test_n_to_1_transport(): for i in range(NUM_DATA): assert f"data-{i}" in ray.get(get_objs) - wait_sending() + global_context.get_global_context().get_cleanup_manager().graceful_stop() + global_context.clear_global_context() ray.shutdown()