From 7ffe3673ff2225099f0d14d1d919f9a90e6f0633 Mon Sep 17 00:00:00 2001 From: paer Date: Wed, 12 Jul 2023 17:45:21 +0800 Subject: [PATCH] fix retry_policy update & get party grpc_options Signed-off-by: paer --- fed/proxy/barriers.py | 6 +- fed/proxy/grpc_proxy.py | 24 +++-- .../test_unpickle_with_whitelist.py | 2 +- tests/test_exit_on_failure_sending.py | 2 +- tests/test_grpc_options_on_proxies.py | 6 +- tests/test_grpc_options_per_party.py | 96 +++++++++++++++---- tests/test_party_specific_grpc_options.py | 78 --------------- tests/test_retry_policy.py | 2 +- tests/test_setup_proxy_actor.py | 2 +- 9 files changed, 101 insertions(+), 117 deletions(-) delete mode 100644 tests/test_party_specific_grpc_options.py diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index b711695..7e86277 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -81,7 +81,7 @@ async def send( async def is_ready(self): return True - async def get_proxy_config(self): + async def get_proxy_config(self, dest_party=None): return self._proxy_config @@ -180,8 +180,8 @@ async def _get_stats(self): async def _get_cluster_info(self): return self._cluster - async def _get_proxy_config(self): - return await self._proxy_instance.get_proxy_config() + async def _get_proxy_config(self, dest_party=None): + return await self._proxy_instance.get_proxy_config(dest_party) @ray.remote class RecverProxyActor: diff --git a/fed/proxy/grpc_proxy.py b/fed/proxy/grpc_proxy.py index 03a20fd..233c033 100644 --- a/fed/proxy/grpc_proxy.py +++ b/fed/proxy/grpc_proxy.py @@ -37,9 +37,11 @@ def parse_grpc_options(proxy_config: CrossSiloCommConfig): proxy_config.messages_max_size_in_bytes }) if isinstance(proxy_config, CrossSiloGrpcCommConfig): - grpc_channel_options.update(proxy_config.grpc_channel_options) + if proxy_config.grpc_channel_options is not None: + grpc_channel_options.update(proxy_config.grpc_channel_options) if proxy_config.grpc_retry_policy is not None: grpc_channel_options.update({ + 'grpc.service_config': json.dumps( { 'methodConfig': [ @@ -122,16 +124,20 @@ def get_grpc_config_by_party(self, dest_party): **dest_party_grpc_metadata } dest_party_grpc_options = parse_grpc_options(dest_party_comm_config) - grpc_options = fed_utils.dict2tuple({ + grpc_options = { **grpc_options, **dest_party_grpc_options - }) - return grpc_metadata, grpc_options - - async def get_proxy_config(self): + } + return grpc_metadata, fed_utils.dict2tuple(grpc_options) + + async def get_proxy_config(self, dest_party=None): + if dest_party is None: + grpc_options = fed_utils.dict2tuple(self._grpc_options) + else: + _, grpc_options = self.get_grpc_config_by_party(dest_party) proxy_config = self._proxy_config.__dict__ - proxy_config.update({'grpc_options': fed_utils.dict2tuple(self._grpc_options)}) + proxy_config.update({'grpc_options': grpc_options}) return proxy_config - + async def send_data_grpc( data, @@ -191,7 +197,7 @@ async def start(self): self._lock, self._server_ready_future, self._tls_config, - self._grpc_options, + fed_utils.dict2tuple(self._grpc_options), ) except RuntimeError as err: msg = f'Grpc server failed to listen to port: {port}' \ diff --git a/tests/serializations_tests/test_unpickle_with_whitelist.py b/tests/serializations_tests/test_unpickle_with_whitelist.py index 28b4ca8..0cf75e6 100644 --- a/tests/serializations_tests/test_unpickle_with_whitelist.py +++ b/tests/serializations_tests/test_unpickle_with_whitelist.py @@ -53,7 +53,7 @@ def run(party): fed.init( cluster=cluster, party=party, - cross_silo_comm_config=CrossSiloCommConfig( + global_cross_silo_comm_config=CrossSiloCommConfig( serializing_allowed_list=allowed_list )) diff --git a/tests/test_exit_on_failure_sending.py b/tests/test_exit_on_failure_sending.py index 3f6ed52..555e755 100644 --- a/tests/test_exit_on_failure_sending.py +++ b/tests/test_exit_on_failure_sending.py @@ -71,7 +71,7 @@ def run(party, is_inner_party): cluster=cluster, party=party, logging_level='debug', - cross_silo_comm_config=cross_silo_comm_config + global_cross_silo_comm_config=cross_silo_comm_config ) o = f.party("alice").remote() diff --git a/tests/test_grpc_options_on_proxies.py b/tests/test_grpc_options_on_proxies.py index cf2ebb3..b70bf0e 100644 --- a/tests/test_grpc_options_on_proxies.py +++ b/tests/test_grpc_options_on_proxies.py @@ -35,16 +35,14 @@ def run(party): fed.init( cluster=cluster, party=party, - cross_silo_comm_config=CrossSiloCommConfig( + global_cross_silo_comm_config=CrossSiloCommConfig( messages_max_size_in_bytes=100) ) def _assert_on_proxy(proxy_actor): config = ray.get(proxy_actor._get_proxy_config.remote()) - print(f"==============={config}==============") options = config['grpc_options'] - assert options[0][0] == "grpc.max_send_message_length" - assert options[0][1] == 100 + assert ("grpc.max_send_message_length", 100) in options assert ('grpc.so_reuseport', 0) in options send_proxy = ray.get_actor("SendProxyActor") diff --git a/tests/test_grpc_options_per_party.py b/tests/test_grpc_options_per_party.py index e99d3d9..53b70cb 100644 --- a/tests/test_grpc_options_per_party.py +++ b/tests/test_grpc_options_per_party.py @@ -18,7 +18,7 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import CrossSiloCommConfig +from fed.config import CrossSiloCommConfig, CrossSiloGrpcCommConfig @fed.remote @@ -31,39 +31,34 @@ def run(party): cluster = { 'alice': { 'address': '127.0.0.1:11010', - 'grpc_options': [ - ('grpc.default_authority', 'alice'), - ('grpc.max_send_message_length', 200) - ] + 'cross_silo_comm_config': CrossSiloGrpcCommConfig( + grpc_channel_options=[ + ('grpc.default_authority', 'alice'), + ('grpc.max_send_message_length', 200) + ]) }, 'bob': {'address': '127.0.0.1:11011'}, } fed.init( cluster=cluster, party=party, - cross_silo_comm_config=CrossSiloCommConfig( + global_cross_silo_comm_config=CrossSiloCommConfig( messages_max_size_in_bytes=100) ) def _assert_on_send_proxy(proxy_actor): - alice_config = ray.get(proxy_actor.setup_grpc_config.remote('alice')) + alice_config = ray.get(proxy_actor._get_proxy_config.remote('alice')) # print(f"【NKcqx】alice config: {alice_config}") assert 'grpc_options' in alice_config alice_options = alice_config['grpc_options'] - assert 'grpc.max_send_message_length' in alice_options - # This should be overwritten by cluster config - assert alice_options['grpc.max_send_message_length'] == 200 - assert 'grpc.default_authority' in alice_options - assert alice_options['grpc.default_authority'] == 'alice' - - bob_config = ray.get(proxy_actor.setup_grpc_config.remote('bob')) - # print(f"【NKcqx】bob config: {bob_config}") + assert ('grpc.max_send_message_length', 200) in alice_options + assert ('grpc.default_authority', 'alice') in alice_options + + bob_config = ray.get(proxy_actor._get_proxy_config.remote('bob')) assert 'grpc_options' in bob_config bob_options = bob_config['grpc_options'] - assert "grpc.max_send_message_length" in bob_options - # Not setting bob's grpc_options, should be the same with global - assert bob_options["grpc.max_send_message_length"] == 100 - assert 'grpc.default_authority' not in bob_options + assert ('grpc.max_send_message_length', 100) in bob_options + assert not any(o[0] == 'grpc.default_authority' for o in bob_options) send_proxy = ray.get_actor("SendProxyActor") _assert_on_send_proxy(send_proxy) @@ -86,6 +81,69 @@ def test_grpc_options(): assert p_alice.exitcode == 0 and p_bob.exitcode == 0 +def party_grpc_options(party): + compatible_utils.init_ray(address='local') + cluster = { + 'alice': { + 'address': '127.0.0.1:11010', + 'cross_silo_comm_config': CrossSiloGrpcCommConfig( + grpc_channel_options=[ + ('grpc.default_authority', 'alice'), + ('grpc.max_send_message_length', 51 * 1024 * 1024) + ]) + }, + 'bob': { + 'address': '127.0.0.1:11011', + 'cross_silo_comm_config': CrossSiloGrpcCommConfig( + grpc_channel_options=[ + ('grpc.default_authority', 'bob'), + ('grpc.max_send_message_length', 50 * 1024 * 1024) + ]) + }, + } + fed.init( + cluster=cluster, + party=party, + global_cross_silo_comm_config=CrossSiloCommConfig( + messages_max_size_in_bytes=100) + ) + + def _assert_on_send_proxy(proxy_actor): + alice_config = ray.get(proxy_actor._get_proxy_config.remote('alice')) + assert 'grpc_options' in alice_config + alice_options = alice_config['grpc_options'] + assert ('grpc.max_send_message_length', 51 * 1024 * 1024) in alice_options + assert ('grpc.default_authority', 'alice') in alice_options + + bob_config = ray.get(proxy_actor._get_proxy_config.remote('bob')) + assert 'grpc_options' in bob_config + bob_options = bob_config['grpc_options'] + assert ('grpc.max_send_message_length', 50 * 1024 * 1024) in bob_options + assert ('grpc.default_authority', 'bob') in bob_options + + send_proxy = ray.get_actor("SendProxyActor") + _assert_on_send_proxy(send_proxy) + + a = dummpy.party('alice').remote() + b = dummpy.party('bob').remote() + fed.get([a, b]) + + fed.shutdown() + ray.shutdown() + + +def test_party_specific_grpc_options(): + p_alice = multiprocessing.Process( + target=party_grpc_options, args=('alice',)) + p_bob = multiprocessing.Process( + target=party_grpc_options, args=('bob',)) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + assert p_alice.exitcode == 0 and p_bob.exitcode == 0 + + if __name__ == "__main__": import sys diff --git a/tests/test_party_specific_grpc_options.py b/tests/test_party_specific_grpc_options.py deleted file mode 100644 index 1c017ca..0000000 --- a/tests/test_party_specific_grpc_options.py +++ /dev/null @@ -1,78 +0,0 @@ -import multiprocessing -import pytest -import fed -import fed._private.compatible_utils as compatible_utils -import ray - -from fed.config import CrossSiloCommConfig - - -@fed.remote -def dummpy(): - return 2 - - -def party_grpc_options(party): - compatible_utils.init_ray(address='local') - cluster = { - 'alice': { - 'address': '127.0.0.1:11010', - 'grpc_channel_option': [ - ('grpc.default_authority', 'alice'), - ('grpc.max_send_message_length', 51 * 1024 * 1024) - ]}, - 'bob': { - 'address': '127.0.0.1:11011', - 'grpc_channel_option': [ - ('grpc.default_authority', 'bob'), - ('grpc.max_send_message_length', 50 * 1024 * 1024) - ]}, - } - fed.init( - cluster=cluster, - party=party, - cross_silo_comm_config=CrossSiloCommConfig( - messages_max_size_in_bytes=100) - ) - - def _assert_on_proxy(proxy_actor): - cluster_info = ray.get(proxy_actor._get_cluster_info.remote()) - assert cluster_info['alice'] is not None - assert cluster_info['alice']['grpc_channel_option'] is not None - alice_channel_options = cluster_info['alice']['grpc_channel_option'] - assert ('grpc.default_authority', 'alice') in alice_channel_options - assert ('grpc.max_send_message_length', 51 * 1024 * 1024) in alice_channel_options # noqa - - assert cluster_info['bob'] is not None - assert cluster_info['bob']['grpc_channel_option'] is not None - bob_channel_options = cluster_info['bob']['grpc_channel_option'] - assert ('grpc.default_authority', 'bob') in bob_channel_options - assert ('grpc.max_send_message_length', 50 * 1024 * 1024) in bob_channel_options # noqa - - send_proxy = ray.get_actor("SendProxyActor") - _assert_on_proxy(send_proxy) - - a = dummpy.party('alice').remote() - b = dummpy.party('bob').remote() - fed.get([a, b]) - - fed.shutdown() - ray.shutdown() - - -def test_party_specific_grpc_options(): - p_alice = multiprocessing.Process( - target=party_grpc_options, args=('alice',)) - p_bob = multiprocessing.Process( - target=party_grpc_options, args=('bob',)) - p_alice.start() - p_bob.start() - p_alice.join() - p_bob.join() - assert p_alice.exitcode == 0 and p_bob.exitcode == 0 - - -if __name__ == "__main__": - import sys - - sys.exit(pytest.main(["-sv", __file__])) diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index c2a45d8..15f37fc 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -53,7 +53,7 @@ def run(party, is_inner_party): fed.init( cluster=cluster, party=party, - cross_silo_comm_config=CrossSiloGrpcCommConfig( + global_cross_silo_comm_config=CrossSiloGrpcCommConfig( grpc_retry_policy=retry_policy ) ) diff --git a/tests/test_setup_proxy_actor.py b/tests/test_setup_proxy_actor.py index a8f53f6..c95fcef 100644 --- a/tests/test_setup_proxy_actor.py +++ b/tests/test_setup_proxy_actor.py @@ -65,7 +65,7 @@ def run_failure(party): fed.init( cluster=cluster, party=party, - cross_silo_comm_config=CrossSiloCommConfig( + global_cross_silo_comm_config=CrossSiloCommConfig( send_resource_label=send_proxy_resources, recv_resource_label=recv_proxy_resources, timeout_in_seconds=10,