diff --git a/fed/api.py b/fed/api.py index 561006b..edb467a 100644 --- a/fed/api.py +++ b/fed/api.py @@ -43,6 +43,8 @@ def init( cross_silo_grpc_retry_policy: Dict = None, cross_silo_send_max_retries: int = None, cross_silo_serializing_allowed_list: Dict = None, + cross_silo_send_resource_label: Dict = None, + cross_silo_recv_resource_label: Dict = None, exit_on_failure_cross_silo_sending: bool = False, cross_silo_messages_max_size_in_bytes: int = None, cross_silo_timeout_in_seconds: int = 60, @@ -131,6 +133,14 @@ def init( cross_silo_serializing_allowed_list: The package or class list allowed for serializing(deserializating) cross silos. It's used for avoiding pickle deserializing execution attack when crossing solis. + cross_silo_send_resource_label: Customized resource label, the SendProxyActor + will be scheduled based on the declared resource label. For example, + when setting to `{"my_label": 1}`, then the SendProxyActor will be started + only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. + cross_silo_recv_resource_label: Customized resource label, the RecverProxyActor + will be scheduled based on the declared resource label. For example, + when setting to `{"my_label": 1}`, then the RecverProxyActor will be started + only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. exit_on_failure_cross_silo_sending: whether exit when failure on cross-silo sending. If True, a SIGTERM will be signaled to self if failed to sending cross-silo data. @@ -200,6 +210,8 @@ def init( logger.info(f'Started rayfed with {cluster_config}') set_exit_on_failure_sending(exit_on_failure_cross_silo_sending) + recv_actor_config = fed_config.ProxyActorConfig( + resource_label=cross_silo_recv_resource_label) # Start recv proxy start_recv_proxy( cluster=cluster, @@ -207,7 +219,11 @@ def init( logging_level=logging_level, tls_config=tls_config, retry_policy=cross_silo_grpc_retry_policy, + actor_config=recv_actor_config ) + + send_actor_config = fed_config.ProxyActorConfig( + resource_label=cross_silo_send_resource_label) start_send_proxy( cluster=cluster, party=party, @@ -215,6 +231,7 @@ def init( tls_config=tls_config, retry_policy=cross_silo_grpc_retry_policy, max_retries=cross_silo_send_max_retries, + actor_config=send_actor_config ) if enable_waiting_for_other_parties_ready: diff --git a/fed/barriers.py b/fed/barriers.py index 0278589..f1664e3 100644 --- a/fed/barriers.py +++ b/fed/barriers.py @@ -16,7 +16,8 @@ import logging import threading import time -from typing import Dict +import copy +from typing import Dict, Optional import cloudpickle import grpc @@ -27,6 +28,7 @@ from fed._private import constants from fed._private.grpc_options import get_grpc_options, set_max_message_length from fed.cleanup import push_to_sending +from fed.config import get_cluster_config from fed.grpc import fed_pb2, fed_pb2_grpc from fed.utils import setup_logger @@ -133,7 +135,6 @@ async def send_data_grpc( ): grpc_options = get_grpc_options(retry_policy=retry_policy) if \ grpc_options is None else fed_utils.dict2tuple(grpc_options) - tls_enabled = fed_utils.tls_enabled(tls_config) cluster_config = fed_config.get_cluster_config() metadata = fed_utils.dict2tuple(metadata) @@ -366,13 +367,20 @@ async def _get_grpc_options(self): return get_grpc_options() +_DEFAULT_RECV_PROXY_OPTIONS = { + "max_concurrency": 1000, +} + + def start_recv_proxy( cluster: str, party: str, logging_level: str, tls_config=None, retry_policy=None, + actor_config: Optional[fed_config.ProxyActorConfig] = None ): + # Create RecevrProxyActor # Not that this is now a threaded actor. # NOTE(NKcqx): This is not just addr, but a party dict containing 'address' @@ -381,8 +389,14 @@ def start_recv_proxy( if not listen_addr: listen_addr = party_addr['address'] + actor_options = copy.deepcopy(_DEFAULT_RECV_PROXY_OPTIONS) + if actor_config is not None and actor_config.resource_label is not None: + actor_options.update({"resources": actor_config.resource_label}) + + logger.debug(f"Starting RecvProxyActor with options: {actor_options}") + recver_proxy_actor = RecverProxyActor.options( - name=f"RecverProxyActor-{party}", max_concurrency=1000 + name=f"RecverProxyActor-{party}", **actor_options ).remote( listen_addr=listen_addr, party=party, @@ -391,12 +405,16 @@ def start_recv_proxy( retry_policy=retry_policy, ) recver_proxy_actor.run_grpc_server.remote() - server_state = ray.get(recver_proxy_actor.is_ready.remote()) + timeout = get_cluster_config().cross_silo_timeout + server_state = ray.get(recver_proxy_actor.is_ready.remote(), timeout=timeout) assert server_state[0], server_state[1] - logger.info("RecverProxy was successfully created.") + logger.info("RecverProxy has successfully created.") _SEND_PROXY_ACTOR = None +_DEFAULT_SEND_PROXY_OPTIONS = { + "max_concurrency": 1000, +} def start_send_proxy( @@ -406,20 +424,24 @@ def start_send_proxy( tls_config: Dict = None, retry_policy=None, max_retries=None, + actor_config: Optional[fed_config.ProxyActorConfig] = None ): # Create SendProxyActor global _SEND_PROXY_ACTOR + + actor_options = copy.deepcopy(_DEFAULT_SEND_PROXY_OPTIONS) if max_retries is not None: - _SEND_PROXY_ACTOR = SendProxyActor.options( - name="SendProxyActor", - max_concurrency=1000, - max_task_retries=max_retries, - max_restarts=1, - ) - else: - _SEND_PROXY_ACTOR = SendProxyActor.options( - name="SendProxyActor", max_concurrency=1000 - ) + actor_options.update({ + "max_task_retries": max_retries, + "max_restarts": 1, + }) + if actor_config is not None and actor_config.resource_label is not None: + actor_options.update({"resources": actor_config.resource_label}) + + logger.debug(f"Starting SendProxyActor with options: {actor_options}") + _SEND_PROXY_ACTOR = SendProxyActor.options( + name="SendProxyActor", **actor_options) + _SEND_PROXY_ACTOR = _SEND_PROXY_ACTOR.remote( cluster=cluster, party=party, @@ -427,8 +449,9 @@ def start_send_proxy( logging_level=logging_level, retry_policy=retry_policy, ) - assert ray.get(_SEND_PROXY_ACTOR.is_ready.remote()) - logger.info("SendProxy was successfully created.") + timeout = get_cluster_config().cross_silo_timeout + assert ray.get(_SEND_PROXY_ACTOR.is_ready.remote(), timeout=timeout) + logger.info("SendProxyActor has successfully created.") def send( diff --git a/fed/config.py b/fed/config.py index 4476df5..f3946c1 100644 --- a/fed/config.py +++ b/fed/config.py @@ -7,6 +7,7 @@ import fed._private.compatible_utils as compatible_utils import fed._private.constants as fed_constants import cloudpickle +from typing import Dict, Optional class ClusterConfig: @@ -77,3 +78,16 @@ def get_job_config(): raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_JOB_CONFIG) _job_config = JobConfig(raw_dict) return _job_config + + +class ProxyActorConfig: + """A class to store parameters used for Proxy Actor + + Attributes: + resource_label: The customized resources for the actor. This will be + filled into the "resource" field of Ray ActorClass.options. + """ + def __init__( + self, + resource_label: Optional[Dict[str, str]] = None) -> None: + self.resource_label = resource_label diff --git a/tests/test_setup_proxy_actor.py b/tests/test_setup_proxy_actor.py new file mode 100644 index 0000000..7a3aabd --- /dev/null +++ b/tests/test_setup_proxy_actor.py @@ -0,0 +1,96 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import multiprocessing + +import pytest +import fed +import fed._private.compatible_utils as compatible_utils +import ray + + +def test_setup_proxy_success(): + def run(party): + compatible_utils.init_ray(address='local', resources={"127.0.0.1": 2}) + cluster = { + 'alice': {'address': '127.0.0.1:11010'}, + 'bob': {'address': '127.0.0.1:11011'}, + } + send_proxy_resources = { + "127.0.0.1": 1 + } + recv_proxy_resources = { + "127.0.0.1": 1 + } + fed.init( + cluster=cluster, + party=party, + cross_silo_send_resource_label=send_proxy_resources, + cross_silo_recv_resource_label=recv_proxy_resources, + ) + + assert ray.get_actor("SendProxyActor") is not None + assert ray.get_actor(f"RecverProxyActor-{party}") is not None + + fed.shutdown() + ray.shutdown() + + p_alice = multiprocessing.Process(target=run, args=('alice',)) + p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + assert p_alice.exitcode == 0 and p_bob.exitcode == 0 + + +def test_setup_proxy_failed(): + def run(party): + compatible_utils.init_ray(address='local', resources={"127.0.0.1": 1}) + cluster = { + 'alice': {'address': '127.0.0.1:11010'}, + 'bob': {'address': '127.0.0.1:11011'}, + } + send_proxy_resources = { + "127.0.0.2": 1 # Insufficient resource + } + recv_proxy_resources = { + "127.0.0.2": 1 # Insufficient resource + } + with pytest.raises(ray.exceptions.GetTimeoutError): + fed.init( + cluster=cluster, + party=party, + cross_silo_send_resource_label=send_proxy_resources, + cross_silo_recv_resource_label=recv_proxy_resources, + cross_silo_timeout_in_seconds=10, # Quick fail in test + ) + + fed.shutdown() + ray.shutdown() + + p_alice = multiprocessing.Process(target=run, args=('alice',)) + p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + assert p_alice.exitcode == 0 and p_bob.exitcode == 0 + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-sv", __file__]))