diff --git a/README.md b/README.md index d78a17f0..86bf867b 100644 --- a/README.md +++ b/README.md @@ -107,11 +107,11 @@ The above codes: def main(party): ray.init(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) ``` This first declares a two-party cluster, whose addresses corresponding to '127.0.0.1:11012' in 'alice' and '127.0.0.1:11011' in 'bob'. And then, the `fed.init` create a cluster in the specified party. @@ -145,11 +145,11 @@ def aggregate(val1, val2): def main(party): ray.init(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) actor_alice = MyActor.party("alice").remote(1) actor_bob = MyActor.party("bob").remote(1) diff --git a/benchmarks/many_tiny_tasks_benchmark.py b/benchmarks/many_tiny_tasks_benchmark.py index ab3c2995..5fbbf387 100644 --- a/benchmarks/many_tiny_tasks_benchmark.py +++ b/benchmarks/many_tiny_tasks_benchmark.py @@ -33,11 +33,11 @@ def aggr(self, val1, val2): def main(party): ray.init(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11010'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11010', + 'bob': '127.0.0.1:11011', } - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) actor_alice = MyActor.party("alice").remote() actor_bob = MyActor.party("bob").remote() diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 356e191b..c76bc51f 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -24,6 +24,6 @@ For example: >>> import ray >>> import fed >>> ray.init() ->>> fed.init(cluster=cluster, party="Alice", tls_config=tls_config) +>>> fed.init(addresses=addresses, party="Alice", tls_config=tls_config) Successfully to connect to current Ray cluster in party `Alice` diff --git a/fed/_private/fed_actor.py b/fed/_private/fed_actor.py index fc5496ba..aa73bfa4 100644 --- a/fed/_private/fed_actor.py +++ b/fed/_private/fed_actor.py @@ -25,14 +25,14 @@ class FedActorHandle: def __init__( self, fed_class_task_id, - cluster, + addresses, cls, party, node_party, options, ) -> None: self._fed_class_task_id = fed_class_task_id - self._cluster = cluster + self._addresses = addresses self._body = cls self._party = party self._node_party = node_party @@ -46,7 +46,7 @@ def __getattr__(self, method_name: str): # Raise an error if the method is invalid. getattr(self._body, method_name) call_node = FedActorMethod( - self._cluster, + self._addresses, self._party, self._node_party, self, @@ -90,13 +90,13 @@ def _execute_remote_method(self, method_name, options, args, kwargs): class FedActorMethod: def __init__( self, - cluster, + addresses, party, node_party, fed_actor_handle, method_name, ) -> None: - self._cluster = cluster + self._addresses = addresses self._party = party # Current party self._node_party = node_party self._fed_actor_handle = fed_actor_handle diff --git a/fed/api.py b/fed/api.py index 209cb49e..fdf97fbe 100644 --- a/fed/api.py +++ b/fed/api.py @@ -15,7 +15,7 @@ import functools import inspect import logging -from typing import Any, Dict, List, Union, Optional +from typing import Any, Dict, List, Union import cloudpickle import ray @@ -35,7 +35,7 @@ _start_sender_proxy, ) from fed.proxy.grpc.grpc_proxy import SenderProxy, ReceiverProxy -from fed.config import CrossSiloMessageConfig +from fed.config import GrpcCrossSiloMessageConfig from fed.fed_object import FedObject from fed.utils import is_ray_object_refs, setup_logger @@ -43,50 +43,28 @@ def init( - cluster: Dict = None, + addresses: Dict = None, party: str = None, + config: Dict = {}, tls_config: Dict = None, logging_level: str = 'info', - enable_waiting_for_other_parties_ready: bool = False, sender_proxy_cls: SenderProxy = None, receiver_proxy_cls: ReceiverProxy = None, - global_cross_silo_message_config: Optional[CrossSiloMessageConfig] = None, - **kwargs, ): """ Initialize a RayFed client. Args: - cluster: optional; a dict describes the cluster config. E.g. + addresses: optional; a dict describes the addresses configurations. E.g. .. code:: python { - 'alice': { - # The address for other parties. - 'address': '127.0.0.1:10001', - # (Optional) the listen address, the `address` will be - # used if not provided. - 'listen_addr': '0.0.0.0:10001', - 'cross_silo_message_config': CrossSiloMessageConfig - }, - 'bob': { - # The address for other parties. - 'address': '127.0.0.1:10002', - # (Optional) the listen address, the `address` will be - # used if not provided. - 'listen_addr': '0.0.0.0:10002', - # (Optional) The party specific metadata sent with grpc requests - 'grpc_metadata': (('token', 'bob-token'),), - }, - 'carol': { - # The address for other parties. - 'address': '127.0.0.1:10003', - # (Optional) the listen address, the `address` will be - # used if not provided. - 'listen_addr': '0.0.0.0:10003', - # (Optional) The party specific metadata sent with grpc requests - 'grpc_metadata': (('token', 'carol-token'),), - }, + # The address that can be connected to `alice` by other parties. + 'alice': '127.0.0.1:10001', + # The address that can be connected to `bob` by other parties. + 'bob': '127.0.0.1:10002', + # The address that can be connected to `carol` by other parties. + 'carol': '127.0.0.1:10003', } party: optional; self party. tls_config: optional; a dict describes the tls config. E.g. @@ -109,29 +87,24 @@ def init( } logging_level: optional; the logging level, could be `debug`, `info`, `warning`, `error`, `critical`, not case sensititive. - enable_waiting_for_other_parties_ready: ping other parties until they - are all ready if True. - global_cross_silo_message_config: Global cross-silo message related - configs that are applied to all connections. Supported configs - can refer to CrossSiloMessageConfig in config.py. Examples: >>> import fed >>> import ray >>> ray.init(address='local') - >>> cluster = { - >>> 'alice': {'address': '127.0.0.1:10001'}, - >>> 'bob': {'address': '127.0.0.1:10002'}, - >>> 'carol': {'address': '127.0.0.1:10003'}, + >>> addresses = { + >>> 'alice': '127.0.0.1:10001', + >>> 'bob': '127.0.0.1:10002', + >>> 'carol': '127.0.0.1:10003', >>> } >>> # Start as alice. - >>> fed.init(cluster=cluster, self_party='alice') + >>> fed.init(addresses=addresses, party='alice') """ - assert cluster, "Cluster should be provided." + assert addresses, "Addresses should be provided." assert party, "Party should be provided." - assert party in cluster, f"Party {party} is not in cluster {cluster}." + assert party in addresses, f"Party {party} is not in the addresses {addresses}." - fed_utils.validate_cluster_info(cluster) + fed_utils.validate_addresses(addresses) tls_config = {} if tls_config is None else tls_config if tls_config: @@ -139,20 +112,21 @@ def init( 'cert' in tls_config and 'key' in tls_config ), 'Cert or key are not in tls_config.' - global_cross_silo_message_config = \ - global_cross_silo_message_config or CrossSiloMessageConfig() + cross_silo_message_dict = config.get("cross_silo_message", {}) + cross_silo_message_config = GrpcCrossSiloMessageConfig.from_dict( + cross_silo_message_dict) # A Ray private accessing, should be replaced in public API. compatible_utils._init_internal_kv() cluster_config = { - constants.KEY_OF_CLUSTER_ADDRESSES: cluster, + constants.KEY_OF_CLUSTER_ADDRESSES: addresses, constants.KEY_OF_CURRENT_PARTY_NAME: party, constants.KEY_OF_TLS_CONFIG: tls_config, } job_config = { constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG: - global_cross_silo_message_config, + cross_silo_message_config, } compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)) @@ -170,7 +144,7 @@ def init( logger.info(f'Started rayfed with {cluster_config}') get_global_context().get_cleanup_manager().start( - exit_when_failure_sending=global_cross_silo_message_config.exit_on_sending_failure) # noqa + exit_when_failure_sending=cross_silo_message_config.exit_on_sending_failure) # noqa if receiver_proxy_cls is None: logger.debug( @@ -179,12 +153,12 @@ def init( from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy receiver_proxy_cls = GrpcReceiverProxy _start_receiver_proxy( - cluster=cluster, + addresses=addresses, party=party, logging_level=logging_level, tls_config=tls_config, proxy_cls=receiver_proxy_cls, - proxy_config=global_cross_silo_message_config + proxy_config=cross_silo_message_config ) if sender_proxy_cls is None: @@ -194,17 +168,18 @@ def init( from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy sender_proxy_cls = GrpcSenderProxy _start_sender_proxy( - cluster=cluster, + addresses=addresses, party=party, logging_level=logging_level, tls_config=tls_config, proxy_cls=sender_proxy_cls, - proxy_config=global_cross_silo_message_config + # TODO(qwang): proxy_config -> cross_silo_message_config + proxy_config=cross_silo_message_config ) - if enable_waiting_for_other_parties_ready: + if config.get("barrier_on_initializing", False): # TODO(zhouaihui): can be removed after we have a better retry strategy. - ping_others(cluster=cluster, self_party=party, max_retries=3600) + ping_others(addresses=addresses, self_party=party, max_retries=3600) def shutdown(): @@ -216,9 +191,9 @@ def shutdown(): logger.info('Shutdowned rayfed.') -def _get_cluster(): +def _get_addresses(): """ - Get the RayFed cluster configration. + Get the RayFed addresses configration. """ return fed_config.get_cluster_config().cluster_addresses @@ -290,7 +265,7 @@ def remote(self, *cls_args, **cls_kwargs): fed_class_task_id = get_global_context().next_seq_id() fed_actor_handle = FedActorHandle( fed_class_task_id, - _get_cluster(), + _get_addresses(), self._cls, _get_party(), self._party, @@ -341,7 +316,7 @@ def get( # A fake fed_task_id for a `fed.get()` operator. This is useful # to help contruct the whole DAG within `fed.get`. fake_fed_task_id = get_global_context().next_seq_id() - cluster = _get_cluster() + addresses = _get_addresses() current_party = _get_party() is_individual_id = isinstance(fed_objects, FedObject) if is_individual_id: @@ -357,7 +332,7 @@ def get( assert ray_object_ref is not None ray_refs.append(ray_object_ref) - for party_name in cluster: + for party_name in addresses: if party_name == current_party: continue else: diff --git a/fed/config.py b/fed/config.py index 15130bb0..3ad9eb4d 100644 --- a/fed/config.py +++ b/fed/config.py @@ -112,6 +112,9 @@ class CrossSiloMessageConfig: send_resource_label: Optional[Dict[str, str]] = None recv_resource_label: Optional[Dict[str, str]] = None http_header: Optional[Dict[str, str]] = None + # (Optional) The address this party is going to listen on. + # If not provided, the `address` will be used. + listening_address: str = None def __json__(self): return json.dumps(self.__dict__) @@ -132,7 +135,10 @@ def from_dict(cls, data: Dict): CrossSiloMessageConfig: An instance of CrossSiloMessageConfig. """ # Get the attributes of the class - attrs = {attr for attr, _ in cls.__annotations__.items()} + + data = data or {} + all_annotations = {**cls.__annotations__, **cls.__base__.__annotations__} + attrs = {attr for attr, _ in all_annotations.items()} # Filter the dictionary to only include keys that are attributes of the class filtered_data = {key: value for key, value in data.items() if key in attrs} return cls(**filtered_data) diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 647e8d22..0599cef3 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -58,7 +58,7 @@ def pop_from_two_dim_dict(the_dict, key_a, key_b): class SenderProxyActor: def __init__( self, - cluster: Dict, + addresses: Dict, party: str, tls_config: Dict = None, logging_level: str = None, @@ -72,13 +72,13 @@ def __init__( ) self._stats = {"send_op_count": 0} - self._cluster = cluster + self._addresses = addresses self._party = party self._tls_config = tls_config job_config = fed_config.get_job_config() cross_silo_message_config = job_config.cross_silo_message_config self._proxy_instance: SenderProxy = proxy_cls( - cluster, party, tls_config, cross_silo_message_config) + addresses, party, tls_config, cross_silo_message_config) async def is_ready(self): res = await self._proxy_instance.is_ready() @@ -93,8 +93,8 @@ async def send( ): self._stats["send_op_count"] += 1 assert ( - dest_party in self._cluster - ), f'Failed to find {dest_party} in cluster {self._cluster}.' + dest_party in self._addresses + ), f'Failed to find {dest_party} in addresses {self._addresses}.' send_log_msg = ( f'send data to seq_id {downstream_seq_id} of {dest_party} ' f'from {upstream_seq_id}' @@ -115,8 +115,8 @@ async def send( async def _get_stats(self): return self._stats - async def _get_cluster_info(self): - return self._cluster + async def _get_addresses_info(self): + return self._addresses async def _get_proxy_config(self, dest_party=None): return await self._proxy_instance.get_proxy_config(dest_party) @@ -126,7 +126,7 @@ async def _get_proxy_config(self, dest_party=None): class ReceiverProxyActor: def __init__( self, - listen_addr: str, + listening_address: str, party: str, logging_level: str, tls_config=None, @@ -139,13 +139,13 @@ def __init__( party_val=party, ) self._stats = {"receive_op_count": 0} - self._listen_addr = listen_addr + self._listening_address = listening_address self._party = party self._tls_config = tls_config job_config = fed_config.get_job_config() cross_silo_message_config = job_config.cross_silo_message_config self._proxy_instance: ReceiverProxy = proxy_cls( - listen_addr, party, tls_config, cross_silo_message_config) + listening_address, party, tls_config, cross_silo_message_config) async def start(self): await self._proxy_instance.start() @@ -173,7 +173,7 @@ async def _get_proxy_config(self): def _start_receiver_proxy( - cluster: str, + addresses: str, party: str, logging_level: str, tls_config=None, @@ -184,10 +184,10 @@ def _start_receiver_proxy( # Create RecevrProxyActor # Not that this is now a threaded actor. # NOTE(NKcqx): This is not just addr, but a party dict containing 'address' - party_addr = cluster[party] - listen_addr = party_addr.get('listen_addr', None) - if not listen_addr: - listen_addr = party_addr['address'] + party_addr = addresses[party] + listening_address = proxy_config.listening_address + if not listening_address: + listening_address = party_addr actor_options = copy.deepcopy(_DEFAULT_RECEIVER_PROXY_OPTIONS) if proxy_config is not None and proxy_config.recv_resource_label is not None: @@ -198,7 +198,7 @@ def _start_receiver_proxy( receiver_proxy_actor = ReceiverProxyActor.options( name=f"ReceiverProxyActor-{party}", **actor_options ).remote( - listen_addr=listen_addr, + listening_address=listening_address, party=party, tls_config=tls_config, logging_level=logging_level, @@ -218,7 +218,7 @@ def _start_receiver_proxy( def _start_sender_proxy( - cluster: Dict, + addresses: Dict, party: str, logging_level: str, tls_config: Dict = None, @@ -242,7 +242,7 @@ def _start_sender_proxy( name="SenderProxyActor", **actor_options) _SENDER_PROXY_ACTOR = _SENDER_PROXY_ACTOR.remote( - cluster=cluster, + addresses=addresses, party=party, tls_config=tls_config, logging_level=logging_level, @@ -276,9 +276,9 @@ def recv(party: str, src_party: str, upstream_seq_id, curr_seq_id): return receiver_proxy.get_data.remote(src_party, upstream_seq_id, curr_seq_id) -def ping_others(cluster: Dict[str, Dict], self_party: str, max_retries=3600): +def ping_others(addresses: Dict[str, Dict], self_party: str, max_retries=3600): """Ping other parties until all are ready or timeout.""" - others = [party for party in cluster if not party == self_party] + others = [party for party in addresses if not party == self_party] tried = 0 while tried < max_retries and others: diff --git a/fed/proxy/base_proxy.py b/fed/proxy/base_proxy.py index ededca86..ecd4a336 100644 --- a/fed/proxy/base_proxy.py +++ b/fed/proxy/base_proxy.py @@ -21,12 +21,12 @@ class SenderProxy(abc.ABC): def __init__( self, - cluster: Dict, + addresses: Dict, party: str, tls_config: Dict, proxy_config: CrossSiloMessageConfig = None ) -> None: - self._cluster = cluster + self._addresses = addresses self._party = party self._tls_config = tls_config self._proxy_config = proxy_config diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index 38317140..50ad372f 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -105,7 +105,7 @@ async def send( data, upstream_seq_id, downstream_seq_id): - dest_addr = self._cluster[dest_party]['address'] + dest_addr = self._addresses[dest_party] grpc_metadata, grpc_channel_options = self.get_grpc_config_by_party(dest_party) tls_enabled = fed_utils.tls_enabled(self._tls_config) if dest_party not in self._stubs: @@ -142,8 +142,7 @@ def get_grpc_config_by_party(self, dest_party): grpc_metadata = self._grpc_metadata grpc_options = self._grpc_options - dest_party_msg_config = self._cluster[dest_party].get( - 'cross_silo_message_config', None) + dest_party_msg_config = self._proxy_config if dest_party_msg_config is not None: if dest_party_msg_config.http_header is not None: dest_party_grpc_metadata = dict(dest_party_msg_config.http_header) diff --git a/fed/utils.py b/fed/utils.py index 3632c721..7e28f3d6 100644 --- a/fed/utils.py +++ b/fed/utils.py @@ -209,9 +209,9 @@ def validate_address(address: str) -> None: raise ValueError(error_msg) -def validate_cluster_info(cluster: dict): +def validate_addresses(addresses: dict): """ - Validate whether the cluster information is in correct forms. + Validate whether the addresses is in correct forms. """ - for _, info in cluster.items(): - validate_address(info['address']) + for _, address in addresses.items(): + validate_address(address) diff --git a/tests/serializations_tests/test_unpickle_with_whitelist.py b/tests/serializations_tests/test_unpickle_with_whitelist.py index 98f39b53..a072731a 100644 --- a/tests/serializations_tests/test_unpickle_with_whitelist.py +++ b/tests/serializations_tests/test_unpickle_with_whitelist.py @@ -19,8 +19,6 @@ import multiprocessing import numpy -from fed.config import CrossSiloMessageConfig - @fed.remote def generate_wrong_type(): @@ -42,20 +40,23 @@ def pass_arg(d): def run(party): compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } allowed_list = { "numpy.core.numeric": ["*"], "numpy": ["dtype"], } fed.init( - cluster=cluster, + addresses=addresses, party=party, - global_cross_silo_message_config=CrossSiloMessageConfig( - serializing_allowed_list=allowed_list - )) + config={ + "cross_silo_message": { + 'serializing_allowed_list': allowed_list + } + }, + ) # Test passing an allowed type. o1 = generate_allowed_type.party("alice").remote() diff --git a/tests/simple_example.py b/tests/simple_example.py index dc7b1c70..4ca1095f 100644 --- a/tests/simple_example.py +++ b/tests/simple_example.py @@ -42,15 +42,15 @@ def agg_fn(obj1, obj2): return f"agg-{obj1}-{obj2}" -cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, +addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } def run(party): ray.init(address='local') - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) print(f"Running the script in party {party}") ds1, ds2 = [123, 789] diff --git a/tests/test_api.py b/tests/test_api.py index 71774f2f..9a3b00af 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -22,12 +22,12 @@ def run(): compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, + addresses = { + 'alice': '127.0.0.1:11012', } - fed.init(cluster=cluster, party="alice") + fed.init(addresses=addresses, party="alice") config = fed_config.get_cluster_config() - assert config.cluster_addresses == cluster + assert config.cluster_addresses == addresses assert config.current_party == "alice" fed.shutdown() ray.shutdown() @@ -40,25 +40,26 @@ def test_fed_apis(): assert p_alice.exitcode == 0 -def test_miss_party_name_on_actor(): - def run(): - compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - } - fed.init(cluster=cluster, party="alice") +def _run(): + compatible_utils.init_ray(address='local') + addresses = { + 'alice': '127.0.0.1:11012', + } + fed.init(addresses=addresses, party="alice") + + @fed.remote + class MyActor: + pass - @fed.remote - class MyActor: - pass + with pytest.raises(ValueError): + MyActor.remote() - with pytest.raises(ValueError): - MyActor.remote() + fed.shutdown() + ray.shutdown() - fed.shutdown() - ray.shutdown() - p_alice = multiprocessing.Process(target=run) +def test_miss_party_name_on_actor(): + p_alice = multiprocessing.Process(target=_run) p_alice.start() p_alice.join() assert p_alice.exitcode == 0 diff --git a/tests/test_async_startup_2_clusters.py b/tests/test_async_startup_2_clusters.py index b90fdb72..9542f87e 100644 --- a/tests/test_async_startup_2_clusters.py +++ b/tests/test_async_startup_2_clusters.py @@ -43,11 +43,11 @@ def _run(party: str): time.sleep(10) compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) my1 = My.party("alice").remote() my2 = My.party("bob").remote() diff --git a/tests/test_basic_pass_fed_objects.py b/tests/test_basic_pass_fed_objects.py index 98307893..ebca911f 100644 --- a/tests/test_basic_pass_fed_objects.py +++ b/tests/test_basic_pass_fed_objects.py @@ -36,11 +36,11 @@ def get_value(self): def run(party, is_inner_party): compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) o = f.party("alice").remote() actor_location = "alice" if is_inner_party else "bob" diff --git a/tests/test_cache_fed_objects.py b/tests/test_cache_fed_objects.py index 22b02333..ac181397 100644 --- a/tests/test_cache_fed_objects.py +++ b/tests/test_cache_fed_objects.py @@ -32,11 +32,11 @@ def g(x, index): def run(party): compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) o = f.party("alice").remote() o1 = g.party("bob").remote(o, 1) diff --git a/tests/test_enable_tls_across_parties.py b/tests/test_enable_tls_across_parties.py index 2f4ee60e..5c3c1c38 100644 --- a/tests/test_enable_tls_across_parties.py +++ b/tests/test_enable_tls_across_parties.py @@ -48,11 +48,11 @@ def _run(party: str): "key": os.path.join(cert_dir, "server.key"), } - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } - fed.init(cluster=cluster, party=party, tls_config=cert_config) + fed.init(addresses=addresses, party=party, tls_config=cert_config) my1 = My.party("alice").remote() my2 = My.party("bob").remote() diff --git a/tests/test_exit_on_failure_sending.py b/tests/test_exit_on_failure_sending.py index 594fc891..47a15376 100644 --- a/tests/test_exit_on_failure_sending.py +++ b/tests/test_exit_on_failure_sending.py @@ -19,8 +19,6 @@ import fed import fed._private.compatible_utils as compatible_utils -from fed.config import GrpcCrossSiloMessageConfig - import signal import os @@ -48,13 +46,13 @@ def get_value(self): return self._value -def run(party, is_inner_party): +def run(party): signal.signal(signal.SIGTERM, signal_handler) compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } retry_policy = { "maxAttempts": 2, @@ -63,15 +61,17 @@ def run(party, is_inner_party): "backoffMultiplier": 1, "retryableStatusCodes": ["UNAVAILABLE"], } - cross_silo_message_config = GrpcCrossSiloMessageConfig( - grpc_retry_policy=retry_policy, - exit_on_sending_failure=True - ) + fed.init( - cluster=cluster, + addresses=addresses, party=party, logging_level='debug', - global_cross_silo_message_config=cross_silo_message_config + config={ + 'cross_silo_message': { + 'grpc_retry_policy': retry_policy, + 'exit_on_sending_failure': True, + }, + }, ) o = f.party("alice").remote() @@ -83,7 +83,7 @@ def run(party, is_inner_party): def test_exit_when_failure_on_sending(): - p_alice = multiprocessing.Process(target=run, args=('alice', True)) + p_alice = multiprocessing.Process(target=run, args=('alice',)) p_alice.start() p_alice.join() assert p_alice.exitcode == 0 diff --git a/tests/test_fed_get.py b/tests/test_fed_get.py index f49fc6aa..3aaf71d6 100644 --- a/tests/test_fed_get.py +++ b/tests/test_fed_get.py @@ -48,11 +48,11 @@ def mean(x, y): def run(party): compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) epochs = 3 alice_model = MyModel.party("alice").remote("alice", 2) diff --git a/tests/test_grpc_options_on_proxies.py b/tests/test_grpc_options_on_proxies.py index a43007d1..2abf78d0 100644 --- a/tests/test_grpc_options_on_proxies.py +++ b/tests/test_grpc_options_on_proxies.py @@ -18,8 +18,6 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import GrpcCrossSiloMessageConfig - @fed.remote def dummpy(): @@ -28,18 +26,18 @@ def dummpy(): def run(party): compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11019'}, - 'bob': {'address': '127.0.0.1:11018'}, + addresses = { + 'alice': '127.0.0.1:11019', + 'bob': '127.0.0.1:11018', } fed.init( - cluster=cluster, + addresses=addresses, party=party, - global_cross_silo_message_config=GrpcCrossSiloMessageConfig( - grpc_channel_options=[( - 'grpc.max_send_message_length', 100 - )] - ) + config={ + "cross_silo_message": { + "grpc_channel_options": [('grpc.max_send_message_length', 100)], + }, + }, ) def _assert_on_proxy(proxy_actor): diff --git a/tests/test_grpc_options_per_party.py b/tests/test_grpc_options_per_party.py deleted file mode 100644 index 018294fb..00000000 --- a/tests/test_grpc_options_per_party.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2023 The RayFed Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import multiprocessing -import pytest -import fed -import fed._private.compatible_utils as compatible_utils -import ray - -from fed.config import GrpcCrossSiloMessageConfig - - -@fed.remote -def dummpy(): - return 2 - - -def run(party): - compatible_utils.init_ray(address='local') - cluster = { - 'alice': { - 'address': '127.0.0.1:11010', - 'cross_silo_message_config': GrpcCrossSiloMessageConfig( - grpc_channel_options=[ - ('grpc.default_authority', 'alice'), - ('grpc.max_send_message_length', 200) - ]) - }, - 'bob': {'address': '127.0.0.1:11011'}, - } - fed.init( - cluster=cluster, - party=party, - global_cross_silo_message_config=GrpcCrossSiloMessageConfig( - grpc_channel_options=[( - 'grpc.max_send_message_length', 100 - )] - ) - ) - - def _assert_on_sender_proxy(proxy_actor): - alice_config = ray.get(proxy_actor._get_proxy_config.remote('alice')) - # print(f"【NKcqx】alice config: {alice_config}") - assert 'grpc_options' in alice_config - alice_options = alice_config['grpc_options'] - assert ('grpc.max_send_message_length', 200) in alice_options - assert ('grpc.default_authority', 'alice') in alice_options - - bob_config = ray.get(proxy_actor._get_proxy_config.remote('bob')) - assert 'grpc_options' in bob_config - bob_options = bob_config['grpc_options'] - assert ('grpc.max_send_message_length', 100) in bob_options - assert not any(o[0] == 'grpc.default_authority' for o in bob_options) - - sender_proxy = ray.get_actor("SenderProxyActor") - _assert_on_sender_proxy(sender_proxy) - - a = dummpy.party('alice').remote() - b = dummpy.party('bob').remote() - fed.get([a, b]) - - fed.shutdown() - ray.shutdown() - - -def test_grpc_options(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) - p_alice.start() - p_bob.start() - p_alice.join() - p_bob.join() - assert p_alice.exitcode == 0 and p_bob.exitcode == 0 - - -def party_grpc_options(party): - compatible_utils.init_ray(address='local') - cluster = { - 'alice': { - 'address': '127.0.0.1:11010', - 'cross_silo_message_config': GrpcCrossSiloMessageConfig( - grpc_channel_options=[ - ('grpc.default_authority', 'alice'), - ('grpc.max_send_message_length', 51 * 1024 * 1024) - ]) - }, - 'bob': { - 'address': '127.0.0.1:11011', - 'cross_silo_message_config': GrpcCrossSiloMessageConfig( - grpc_channel_options=[ - ('grpc.default_authority', 'bob'), - ('grpc.max_send_message_length', 50 * 1024 * 1024) - ]) - }, - } - fed.init( - cluster=cluster, - party=party, - global_cross_silo_message_config=GrpcCrossSiloMessageConfig( - grpc_channel_options=[( - 'grpc.max_send_message_length', 100 - )] - ) - ) - - def _assert_on_sender_proxy(proxy_actor): - alice_config = ray.get(proxy_actor._get_proxy_config.remote('alice')) - assert 'grpc_options' in alice_config - alice_options = alice_config['grpc_options'] - assert ('grpc.max_send_message_length', 51 * 1024 * 1024) in alice_options - assert ('grpc.default_authority', 'alice') in alice_options - - bob_config = ray.get(proxy_actor._get_proxy_config.remote('bob')) - assert 'grpc_options' in bob_config - bob_options = bob_config['grpc_options'] - assert ('grpc.max_send_message_length', 50 * 1024 * 1024) in bob_options - assert ('grpc.default_authority', 'bob') in bob_options - - sender_proxy = ray.get_actor("SenderProxyActor") - _assert_on_sender_proxy(sender_proxy) - - a = dummpy.party('alice').remote() - b = dummpy.party('bob').remote() - fed.get([a, b]) - - fed.shutdown() - ray.shutdown() - - -def test_party_specific_grpc_options(): - p_alice = multiprocessing.Process( - target=party_grpc_options, args=('alice',)) - p_bob = multiprocessing.Process( - target=party_grpc_options, args=('bob',)) - p_alice.start() - p_bob.start() - p_alice.join() - p_bob.join() - assert p_alice.exitcode == 0 and p_bob.exitcode == 0 - - -if __name__ == "__main__": - import sys - - sys.exit(pytest.main(["-sv", __file__])) diff --git a/tests/test_internal_kv.py b/tests/test_internal_kv.py index c57dec48..f2b5372b 100644 --- a/tests/test_internal_kv.py +++ b/tests/test_internal_kv.py @@ -8,12 +8,12 @@ def run(party): compatible_utils.init_ray("local") - cluster = { - 'alice': {'address': '127.0.0.1:11010', 'listen_addr': '0.0.0.0:11010'}, - 'bob': {'address': '127.0.0.1:11011', 'listen_addr': '0.0.0.0:11011'}, + addresses = { + 'alice': '127.0.0.1:11010', + 'bob': '127.0.0.1:11011', } assert compatible_utils.kv is None - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) assert compatible_utils.kv assert not compatible_utils.kv.put(b"test_key", b"test_val") assert compatible_utils.kv.get(b"test_key") == b"test_val" diff --git a/tests/test_listen_addr.py b/tests/test_listen_addr.py deleted file mode 100644 index 05960e2f..00000000 --- a/tests/test_listen_addr.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2023 The RayFed Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import multiprocessing - -import pytest -import ray -import fed -import fed._private.compatible_utils as compatible_utils - - -@fed.remote -def f(): - return 100 - - -@fed.remote -class My: - def __init__(self, value) -> None: - self._value = value - - def get_value(self): - return self._value - - -def run(party, is_inner_party): - compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012', 'listen_addr': '0.0.0.0:11012'}, - 'bob': {'address': '127.0.0.1:11011', 'listen_addr': '0.0.0.0:11011'}, - } - fed.init(cluster=cluster, party=party) - - o = f.party("alice").remote() - actor_location = "alice" if is_inner_party else "bob" - my = My.party(actor_location).remote(o) - val = my.get_value.remote() - result = fed.get(val) - assert result == 100 - assert fed.get(o) == 100 - import time - - time.sleep(5) - fed.shutdown() - ray.shutdown() - - -def test_listen_addr(): - p_alice = multiprocessing.Process(target=run, args=('alice', True)) - p_bob = multiprocessing.Process(target=run, args=('bob', True)) - p_alice.start() - p_bob.start() - p_alice.join() - p_bob.join() - assert p_alice.exitcode == 0 and p_bob.exitcode == 0 - - -def test_listen_used_addr(): - def run(party): - import socket - - compatible_utils.init_ray(address='local') - occupied_port = 11020 - # NOTE(NKcqx): Firstly try to bind IPv6 because the grpc server will do so. - # Otherwise this UT will false because socket bind $occupied_port - # on IPv4 address while grpc server listendn Ipv6 address. - try: - s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - # Pre-occuping the port - s.bind(("::", occupied_port)) - except OSError: - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("127.0.0.1", occupied_port)) - - cluster = { - 'alice': { - 'address': '127.0.0.1:11012', - 'listen_addr': f'0.0.0.0:{occupied_port}'}, - 'bob': { - 'address': '127.0.0.1:11011', - 'listen_addr': '0.0.0.0:11011'}, - } - - # Starting grpc server on an used port will cause AssertionError - with pytest.raises(AssertionError): - fed.init(cluster=cluster, party=party) - - import time - - time.sleep(5) - s.close() - fed.shutdown() - ray.shutdown() - - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_alice.start() - p_alice.join() - assert p_alice.exitcode == 0 - - -if __name__ == "__main__": - # import sys - - # sys.exit(pytest.main(["-sv", __file__])) - test_listen_used_addr() diff --git a/tests/test_listening_address.py b/tests/test_listening_address.py new file mode 100644 index 00000000..bb7a01cd --- /dev/null +++ b/tests/test_listening_address.py @@ -0,0 +1,132 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing + +import pytest +import ray +import fed +import fed._private.compatible_utils as compatible_utils + + +@fed.remote +def f(): + return 100 + + +@fed.remote +class My: + def __init__(self, value) -> None: + self._value = value + + def get_value(self): + return self._value + + +def run(party, is_inner_party): + compatible_utils.init_ray(address='local') + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', + } + listening_address = '0.0.0.0:11012' if party == 'alice' else '0.0.0.0:11011' + fed.init( + addresses=addresses, + party=party, + config={ + 'cross_silo_message': { + 'listening_address': listening_address + } + }, + ) + + o = f.party("alice").remote() + actor_location = "alice" if is_inner_party else "bob" + my = My.party(actor_location).remote(o) + val = my.get_value.remote() + result = fed.get(val) + assert result == 100 + assert fed.get(o) == 100 + import time + + time.sleep(5) + fed.shutdown() + ray.shutdown() + + +def test_listening_address(): + p_alice = multiprocessing.Process(target=run, args=('alice', True)) + p_bob = multiprocessing.Process(target=run, args=('bob', True)) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + assert p_alice.exitcode == 0 and p_bob.exitcode == 0 + + +def _run(party): + import socket + + compatible_utils.init_ray(address='local') + occupied_port = 11020 + # NOTE(NKcqx): Firstly try to bind IPv6 because the grpc server will do so. + # Otherwise this UT will false because socket bind $occupied_port + # on IPv4 address while grpc server listendn Ipv6 address. + try: + s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + # Pre-occuping the port + s.bind(("::", occupied_port)) + except OSError: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("127.0.0.1", occupied_port)) + + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', + } + listening_address = f'0.0.0.0:{occupied_port}' \ + if party == 'alice' else '0.0.0.0:11011' + + # Starting grpc server on an used port will cause AssertionError + with pytest.raises(AssertionError): + fed.init( + addresses=addresses, + party=party, + config={ + 'cross_silo_message': { + 'listening_address': listening_address + } + }, + ) + + import time + + time.sleep(5) + s.close() + fed.shutdown() + ray.shutdown() + + +def test_listen_used_address(): + p_alice = multiprocessing.Process(target=_run, args=('alice',)) + p_alice.start() + p_alice.join() + assert p_alice.exitcode == 0 + + +if __name__ == "__main__": + # import sys + + # sys.exit(pytest.main(["-sv", __file__])) + test_listen_used_address() diff --git a/tests/test_options.py b/tests/test_options.py index e76b2dd9..13bbb6f3 100644 --- a/tests/test_options.py +++ b/tests/test_options.py @@ -33,11 +33,11 @@ def bar(x): def run(party): compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) foo = Foo.party("alice").remote() a, b = fed.get(foo.run.options(num_returns=2).remote()) diff --git a/tests/test_pass_fed_objects_in_containers_in_actor.py b/tests/test_pass_fed_objects_in_containers_in_actor.py index 1d815375..08d3a2d6 100644 --- a/tests/test_pass_fed_objects_in_containers_in_actor.py +++ b/tests/test_pass_fed_objects_in_containers_in_actor.py @@ -37,15 +37,15 @@ def bar(self, li): return True -cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, +addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } def run(party): compatible_utils.init_ray(address='local') - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) my1 = My.party("alice").remote() my2 = My.party("bob").remote() o1 = my1.foo.remote(0) diff --git a/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py b/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py index 580e031e..7b78cb60 100644 --- a/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py +++ b/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py @@ -39,11 +39,11 @@ def bar(li): def run(party): compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) o1 = foo.party("alice").remote(0) o2 = foo.party("bob").remote(1) li = ["hello", [o1], ["world", [o2]]] diff --git a/tests/test_ping_others.py b/tests/test_ping_others.py index 1f9c6b30..0782bb7a 100644 --- a/tests/test_ping_others.py +++ b/tests/test_ping_others.py @@ -20,19 +20,19 @@ from fed.proxy.barriers import ping_others -cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, +addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } def test_ping_non_started_party(): def run(party): compatible_utils.init_ray(address='local') - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) if (party == 'alice'): with pytest.raises(RuntimeError): - ping_others(cluster, party, 5) + ping_others(addresses, party, 5) fed.shutdown() ray.shutdown() @@ -45,9 +45,9 @@ def run(party): def test_ping_started_party(): def run(party): compatible_utils.init_ray(address='local') - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) if (party == 'alice'): - ping_success = ping_others(cluster, party, 5) + ping_success = ping_others(addresses, party, 5) assert ping_success is True fed.shutdown() diff --git a/tests/test_repeat_init.py b/tests/test_repeat_init.py index 5e2d49fd..8926c786 100644 --- a/tests/test_repeat_init.py +++ b/tests/test_repeat_init.py @@ -37,16 +37,16 @@ def bar(self, li): return True -cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, +addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } def run(party): def _run(): compatible_utils.init_ray(address='local') - fed.init(cluster=cluster, party=party) + fed.init(addresses=addresses, party=party) my1 = My.party("alice").remote() my2 = My.party("bob").remote() diff --git a/tests/test_reset_context.py b/tests/test_reset_context.py index aee5ba0c..95c6e53f 100644 --- a/tests/test_reset_context.py +++ b/tests/test_reset_context.py @@ -4,9 +4,9 @@ import fed._private.compatible_utils as compatible_utils import pytest -cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, +addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } @@ -22,7 +22,7 @@ def get(self): def run(party): compatible_utils.init_ray(address='local') fed.init( - cluster=cluster, + addresses=addresses, party=party) actor = A.party('alice').remote(10) @@ -46,7 +46,7 @@ def run(party): compatible_utils.init_ray(address='local') fed.init( - cluster=cluster, + addresses=addresses, party=party) actor = A.party('alice').remote(10) diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index 9450dac6..92981aae 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -20,8 +20,6 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import GrpcCrossSiloMessageConfig - @fed.remote def f(): @@ -39,9 +37,9 @@ def get_value(self): def run(party, is_inner_party): compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', } retry_policy = { "maxAttempts": 4, @@ -51,11 +49,11 @@ def run(party, is_inner_party): "retryableStatusCodes": ["UNAVAILABLE"], } fed.init( - cluster=cluster, + addresses=addresses, party=party, - global_cross_silo_message_config=GrpcCrossSiloMessageConfig( - grpc_retry_policy=retry_policy - ) + config={'cross_silo_message': { + 'grpc_retry_policy': retry_policy, + }}, ) o = f.party("alice").remote() diff --git a/tests/test_setup_proxy_actor.py b/tests/test_setup_proxy_actor.py index 2a136cbe..21f6941e 100644 --- a/tests/test_setup_proxy_actor.py +++ b/tests/test_setup_proxy_actor.py @@ -20,26 +20,16 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import CrossSiloMessageConfig - def run(party): compatible_utils.init_ray(address='local', resources={"127.0.0.1": 2}) - cluster = { - 'alice': {'address': '127.0.0.1:11010'}, - 'bob': {'address': '127.0.0.1:11011'}, - } - sender_proxy_resources = { - "127.0.0.1": 1 - } - receiver_proxy_resources = { - "127.0.0.1": 1 + addresses = { + 'alice': '127.0.0.1:11010', + 'bob': '127.0.0.1:11011', } fed.init( - cluster=cluster, + addresses=addresses, party=party, - cross_silo_send_resource_label=sender_proxy_resources, - cross_silo_recv_resource_label=receiver_proxy_resources, ) assert ray.get_actor("SenderProxyActor") is not None @@ -51,9 +41,9 @@ def run(party): def run_failure(party): compatible_utils.init_ray(address='local', resources={"127.0.0.1": 1}) - cluster = { - 'alice': {'address': '127.0.0.1:11010'}, - 'bob': {'address': '127.0.0.1:11011'}, + addresses = { + 'alice': '127.0.0.1:11010', + 'bob': '127.0.0.1:11011', } sender_proxy_resources = { "127.0.0.2": 1 # Insufficient resource @@ -63,13 +53,15 @@ def run_failure(party): } with pytest.raises(ray.exceptions.GetTimeoutError): fed.init( - cluster=cluster, + addresses=addresses, party=party, - global_cross_silo_message_config=CrossSiloMessageConfig( - send_resource_label=sender_proxy_resources, - recv_resource_label=receiver_proxy_resources, - timeout_in_ms=10*1000, - ) + config={ + 'cross_silo_message': { + 'send_resource_label': sender_proxy_resources, + 'recv_resource_label': receiver_proxy_resources, + 'timeout_in_ms': 10*1000, + } + } ) fed.shutdown() diff --git a/tests/test_transport_proxy.py b/tests/test_transport_proxy.py index 7e9c3fe6..4828240b 100644 --- a/tests/test_transport_proxy.py +++ b/tests/test_transport_proxy.py @@ -58,17 +58,17 @@ def test_n_to_1_transport(): NUM_DATA = 10 SERVER_ADDRESS = "127.0.0.1:12344" party = 'test_party' - cluster_config = {'test_party': {'address': SERVER_ADDRESS}} + addresses = {'test_party': SERVER_ADDRESS} config = GrpcCrossSiloMessageConfig() _start_receiver_proxy( - cluster_config, + addresses, party, logging_level='info', proxy_cls=GrpcReceiverProxy, proxy_config=config ) _start_sender_proxy( - cluster_config, + addresses, party, logging_level='info', proxy_cls=GrpcSenderProxy, @@ -148,22 +148,23 @@ async def is_ready(self): def _test_start_receiver_proxy( - cluster: str, + addresses: str, + config: dict, party: str, logging_level: str, expected_metadata: dict, ): # Create RecevrProxyActor # Not that this is now a threaded actor. - party_addr = cluster[party] - listen_addr = party_addr.get('listen_addr', None) - if not listen_addr: - listen_addr = party_addr['address'] + address = addresses[party] + listening_address = config['cross_silo_message'].get('listening_address', None) + if not listening_address: + listening_address = address receiver_proxy_actor = TestReceiverProxyActor.options( name=f"ReceiverProxyActor-{party}", max_concurrency=1000 ).remote( - listen_addr=listen_addr, + listen_addr=listening_address, party=party, expected_metadata=expected_metadata ) @@ -194,70 +195,23 @@ def test_send_grpc_with_meta(): global_context.get_global_context().get_cleanup_manager().start() SERVER_ADDRESS = "127.0.0.1:12344" - party = 'test_party' - cluster_config = {'test_party': {'address': SERVER_ADDRESS}} + party_name = 'test_party' + addresses = {party_name: SERVER_ADDRESS} _test_start_receiver_proxy( - cluster_config, party, logging_level='info', + addresses, + {'cross_silo_message': {}}, + party_name, + logging_level='info', expected_metadata=metadata, ) _start_sender_proxy( - cluster_config, - party, + addresses, + party_name, logging_level='info', proxy_cls=GrpcSenderProxy, proxy_config=GrpcCrossSiloMessageConfig()) sent_objs = [] - sent_obj = send(party, "data", 0, 1) - sent_objs.append(sent_obj) - for result in ray.get(sent_objs): - assert result - - global_context.get_global_context().get_cleanup_manager().graceful_stop() - global_context.clear_global_context() - ray.shutdown() - - -def test_send_grpc_with_party_specific_meta(): - compatible_utils.init_ray(address='local') - cluster_config = { - constants.KEY_OF_CLUSTER_ADDRESSES: "", - constants.KEY_OF_CURRENT_PARTY_NAME: "", - constants.KEY_OF_TLS_CONFIG: "", - } - sender_proxy_config = CrossSiloMessageConfig( - http_header={"key": "value"}) - job_config = { - constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG: - sender_proxy_config, - } - compatible_utils._init_internal_kv() - compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, - cloudpickle.dumps(cluster_config)) - compatible_utils.kv.put(constants.KEY_OF_JOB_CONFIG, - cloudpickle.dumps(job_config)) - global_context.get_global_context().get_cleanup_manager().start() - - SERVER_ADDRESS = "127.0.0.1:12344" - party = 'test_party' - cluster_parties_config = { - 'test_party': { - 'address': SERVER_ADDRESS, - 'cross_silo_message_config': CrossSiloMessageConfig( - http_header={"token": "test-party-token"}) - } - } - _test_start_receiver_proxy( - cluster_parties_config, party, logging_level='info', - expected_metadata={"key": "value", "token": "test-party-token"}, - ) - _start_sender_proxy( - cluster_parties_config, - party, - logging_level='info', - proxy_cls=GrpcSenderProxy, - proxy_config=sender_proxy_config) - sent_objs = [] - sent_obj = send(party, "data", 0, 1) + sent_obj = send(party_name, "data", 0, 1) sent_objs.append(sent_obj) for result in ray.get(sent_objs): assert result diff --git a/tests/test_transport_proxy_tls.py b/tests/test_transport_proxy_tls.py index 41056732..6e931084 100644 --- a/tests/test_transport_proxy_tls.py +++ b/tests/test_transport_proxy_tls.py @@ -56,10 +56,10 @@ def test_n_to_1_transport(): NUM_DATA = 10 SERVER_ADDRESS = "127.0.0.1:65422" party = 'test_party' - cluster_config = {'test_party': {'address': SERVER_ADDRESS}} + addresses = {'test_party': SERVER_ADDRESS} config = GrpcCrossSiloMessageConfig() _start_receiver_proxy( - cluster_config, + addresses, party, logging_level='info', tls_config=tls_config, @@ -67,7 +67,7 @@ def test_n_to_1_transport(): proxy_config=config ) _start_sender_proxy( - cluster_config, + addresses, party, logging_level='info', tls_config=tls_config,