Skip to content

Commit

Permalink
fix retry_policy update & get party grpc_options
Browse files Browse the repository at this point in the history
Signed-off-by: paer <chenqixiang.cqx@antgroup.com>
  • Loading branch information
paer committed Jul 12, 2023
1 parent a432d31 commit 7ffe367
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 117 deletions.
6 changes: 3 additions & 3 deletions fed/proxy/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def send(
async def is_ready(self):
return True

async def get_proxy_config(self):
async def get_proxy_config(self, dest_party=None):
return self._proxy_config


Expand Down Expand Up @@ -180,8 +180,8 @@ async def _get_stats(self):
async def _get_cluster_info(self):
return self._cluster

async def _get_proxy_config(self):
return await self._proxy_instance.get_proxy_config()
async def _get_proxy_config(self, dest_party=None):
return await self._proxy_instance.get_proxy_config(dest_party)

@ray.remote
class RecverProxyActor:
Expand Down
24 changes: 15 additions & 9 deletions fed/proxy/grpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ def parse_grpc_options(proxy_config: CrossSiloCommConfig):
proxy_config.messages_max_size_in_bytes
})
if isinstance(proxy_config, CrossSiloGrpcCommConfig):
grpc_channel_options.update(proxy_config.grpc_channel_options)
if proxy_config.grpc_channel_options is not None:
grpc_channel_options.update(proxy_config.grpc_channel_options)
if proxy_config.grpc_retry_policy is not None:
grpc_channel_options.update({
'grpc.service_config':
json.dumps(
{
'methodConfig': [
Expand Down Expand Up @@ -122,16 +124,20 @@ def get_grpc_config_by_party(self, dest_party):
**dest_party_grpc_metadata
}
dest_party_grpc_options = parse_grpc_options(dest_party_comm_config)
grpc_options = fed_utils.dict2tuple({
grpc_options = {
**grpc_options, **dest_party_grpc_options
})
return grpc_metadata, grpc_options

async def get_proxy_config(self):
}
return grpc_metadata, fed_utils.dict2tuple(grpc_options)

async def get_proxy_config(self, dest_party=None):
if dest_party is None:
grpc_options = fed_utils.dict2tuple(self._grpc_options)
else:
_, grpc_options = self.get_grpc_config_by_party(dest_party)
proxy_config = self._proxy_config.__dict__
proxy_config.update({'grpc_options': fed_utils.dict2tuple(self._grpc_options)})
proxy_config.update({'grpc_options': grpc_options})
return proxy_config


async def send_data_grpc(
data,
Expand Down Expand Up @@ -191,7 +197,7 @@ async def start(self):
self._lock,
self._server_ready_future,
self._tls_config,
self._grpc_options,
fed_utils.dict2tuple(self._grpc_options),
)
except RuntimeError as err:
msg = f'Grpc server failed to listen to port: {port}' \
Expand Down
2 changes: 1 addition & 1 deletion tests/serializations_tests/test_unpickle_with_whitelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def run(party):
fed.init(
cluster=cluster,
party=party,
cross_silo_comm_config=CrossSiloCommConfig(
global_cross_silo_comm_config=CrossSiloCommConfig(
serializing_allowed_list=allowed_list
))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_exit_on_failure_sending.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def run(party, is_inner_party):
cluster=cluster,
party=party,
logging_level='debug',
cross_silo_comm_config=cross_silo_comm_config
global_cross_silo_comm_config=cross_silo_comm_config
)

o = f.party("alice").remote()
Expand Down
6 changes: 2 additions & 4 deletions tests/test_grpc_options_on_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,14 @@ def run(party):
fed.init(
cluster=cluster,
party=party,
cross_silo_comm_config=CrossSiloCommConfig(
global_cross_silo_comm_config=CrossSiloCommConfig(
messages_max_size_in_bytes=100)
)

def _assert_on_proxy(proxy_actor):
config = ray.get(proxy_actor._get_proxy_config.remote())
print(f"==============={config}==============")
options = config['grpc_options']
assert options[0][0] == "grpc.max_send_message_length"
assert options[0][1] == 100
assert ("grpc.max_send_message_length", 100) in options
assert ('grpc.so_reuseport', 0) in options

send_proxy = ray.get_actor("SendProxyActor")
Expand Down
96 changes: 77 additions & 19 deletions tests/test_grpc_options_per_party.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import fed._private.compatible_utils as compatible_utils
import ray

from fed.config import CrossSiloCommConfig
from fed.config import CrossSiloCommConfig, CrossSiloGrpcCommConfig


@fed.remote
Expand All @@ -31,39 +31,34 @@ def run(party):
cluster = {
'alice': {
'address': '127.0.0.1:11010',
'grpc_options': [
('grpc.default_authority', 'alice'),
('grpc.max_send_message_length', 200)
]
'cross_silo_comm_config': CrossSiloGrpcCommConfig(
grpc_channel_options=[
('grpc.default_authority', 'alice'),
('grpc.max_send_message_length', 200)
])
},
'bob': {'address': '127.0.0.1:11011'},
}
fed.init(
cluster=cluster,
party=party,
cross_silo_comm_config=CrossSiloCommConfig(
global_cross_silo_comm_config=CrossSiloCommConfig(
messages_max_size_in_bytes=100)
)

def _assert_on_send_proxy(proxy_actor):
alice_config = ray.get(proxy_actor.setup_grpc_config.remote('alice'))
alice_config = ray.get(proxy_actor._get_proxy_config.remote('alice'))
# print(f"【NKcqx】alice config: {alice_config}")
assert 'grpc_options' in alice_config
alice_options = alice_config['grpc_options']
assert 'grpc.max_send_message_length' in alice_options
# This should be overwritten by cluster config
assert alice_options['grpc.max_send_message_length'] == 200
assert 'grpc.default_authority' in alice_options
assert alice_options['grpc.default_authority'] == 'alice'

bob_config = ray.get(proxy_actor.setup_grpc_config.remote('bob'))
# print(f"【NKcqx】bob config: {bob_config}")
assert ('grpc.max_send_message_length', 200) in alice_options
assert ('grpc.default_authority', 'alice') in alice_options

bob_config = ray.get(proxy_actor._get_proxy_config.remote('bob'))
assert 'grpc_options' in bob_config
bob_options = bob_config['grpc_options']
assert "grpc.max_send_message_length" in bob_options
# Not setting bob's grpc_options, should be the same with global
assert bob_options["grpc.max_send_message_length"] == 100
assert 'grpc.default_authority' not in bob_options
assert ('grpc.max_send_message_length', 100) in bob_options
assert not any(o[0] == 'grpc.default_authority' for o in bob_options)

send_proxy = ray.get_actor("SendProxyActor")
_assert_on_send_proxy(send_proxy)
Expand All @@ -86,6 +81,69 @@ def test_grpc_options():
assert p_alice.exitcode == 0 and p_bob.exitcode == 0


def party_grpc_options(party):
compatible_utils.init_ray(address='local')
cluster = {
'alice': {
'address': '127.0.0.1:11010',
'cross_silo_comm_config': CrossSiloGrpcCommConfig(
grpc_channel_options=[
('grpc.default_authority', 'alice'),
('grpc.max_send_message_length', 51 * 1024 * 1024)
])
},
'bob': {
'address': '127.0.0.1:11011',
'cross_silo_comm_config': CrossSiloGrpcCommConfig(
grpc_channel_options=[
('grpc.default_authority', 'bob'),
('grpc.max_send_message_length', 50 * 1024 * 1024)
])
},
}
fed.init(
cluster=cluster,
party=party,
global_cross_silo_comm_config=CrossSiloCommConfig(
messages_max_size_in_bytes=100)
)

def _assert_on_send_proxy(proxy_actor):
alice_config = ray.get(proxy_actor._get_proxy_config.remote('alice'))
assert 'grpc_options' in alice_config
alice_options = alice_config['grpc_options']
assert ('grpc.max_send_message_length', 51 * 1024 * 1024) in alice_options
assert ('grpc.default_authority', 'alice') in alice_options

bob_config = ray.get(proxy_actor._get_proxy_config.remote('bob'))
assert 'grpc_options' in bob_config
bob_options = bob_config['grpc_options']
assert ('grpc.max_send_message_length', 50 * 1024 * 1024) in bob_options
assert ('grpc.default_authority', 'bob') in bob_options

send_proxy = ray.get_actor("SendProxyActor")
_assert_on_send_proxy(send_proxy)

a = dummpy.party('alice').remote()
b = dummpy.party('bob').remote()
fed.get([a, b])

fed.shutdown()
ray.shutdown()


def test_party_specific_grpc_options():
p_alice = multiprocessing.Process(
target=party_grpc_options, args=('alice',))
p_bob = multiprocessing.Process(
target=party_grpc_options, args=('bob',))
p_alice.start()
p_bob.start()
p_alice.join()
p_bob.join()
assert p_alice.exitcode == 0 and p_bob.exitcode == 0


if __name__ == "__main__":
import sys

Expand Down
78 changes: 0 additions & 78 deletions tests/test_party_specific_grpc_options.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/test_retry_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def run(party, is_inner_party):
fed.init(
cluster=cluster,
party=party,
cross_silo_comm_config=CrossSiloGrpcCommConfig(
global_cross_silo_comm_config=CrossSiloGrpcCommConfig(
grpc_retry_policy=retry_policy
)
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_setup_proxy_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def run_failure(party):
fed.init(
cluster=cluster,
party=party,
cross_silo_comm_config=CrossSiloCommConfig(
global_cross_silo_comm_config=CrossSiloCommConfig(
send_resource_label=send_proxy_resources,
recv_resource_label=recv_proxy_resources,
timeout_in_seconds=10,
Expand Down

0 comments on commit 7ffe367

Please sign in to comment.