From 29aa57932bc953d8d336673f1b5c3778bf1acc2e Mon Sep 17 00:00:00 2001 From: "albert.zah" Date: Tue, 23 Jan 2024 08:57:14 +0000 Subject: [PATCH 1/2] fix: add shutdown lock. --- fed/api.py | 4 ++- fed/tests/test_cross_silo_error.py | 53 ++++++++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/fed/api.py b/fed/api.py index a8d2cbf..fa9ff1a 100644 --- a/fed/api.py +++ b/fed/api.py @@ -294,7 +294,9 @@ def shutdown(): """ Shutdown a RayFed client. """ - _shutdown(True) + global_context = get_global_context() + if global_context is not None and global_context.acquire_shutdown_flag(): + _shutdown(True) def _shutdown(intended=True): diff --git a/fed/tests/test_cross_silo_error.py b/fed/tests/test_cross_silo_error.py index c79a108..d080608 100644 --- a/fed/tests/test_cross_silo_error.py +++ b/fed/tests/test_cross_silo_error.py @@ -77,7 +77,7 @@ def run(party): ray.shutdown() -def test_cross_silo_normal_task_error(): +def cross_silo_normal_task_error(): p_alice = multiprocessing.Process(target=run, args=('alice',)) p_bob = multiprocessing.Process(target=run, args=('bob',)) p_alice.start() @@ -124,7 +124,7 @@ def run2(party): ray.shutdown() -def test_cross_silo_actor_task_error(): +def cross_silo_actor_task_error(): p_alice = multiprocessing.Process(target=run2, args=('alice',)) p_bob = multiprocessing.Process(target=run2, args=('bob',)) p_alice.start() @@ -169,7 +169,7 @@ def run3(party): ray.shutdown() -def test_cross_silo_not_expose_error_trace(): +def cross_silo_not_expose_error_trace(): p_alice = multiprocessing.Process(target=run3, args=('alice',)) p_bob = multiprocessing.Process(target=run3, args=('bob',)) p_alice.start() @@ -180,5 +180,52 @@ def test_cross_silo_not_expose_error_trace(): assert p_bob.exitcode == 0 +@fed.remote +def foo(e): + print(e) + + +def run4(party): + compatible_utils.init_ray(address='local') + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', + } + + fed.init( + addresses=addresses, + party=party, + logging_level='debug', + config={ + 'cross_silo_comm': { + 'timeout_ms': 20 * 1000, + 'expose_error_trace': False, + }, + }, + ) + + a = error_func.party("alice").remote() + o = foo.party('bob').remote(a) + if party == 'bob': + # Wait a while to receive error from alice. + import time + + time.sleep(10) + # Alice will shutdown once exactly. + fed.shutdown() + ray.shutdown() + + +def test_cross_silo_alice_send_error_and_shutdown_once(): + p_alice = multiprocessing.Process(target=run4, args=('alice',)) + p_bob = multiprocessing.Process(target=run4, args=('bob',)) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + assert p_alice.exitcode == 0 + assert p_bob.exitcode == 0 + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) From 58630985848c6d85e40ac28609c79ed465160369 Mon Sep 17 00:00:00 2001 From: "albert.zah" Date: Tue, 23 Jan 2024 09:00:36 +0000 Subject: [PATCH 2/2] Revert ut. --- fed/tests/test_cross_silo_error.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fed/tests/test_cross_silo_error.py b/fed/tests/test_cross_silo_error.py index d080608..88d7f87 100644 --- a/fed/tests/test_cross_silo_error.py +++ b/fed/tests/test_cross_silo_error.py @@ -77,7 +77,7 @@ def run(party): ray.shutdown() -def cross_silo_normal_task_error(): +def test_cross_silo_normal_task_error(): p_alice = multiprocessing.Process(target=run, args=('alice',)) p_bob = multiprocessing.Process(target=run, args=('bob',)) p_alice.start() @@ -124,7 +124,7 @@ def run2(party): ray.shutdown() -def cross_silo_actor_task_error(): +def test_cross_silo_actor_task_error(): p_alice = multiprocessing.Process(target=run2, args=('alice',)) p_bob = multiprocessing.Process(target=run2, args=('bob',)) p_alice.start() @@ -169,7 +169,7 @@ def run3(party): ray.shutdown() -def cross_silo_not_expose_error_trace(): +def test_cross_silo_not_expose_error_trace(): p_alice = multiprocessing.Process(target=run3, args=('alice',)) p_bob = multiprocessing.Process(target=run3, args=('bob',)) p_alice.start()