diff --git a/fed/api.py b/fed/api.py index a8d2cbf..fa9ff1a 100644 --- a/fed/api.py +++ b/fed/api.py @@ -294,7 +294,9 @@ def shutdown(): """ Shutdown a RayFed client. """ - _shutdown(True) + global_context = get_global_context() + if global_context is not None and global_context.acquire_shutdown_flag(): + _shutdown(True) def _shutdown(intended=True): diff --git a/fed/tests/test_cross_silo_error.py b/fed/tests/test_cross_silo_error.py index c79a108..88d7f87 100644 --- a/fed/tests/test_cross_silo_error.py +++ b/fed/tests/test_cross_silo_error.py @@ -180,5 +180,52 @@ def test_cross_silo_not_expose_error_trace(): assert p_bob.exitcode == 0 +@fed.remote +def foo(e): + print(e) + + +def run4(party): + compatible_utils.init_ray(address='local') + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', + } + + fed.init( + addresses=addresses, + party=party, + logging_level='debug', + config={ + 'cross_silo_comm': { + 'timeout_ms': 20 * 1000, + 'expose_error_trace': False, + }, + }, + ) + + a = error_func.party("alice").remote() + o = foo.party('bob').remote(a) + if party == 'bob': + # Wait a while to receive error from alice. + import time + + time.sleep(10) + # Alice will shutdown once exactly. + fed.shutdown() + ray.shutdown() + + +def test_cross_silo_alice_send_error_and_shutdown_once(): + p_alice = multiprocessing.Process(target=run4, args=('alice',)) + p_bob = multiprocessing.Process(target=run4, args=('bob',)) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + assert p_alice.exitcode == 0 + assert p_bob.exitcode == 0 + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__]))