Skip to content

Commit

Permalink
fix missing grpc_options (#182)
Browse files Browse the repository at this point in the history
Signed-off-by: paer <chenqixiang.cqx@antgroup.com>
Co-authored-by: paer <chenqixiang.cqx@antgroup.com>
  • Loading branch information
NKcqx and paer authored Nov 14, 2023
1 parent b688dc9 commit 54304c2
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 5 deletions.
27 changes: 27 additions & 0 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,33 @@ def init(
'carol': '127.0.0.1:10003',
}
party: optional; self party.
config: optional; a dict describes general job configurations. Currently the
supported configurations are [`cross_silo_comm`, 'barrier_on_initializing'].
* `cross_silo_comm`: optional; a dict describes the cross-silo common
configs, the supported configs can be referred to
`fed.config.CrossSiloMessageConfig` and
`fed.config.GrpcCrossSiloMessageConfig`. Note that, the
`cross_silo_comm.messages_max_size_in_bytes` will be overrided
if `cross_silo_comm.grpc_channel_options` is provided and contains
`grpc.max_send_message_length` or `grpc.max_receive_message_length`.
* `barrier_on_initializing`: optional; a bool value indicates whether to
wait for all parties to be ready before starting the job. If set
to True, the job will be started after all parties are ready,
otherwise, the job will be started immediately after the current
party is ready.
Example:
.. code:: python
{
"cross_silo_comm": {
"messages_max_size_in_bytes": 500*1024,
"timeout_in_ms": 1000,
"exit_on_sending_failure": True,
"expose_error_trace": True,
},
"barrier_on_initializing": True,
}
tls_config: optional; a dict describes the tls config. E.g.
For alice,
Expand Down
3 changes: 1 addition & 2 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ class CrossSiloMessageConfig:
cross-silo sending. If True, a SIGTERM will be signaled to self
if failed to sending cross-silo data.
messages_max_size_in_bytes: The maximum length in bytes of
cross-silo messages.
If None, the default value of 500 MB is specified.
cross-silo messages. If None, the default value of 500 MB is specified.
timeout_in_ms: The timeout in mili-seconds of a cross-silo RPC call.
It's 60000 by default.
http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request.
Expand Down
17 changes: 15 additions & 2 deletions fed/proxy/grpc/grpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,21 @@ def parse_grpc_options(proxy_config: CrossSiloMessageConfig):
dict: A dictionary containing the gRPC channel options.
"""
grpc_channel_options = {}
if proxy_config is not None and isinstance(
proxy_config, GrpcCrossSiloMessageConfig):
if proxy_config is not None:
# NOTE(NKcqx): `messages_max_size_in_bytes` is a common cross-silo
# config that should be extracted and filled into proper grpc's
# channel options.
# However, `GrpcCrossSiloMessageConfig` provides a more flexible way
# to configure grpc channel options, i.e. the `grpc_channel_options`
# field, which may override the `messages_max_size_in_bytes` field.
if (isinstance(proxy_config, CrossSiloMessageConfig)):
if (proxy_config.messages_max_size_in_bytes is not None):
grpc_channel_options.update({
'grpc.max_send_message_length':
proxy_config.messages_max_size_in_bytes,
'grpc.max_receive_message_length':
proxy_config.messages_max_size_in_bytes,
})
if isinstance(proxy_config, GrpcCrossSiloMessageConfig):
if proxy_config.grpc_channel_options is not None:
grpc_channel_options.update(proxy_config.grpc_channel_options)
Expand Down
97 changes: 96 additions & 1 deletion fed/tests/test_grpc_options_on_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _assert_on_proxy(proxy_actor):
ray.shutdown()


def test_grpc_max_size():
def test_grpc_max_size_by_channel_options():
p_alice = multiprocessing.Process(target=run, args=('alice',))
p_bob = multiprocessing.Process(target=run, args=('bob',))
p_alice.start()
Expand All @@ -71,6 +71,101 @@ def test_grpc_max_size():
assert p_alice.exitcode == 0 and p_bob.exitcode == 0


def run2(party):
compatible_utils.init_ray(address='local')
addresses = {
'alice': '127.0.0.1:11019',
'bob': '127.0.0.1:11018',
}
fed.init(
addresses=addresses,
party=party,
config={
"cross_silo_comm": {
"messages_max_size_in_bytes": 100,
},
},
)

def _assert_on_proxy(proxy_actor):
config = ray.get(proxy_actor._get_proxy_config.remote())
options = config['grpc_options']
assert ("grpc.max_send_message_length", 100) in options
assert ("grpc.max_receive_message_length", 100) in options
assert ('grpc.so_reuseport', 0) in options

sender_proxy = ray.get_actor(sender_proxy_actor_name())
receiver_proxy = ray.get_actor(receiver_proxy_actor_name())
_assert_on_proxy(sender_proxy)
_assert_on_proxy(receiver_proxy)

a = dummpy.party('alice').remote()
b = dummpy.party('bob').remote()
fed.get([a, b])

fed.shutdown()
ray.shutdown()


def test_grpc_max_size_by_common_config():
p_alice = multiprocessing.Process(target=run2, args=('alice',))
p_bob = multiprocessing.Process(target=run2, args=('bob',))
p_alice.start()
p_bob.start()
p_alice.join()
p_bob.join()
assert p_alice.exitcode == 0 and p_bob.exitcode == 0


def run3(party):
compatible_utils.init_ray(address='local')
addresses = {
'alice': '127.0.0.1:11019',
'bob': '127.0.0.1:11018',
}
fed.init(
addresses=addresses,
party=party,
config={
"cross_silo_comm": {
"messages_max_size_in_bytes": 100,
"grpc_channel_options": [
('grpc.max_send_message_length', 200),
],
},
},
)

def _assert_on_proxy(proxy_actor):
config = ray.get(proxy_actor._get_proxy_config.remote())
options = config['grpc_options']
assert ("grpc.max_send_message_length", 200) in options
assert ("grpc.max_receive_message_length", 100) in options
assert ('grpc.so_reuseport', 0) in options

sender_proxy = ray.get_actor(sender_proxy_actor_name())
receiver_proxy = ray.get_actor(receiver_proxy_actor_name())
_assert_on_proxy(sender_proxy)
_assert_on_proxy(receiver_proxy)

a = dummpy.party('alice').remote()
b = dummpy.party('bob').remote()
fed.get([a, b])

fed.shutdown()
ray.shutdown()


def test_grpc_max_size_by_both_config():
p_alice = multiprocessing.Process(target=run3, args=('alice',))
p_bob = multiprocessing.Process(target=run3, args=('bob',))
p_alice.start()
p_bob.start()
p_alice.join()
p_bob.join()
assert p_alice.exitcode == 0 and p_bob.exitcode == 0


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 54304c2

Please sign in to comment.