diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 3ff0dac..68289cd 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -33,4 +33,4 @@ jobs: - name: Lint run: | . py3/bin/activate - black --check --diff . + black --check --diff . --exclude fed/grpc diff --git a/.isort.cfg b/.isort.cfg index 5f06acb..eb16c8b 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -9,6 +9,5 @@ use_parentheses=True float_to_top=True filter_files=True -known_local_folder=ray -known_third_party=grpc +known_local_folder=fed sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER \ No newline at end of file diff --git a/benchmarks/many_tiny_tasks_benchmark.py b/benchmarks/many_tiny_tasks_benchmark.py index 5fbbf38..1e42032 100644 --- a/benchmarks/many_tiny_tasks_benchmark.py +++ b/benchmarks/many_tiny_tasks_benchmark.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ray -import time import sys +import time + +import ray + import fed @@ -31,11 +33,11 @@ def aggr(self, val1, val2): def main(party): - ray.init(address='local') + ray.init(address="local") addresses = { - 'alice': '127.0.0.1:11010', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11010", + "bob": "127.0.0.1:11011", } fed.init(addresses=addresses, party=party) @@ -53,13 +55,13 @@ def main(party): if i % 100 == 0: print(f"Running {i}th call") print(f"num calls: {num_calls}") - print("total time (ms) = ", (time.time() - start)*1000) - print("per task overhead (ms) =", (time.time() - start)*1000/num_calls) + print("total time (ms) = ", (time.time() - start) * 1000) + print("per task overhead (ms) =", (time.time() - start) * 1000 / num_calls) fed.shutdown() ray.shutdown() if __name__ == "__main__": - assert len(sys.argv) == 2, 'Please run this script with party.' + assert len(sys.argv) == 2, "Please run this script with party." main(sys.argv[1]) diff --git a/docs/source/conf.py b/docs/source/conf.py index 6939e8f..356d368 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -2,34 +2,34 @@ # -- Project information -project = 'RayFed' -copyright = '2022, The RayFed Team' -author = 'The RayFed Authors' +project = "RayFed" +copyright = "2022, The RayFed Team" +author = "The RayFed Authors" -release = '0.1' -version = '0.1.0' +release = "0.1" +version = "0.1.0" # -- General configuration extensions = [ - 'sphinx.ext.duration', - 'sphinx.ext.doctest', - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.intersphinx', + "sphinx.ext.duration", + "sphinx.ext.doctest", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", ] intersphinx_mapping = { - 'python': ('https://docs.python.org/3/', None), - 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), + "python": ("https://docs.python.org/3/", None), + "sphinx": ("https://www.sphinx-doc.org/en/master/", None), } -intersphinx_disabled_domains = ['std'] +intersphinx_disabled_domains = ["std"] -templates_path = ['_templates'] +templates_path = ["_templates"] # -- Options for HTML output -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # -- Options for EPUB output -epub_show_urls = 'footnote' +epub_show_urls = "footnote" diff --git a/fed/__init__.py b/fed/__init__.py index 7636dd4..c5c4502 100644 --- a/fed/__init__.py +++ b/fed/__init__.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from fed.api import (get, init, kill, remote, - shutdown) +from fed.api import get, init, kill, remote, shutdown from fed.proxy.barriers import recv, send from fed.fed_object import FedObject from fed.exceptions import FedRemoteError @@ -27,5 +26,5 @@ "recv", "send", "FedObject", - "FedRemoteError" + "FedRemoteError", ] diff --git a/fed/_private/compatible_utils.py b/fed/_private/compatible_utils.py index 83d2a09..75bfb62 100644 --- a/fed/_private/compatible_utils.py +++ b/fed/_private/compatible_utils.py @@ -13,10 +13,11 @@ # limitations under the License. import abc -import ray -import fed._private.constants as fed_constants +import ray import ray.experimental.internal_kv as ray_internal_kv + +import fed._private.constants as fed_constants from fed._private import constants @@ -26,8 +27,8 @@ def _compare_version_strings(version1, version2): True if version1 is greater, and False if they're equal, and False if version2 is greater. """ - v1_list = version1.split('.') - v2_list = version2.split('.') + v1_list = version1.split(".") + v2_list = version2.split(".") len1 = len(v1_list) len2 = len(v2_list) @@ -41,16 +42,15 @@ def _compare_version_strings(version1, version2): def _ray_version_less_than_2_0_0(): - """ Whther the current ray version is less 2.0.0. - """ + """Whther the current ray version is less 2.0.0.""" return _compare_version_strings( - fed_constants.RAY_VERSION_2_0_0_STR, ray.__version__) + fed_constants.RAY_VERSION_2_0_0_STR, ray.__version__ + ) def init_ray(address: str = None, **kwargs): - """A compatible API to init Ray. - """ - if address == 'local' and _ray_version_less_than_2_0_0(): + """A compatible API to init Ray.""" + if address == "local" and _ray_version_less_than_2_0_0(): # Ignore the `local` when ray < 2.0.0 ray.init(**kwargs) else: @@ -58,8 +58,7 @@ def init_ray(address: str = None, **kwargs): def _get_gcs_address_from_ray_worker(): - """A compatible API to get the gcs address from the ray worker module. - """ + """A compatible API to get the gcs address from the ray worker module.""" try: return ray._private.worker._global_node.gcs_address except AttributeError: @@ -67,19 +66,19 @@ def _get_gcs_address_from_ray_worker(): def wrap_kv_key(job_name, key: str): - """Add an prefix to the key to avoid conflict with other jobs. - """ - assert isinstance(key, str), \ - f"The key of KV data must be `str` type, got {type(key)}." + """Add an prefix to the key to avoid conflict with other jobs.""" + assert isinstance( + key, str + ), f"The key of KV data must be `str` type, got {type(key)}." - return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format( - job_name, key) + return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format(job_name, key) class AbstractInternalKv(abc.ABC): - """ An abstract class that represents for bridging Ray internal kv in + """An abstract class that represents for bridging Ray internal kv in both Ray client mode and non Ray client mode. """ + def __init__(self) -> None: pass @@ -105,8 +104,8 @@ def reset(self): class InternalKv(AbstractInternalKv): - """The internal kv class for non Ray client mode. - """ + """The internal kv class for non Ray client mode.""" + def __init__(self, job_name: str) -> None: super().__init__() self._job_name = job_name @@ -120,21 +119,18 @@ def initialize(self): from ray._raylet import GcsClient gcs_client = GcsClient( - address=_get_gcs_address_from_ray_worker(), - nums_reconnect_retry=10) + address=_get_gcs_address_from_ray_worker(), nums_reconnect_retry=10 + ) return ray_internal_kv._initialize_internal_kv(gcs_client) def put(self, k, v): - return ray_internal_kv._internal_kv_put( - wrap_kv_key(self._job_name, k), v) + return ray_internal_kv._internal_kv_put(wrap_kv_key(self._job_name, k), v) def get(self, k): - return ray_internal_kv._internal_kv_get( - wrap_kv_key(self._job_name, k)) + return ray_internal_kv._internal_kv_get(wrap_kv_key(self._job_name, k)) def delete(self, k): - return ray_internal_kv._internal_kv_del( - wrap_kv_key(self._job_name, k)) + return ray_internal_kv._internal_kv_del(wrap_kv_key(self._job_name, k)) def reset(self): return ray_internal_kv._internal_kv_reset() @@ -144,8 +140,8 @@ def _ping(self): class ClientModeInternalKv(AbstractInternalKv): - """The internal kv class for Ray client mode. - """ + """The internal kv class for Ray client mode.""" + def __init__(self) -> None: super().__init__() self._internal_kv_actor = ray.get_actor("_INTERNAL_KV_ACTOR") @@ -176,9 +172,13 @@ def _init_internal_kv(job_name): global kv if kv is None: from ray._private.client_mode_hook import is_client_mode_enabled + if is_client_mode_enabled: - kv_actor = ray.remote(InternalKv).options( - name="_INTERNAL_KV_ACTOR").remote(job_name) + kv_actor = ( + ray.remote(InternalKv) + .options(name="_INTERNAL_KV_ACTOR") + .remote(job_name) + ) response = kv_actor._ping.remote() ray.get(response) kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv(job_name) @@ -192,6 +192,7 @@ def _clear_internal_kv(): kv.delete(constants.KEY_OF_JOB_CONFIG) kv.reset() from ray._private.client_mode_hook import is_client_mode_enabled + if is_client_mode_enabled: _internal_kv_actor = ray.get_actor("_INTERNAL_KV_ACTOR") ray.kill(_internal_kv_actor) diff --git a/fed/_private/constants.py b/fed/_private/constants.py index f21f3d0..a77201b 100644 --- a/fed/_private/constants.py +++ b/fed/_private/constants.py @@ -27,7 +27,7 @@ KEY_OF_CROSS_SILO_COMM_CONFIG_DICT = "CROSS_SILO_COMM_CONFIG_DICT" -RAYFED_LOG_FMT = "%(asctime)s.%(msecs)03d %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa +RAYFED_LOG_FMT = "%(asctime)s.%(msecs)03d %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa RAYFED_DATE_FMT = "%Y-%m-%d %H:%M:%S" diff --git a/fed/_private/fed_actor.py b/fed/_private/fed_actor.py index dd88057..c4da4ca 100644 --- a/fed/_private/fed_actor.py +++ b/fed/_private/fed_actor.py @@ -16,6 +16,7 @@ import ray from ray.util.client.common import ClientActorHandle + from fed._private.fed_call_holder import FedCallHolder from fed.fed_object import FedObject @@ -90,22 +91,20 @@ def _execute_impl(self, cls_args, cls_kwargs): ) def _execute_remote_method( - self, - method_name, - options, - _ray_wrappered_method, - args, - kwargs, + self, + method_name, + options, + _ray_wrappered_method, + args, + kwargs, ): num_returns = 1 - if options and 'num_returns' in options: - num_returns = options['num_returns'] - logger.debug( - f"Actor method call: {method_name}, num_returns: {num_returns}" - ) + if options and "num_returns" in options: + num_returns = options["num_returns"] + logger.debug(f"Actor method call: {method_name}, num_returns: {num_returns}") return _ray_wrappered_method.options( - name='', + name="", num_returns=num_returns, ).remote( *args, diff --git a/fed/_private/fed_call_holder.py b/fed/_private/fed_call_holder.py index 9e349c4..c9ccce5 100644 --- a/fed/_private/fed_call_holder.py +++ b/fed/_private/fed_call_holder.py @@ -14,15 +14,16 @@ import logging -# Set config in the very beginning to avoid being overwritten by other packages. -logging.basicConfig(level=logging.INFO) - +import fed.config as fed_config from fed._private.global_context import get_global_context -from fed.proxy.barriers import send from fed.fed_object import FedObject -from fed.utils import resolve_dependencies +from fed.proxy.barriers import send from fed.tree_util import tree_flatten -import fed.config as fed_config +from fed.utils import resolve_dependencies + +# Set config in the very beginning to avoid being overwritten by other packages. +logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) @@ -97,10 +98,10 @@ def internal_remote(self, *args, **kwargs): ) if ( self._options - and 'num_returns' in self._options - and self._options['num_returns'] > 1 + and "num_returns" in self._options + and self._options["num_returns"] > 1 ): - num_returns = self._options['num_returns'] + num_returns = self._options["num_returns"] return [ FedObject(self._node_party, fed_task_id, None, i) for i in range(num_returns) diff --git a/fed/_private/global_context.py b/fed/_private/global_context.py index cd1337f..8e5c550 100644 --- a/fed/_private/global_context.py +++ b/fed/_private/global_context.py @@ -12,22 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from fed.cleanup import CleanupManager -from typing import Callable import threading +from typing import Callable + +from fed.cleanup import CleanupManager class GlobalContext: - def __init__(self, job_name: str, - current_party: str, - failure_handler: Callable[[], None]) -> None: + def __init__( + self, job_name: str, current_party: str, failure_handler: Callable[[], None] + ) -> None: self._job_name = job_name self._seq_count = 0 self._failure_handler = failure_handler self._atomic_shutdown_flag_lock = threading.Lock() self._atomic_shutdown_flag = True self._cleanup_manager = CleanupManager( - current_party, self.acquire_shutdown_flag) + current_party, self.acquire_shutdown_flag + ) def next_seq_id(self) -> int: self._seq_count += 1 @@ -65,9 +67,9 @@ def acquire_shutdown_flag(self) -> bool: _global_context = None -def init_global_context(current_party: str, - job_name: str, - failure_handler: Callable[[], None] = None) -> None: +def init_global_context( + current_party: str, job_name: str, failure_handler: Callable[[], None] = None +) -> None: global _global_context if _global_context is None: _global_context = GlobalContext(job_name, current_party, failure_handler) diff --git a/fed/_private/message_queue.py b/fed/_private/message_queue.py index c4a1b6e..fb6d211 100644 --- a/fed/_private/message_queue.py +++ b/fed/_private/message_queue.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import threading -from collections import deque import time -import logging - +from collections import deque logger = logging.getLogger(__name__) @@ -27,7 +26,7 @@ class MessageQueueManager: - def __init__(self, msg_handler, failure_handler=None, thread_name=''): + def __init__(self, msg_handler, failure_handler=None, thread_name=""): assert callable(msg_handler), "msg_handler must be a callable function" # `deque()` is thread safe on `popleft` and `append` operations. # See https://docs.python.org/3/library/collections.html#deque-objects @@ -55,7 +54,8 @@ def _loop(): if self._thread is None or not self._thread.is_alive(): logger.debug( - f"Starting new thread[{self._thread_name}] for message polling.") + f"Starting new thread[{self._thread_name}] for message polling." + ) self._queue = deque() self._thread = threading.Thread(target=_loop, name=self._thread_name) self._thread.start() @@ -79,9 +79,11 @@ def stop(self): If False: forcelly kill the for-loop sub-thread. """ if threading.current_thread() == self._thread: - logger.error(f"Can't stop the message queue in the message " - f"polling thread[{self._thread_name}]. Ignore it as this" - f"could bring unknown time sequence problems.") + logger.error( + f"Can't stop the message queue in the message " + f"polling thread[{self._thread_name}]. Ignore it as this" + f"could bring unknown time sequence problems." + ) raise RuntimeError("Thread can't kill itself") # TODO(NKcqx): Force kill sub-thread by calling `._stop()` will diff --git a/fed/_private/serialization_utils.py b/fed/_private/serialization_utils.py index c1b2e73..0c0f3c6 100644 --- a/fed/_private/serialization_utils.py +++ b/fed/_private/serialization_utils.py @@ -13,11 +13,11 @@ # limitations under the License. import io + import cloudpickle import fed.config as fed_config - _pickle_whitelist = None diff --git a/fed/api.py b/fed/api.py index 66d4aba..42d949c 100644 --- a/fed/api.py +++ b/fed/api.py @@ -69,7 +69,7 @@ def init( party: str = None, config: Dict = {}, tls_config: Dict = None, - logging_level: str = 'info', + logging_level: str = "info", sender_proxy_cls: SenderProxy = None, receiver_proxy_cls: ReceiverProxy = None, receiver_sender_proxy_cls: SenderReceiverProxy = None, @@ -169,8 +169,8 @@ def init( tls_config = {} if tls_config is None else tls_config if tls_config: assert ( - 'cert' in tls_config and 'key' in tls_config - ), 'Cert or key are not in tls_config.' + "cert" in tls_config and "key" in tls_config + ), "Cert or key are not in tls_config." # A Ray private accessing, should be replaced in public API. compatible_utils._init_internal_kv(job_name) @@ -201,7 +201,7 @@ def init( job_name=job_name, ) - logger.info(f'Started rayfed with {cluster_config}') + logger.info(f"Started rayfed with {cluster_config}") cross_silo_comm_config = CrossSiloMessageConfig.from_dict(cross_silo_comm_dict) signal.signal(signal.SIGINT, _signal_handler) get_global_context().get_cleanup_manager().start( @@ -291,7 +291,7 @@ def _shutdown(intended=True): clear_global_context() if not intended and failure_handler is not None: failure_handler() - logger.info('Shutdowned rayfed.') + logger.info("Shutdowned rayfed.") def _get_addresses(job_name: str = None): diff --git a/fed/cleanup.py b/fed/cleanup.py index d68b290..33682c8 100644 --- a/fed/cleanup.py +++ b/fed/cleanup.py @@ -16,11 +16,12 @@ import os import signal import threading -from fed._private.message_queue import MessageQueueManager -from fed.exceptions import FedRemoteError -from ray.exceptions import RayError import ray +from ray.exceptions import RayError + +from fed._private.message_queue import MessageQueueManager +from fed.exceptions import FedRemoteError logger = logging.getLogger(__name__) @@ -44,11 +45,13 @@ class CleanupManager: def __init__(self, current_party, acquire_shutdown_flag) -> None: self._sending_data_q = MessageQueueManager( lambda msg: self._process_data_sending_task_return(msg), - thread_name='DataSendingQueueThread') + thread_name="DataSendingQueueThread", + ) self._sending_error_q = MessageQueueManager( lambda msg: self._process_error_sending_task_return(msg), - thread_name="ErrorSendingQueueThread") + thread_name="ErrorSendingQueueThread", + ) self._monitor_thread = None @@ -60,9 +63,9 @@ def start(self, exit_on_sending_failure=False, expose_error_trace=False): self._expose_error_trace = expose_error_trace self._sending_data_q.start() - logger.debug('Start check sending thread.') + logger.debug("Start check sending thread.") self._sending_error_q.start() - logger.debug('Start check error sending thread.') + logger.debug("Start check error sending thread.") def _main_thread_monitor(): main_thread = threading.main_thread() @@ -71,7 +74,7 @@ def _main_thread_monitor(): self._monitor_thread = threading.Thread(target=_main_thread_monitor) self._monitor_thread.start() - logger.info('Start check sending monitor thread.') + logger.info("Start check sending monitor thread.") def stop(self): # NOTE(NKcqx): MUST firstly stop the data queue, because it @@ -80,12 +83,14 @@ def stop(self): self._sending_data_q.stop() self._sending_error_q.stop() - def push_to_sending(self, - obj_ref: ray.ObjectRef, - dest_party: str = None, - upstream_seq_id: int = -1, - downstream_seq_id: int = -1, - is_error: bool = False): + def push_to_sending( + self, + obj_ref: ray.ObjectRef, + dest_party: str = None, + upstream_seq_id: int = -1, + downstream_seq_id: int = -1, + is_error: bool = False, + ): """ Push the sending remote task's return value, i.e. `obj_ref` to the corresponding message queue. @@ -104,7 +109,7 @@ def push_to_sending(self, queue instead. """ msg_pack = (obj_ref, dest_party, upstream_seq_id, downstream_seq_id) - if (is_error): + if is_error: self._sending_error_q.append(msg_pack) else: self._sending_data_q.append(msg_pack) @@ -123,7 +128,7 @@ def _signal_exit(self): # will cause dead lock. In order to ensure executing `shutdown` exactly # once and avoid dead lock, the lock must be checked before sending # signals. - if (self._acquire_shutdown_flag()): + if self._acquire_shutdown_flag(): logger.debug("Signal SIGINT to exit.") os.kill(os.getpid(), signal.SIGINT) @@ -151,16 +156,24 @@ def _process_data_sending_task_return(self, message): try: res = ray.get(obj_ref) except Exception as e: - logger.warn(f'Failed to send {obj_ref} to {dest_party}, error: {e},' - f'upstream_seq_id: {upstream_seq_id}, ' - f'downstream_seq_id: {downstream_seq_id}.') - if (isinstance(e, RayError)): + logger.warn( + f"Failed to send {obj_ref} to {dest_party}, error: {e}," + f"upstream_seq_id: {upstream_seq_id}, " + f"downstream_seq_id: {downstream_seq_id}." + ) + if isinstance(e, RayError): logger.info(f"Sending error {e.cause} to {dest_party}.") from fed.proxy.barriers import send + # TODO(NKcqx): Cascade broadcast to all parties error_trace = e.cause if self._expose_error_trace else None - send(dest_party, FedRemoteError(self._current_party, error_trace), - upstream_seq_id, downstream_seq_id, True) + send( + dest_party, + FedRemoteError(self._current_party, error_trace), + upstream_seq_id, + downstream_seq_id, + True, + ) res = False @@ -183,10 +196,12 @@ def _process_error_sending_task_return(self, error_msg): res = False if not res: - logger.warning(f"Failed to send error {error_ref} to {dest_party}, " - f"upstream_seq_id: {upstream_seq_id} " - f"downstream_seq_id: {downstream_seq_id}. " - "In this case, other parties won't sense " - "this error and may cause unknown behaviour.") + logger.warning( + f"Failed to send error {error_ref} to {dest_party}, " + f"upstream_seq_id: {upstream_seq_id} " + f"downstream_seq_id: {downstream_seq_id}. " + "In this case, other parties won't sense " + "this error and may cause unknown behaviour." + ) # Return True so that remaining error objects can be sent return True diff --git a/fed/config.py b/fed/config.py index 230d979..2bd22a8 100644 --- a/fed/config.py +++ b/fed/config.py @@ -130,7 +130,7 @@ def from_json(cls, json_str): return cls(**data) @classmethod - def from_dict(cls, data: Dict) -> 'CrossSiloMessageConfig': + def from_dict(cls, data: Dict) -> "CrossSiloMessageConfig": """Initialize CrossSiloMessageConfig from a dictionary. Args: diff --git a/fed/exceptions.py b/fed/exceptions.py index dad4abf..25efed8 100644 --- a/fed/exceptions.py +++ b/fed/exceptions.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. + class FedRemoteError(Exception): def __init__(self, src_party: str, cause: Exception) -> None: self._src_party = src_party self._cause = cause def __str__(self): - error_msg = f'FedRemoteError occurred at {self._src_party}' + error_msg = f"FedRemoteError occurred at {self._src_party}" if self._cause is not None: error_msg += f" caused by {str(self._cause)}" return error_msg diff --git a/fed/fed_object.py b/fed/fed_object.py index 6e62faa..fd71945 100644 --- a/fed/fed_object.py +++ b/fed/fed_object.py @@ -17,6 +17,7 @@ class FedObjectSendingContext: """The class that's used for holding the all contexts about sending side.""" + def __init__(self) -> None: # This field holds the target(downstream) parties that this fed object # is sending or sent to. @@ -33,6 +34,7 @@ def was_sending_or_sent_to_party(self, target_party: str): class FedObjectReceivingContext: """The class that's used for holding the all contexts about receiving side.""" + pass @@ -60,7 +62,7 @@ def get_ray_object_ref(self): return self._ray_object_ref def get_fed_task_id(self): - return f'{self._fed_task_id}#{self._idx_in_task}' + return f"{self._fed_task_id}#{self._idx_in_task}" def get_party(self): return self._node_party diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index aa207ae..1ac8345 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -27,8 +27,8 @@ logger = logging.getLogger(__name__) -_SENDER_PROXY_ACTOR_NAME = 'SenderProxyActor' -_RECEIVER_PROXY_ACTOR_NAME = 'ReceiverProxyActor' +_SENDER_PROXY_ACTOR_NAME = "SenderProxyActor" +_RECEIVER_PROXY_ACTOR_NAME = "ReceiverProxyActor" def sender_proxy_actor_name() -> str: @@ -51,9 +51,9 @@ def set_receiver_proxy_actor_name(name: str): _RECEIVER_PROXY_ACTOR_NAME = name -def set_proxy_actor_name(job_name: str, - use_global_proxy: bool, - sender_recvr_proxy: bool = False): +def set_proxy_actor_name( + job_name: str, use_global_proxy: bool, sender_recvr_proxy: bool = False +): """ Generate the name of the proxy actor. @@ -136,7 +136,8 @@ def __init__( job_config = fed_config.get_job_config(job_name) cross_silo_comm_config = job_config.cross_silo_comm_config_dict self._proxy_instance: SenderProxy = proxy_cls( - addresses, party, job_name, tls_config, cross_silo_comm_config) + addresses, party, job_name, tls_config, cross_silo_comm_config + ) async def is_ready(self): res = await self._proxy_instance.is_ready() @@ -152,21 +153,21 @@ async def send( self._stats["send_op_count"] += 1 assert ( dest_party in self._addresses - ), f'Failed to find {dest_party} in addresses {self._addresses}.' + ), f"Failed to find {dest_party} in addresses {self._addresses}." send_log_msg = ( - f'send data to seq_id {downstream_seq_id} of {dest_party} ' - f'from {upstream_seq_id}' + f"send data to seq_id {downstream_seq_id} of {dest_party} " + f"from {upstream_seq_id}" ) logger.debug( f'Sending {send_log_msg} with{"out" if not self._tls_config else ""}' - ' credentials.' + " credentials." ) try: response = await self._proxy_instance.send( dest_party, data, upstream_seq_id, downstream_seq_id ) except Exception as e: - logger.error(f'Failed to {send_log_msg}, error: {e}') + logger.error(f"Failed to {send_log_msg}, error: {e}") return False logger.debug(f"Succeeded to send {send_log_msg}. Response is {response}") return True # True indicates it's sent successfully. @@ -207,7 +208,8 @@ def __init__( job_config = fed_config.get_job_config(job_name) cross_silo_comm_config = job_config.cross_silo_comm_config_dict self._proxy_instance: ReceiverProxy = proxy_cls( - listening_address, party, job_name, tls_config, cross_silo_comm_config) + listening_address, party, job_name, tls_config, cross_silo_comm_config + ) async def start(self): await self._proxy_instance.start() @@ -222,9 +224,11 @@ async def get_data(self, src_party, upstream_seq_id, curr_seq_id): src_party, upstream_seq_id, curr_seq_id ) if isinstance(data, Exception): - logger.debug(f"Receiving exception: {type(data)}, {data} from {src_party}, " - f"upstream_seq_id: {upstream_seq_id}, " - f"curr_seq_id: {curr_seq_id}. Re-raise it.") + logger.debug( + f"Receiving exception: {type(data)}, {data} from {src_party}, " + f"upstream_seq_id: {upstream_seq_id}, " + f"curr_seq_id: {curr_seq_id}. Re-raise it." + ) raise data return data @@ -350,7 +354,7 @@ def __init__( job_name=job_name, ) - self._stats = {'send_op_count': 0, 'receive_op_count': 0} + self._stats = {"send_op_count": 0, "receive_op_count": 0} self._addresses = addresses self._party = party self._tls_config = tls_config @@ -381,21 +385,21 @@ def send( self._stats["send_op_count"] += 1 assert ( dest_party in self._addresses - ), f'Failed to find {dest_party} in cluster {self._addresses}.' + ), f"Failed to find {dest_party} in cluster {self._addresses}." send_log_msg = ( - f'send data to seq_id {downstream_seq_id} of {dest_party} ' - f'from {upstream_seq_id}' + f"send data to seq_id {downstream_seq_id} of {dest_party} " + f"from {upstream_seq_id}" ) logger.debug( f'Sending {send_log_msg} with{"out" if not self._tls_config else ""}' - ' credentials.' + " credentials." ) try: response = self._proxy_instance.send( dest_party, data, upstream_seq_id, downstream_seq_id ) except Exception as e: - logger.error(f'Failed to {send_log_msg}, error: {e}') + logger.error(f"Failed to {send_log_msg}, error: {e}") return False logger.debug(f"Succeeded to {send_log_msg}. Response is {response}") return True # True indicates it's sent successfully. @@ -437,13 +441,14 @@ def _start_sender_receiver_proxy( global _SENDER_RECEIVER_PROXY_ACTOR _SENDER_RECEIVER_PROXY_ACTOR = SenderReceiverProxyActor.options( - **actor_options).remote( - addresses=addresses, - party=party, - job_name=job_name, - tls_config=tls_config, - logging_level=logging_level, - proxy_cls=proxy_cls, + **actor_options + ).remote( + addresses=addresses, + party=party, + job_name=job_name, + tls_config=tls_config, + logging_level=logging_level, + proxy_cls=proxy_cls, ) _SENDER_RECEIVER_PROXY_ACTOR.start.remote() server_state = ray.get( @@ -453,13 +458,7 @@ def _start_sender_receiver_proxy( logger.info("Succeeded to create receiver proxy actor.") -def send( - dest_party, - data, - upstream_seq_id, - downstream_seq_id, - is_error=False -): +def send(dest_party, data, upstream_seq_id, downstream_seq_id, is_error=False): """ Args: is_error: Whether the `data` is an error object or not. Default is False. @@ -473,12 +472,13 @@ def send( downstream_seq_id=downstream_seq_id, ) get_global_context().get_cleanup_manager().push_to_sending( - res, dest_party, upstream_seq_id, downstream_seq_id, is_error) + res, dest_party, upstream_seq_id, downstream_seq_id, is_error + ) return res def recv(party: str, src_party: str, upstream_seq_id, curr_seq_id): - assert party, 'Party can not be None.' + assert party, "Party can not be None." receiver_proxy = ray.get_actor(receiver_proxy_actor_name()) return receiver_proxy.get_data.remote(src_party, upstream_seq_id, curr_seq_id) @@ -490,13 +490,13 @@ def ping_others(addresses: Dict[str, Dict], self_party: str, max_retries=3600): while tried < max_retries and others: logger.info( - f'Try ping {others} at {tried} attemp, up to {max_retries} attemps.' + f"Try ping {others} at {tried} attemp, up to {max_retries} attemps." ) tried += 1 _party_ping_obj = {} # {$party_name: $ObjectRef} # Batch ping all the other parties for other in others: - _party_ping_obj[other] = send(other, b'data', 'ping', 'ping') + _party_ping_obj[other] = send(other, b"data", "ping", "ping") _, _unready = ray.wait(list(_party_ping_obj.values()), timeout=1) # Keep the unready party for the next ping. diff --git a/fed/proxy/base_proxy.py b/fed/proxy/base_proxy.py index b2eba26..c62285b 100644 --- a/fed/proxy/base_proxy.py +++ b/fed/proxy/base_proxy.py @@ -51,7 +51,7 @@ def __init__( party: str, job_name: str, tls_config: Dict, - proxy_config: CrossSiloMessageConfig = None + proxy_config: CrossSiloMessageConfig = None, ) -> None: self._listen_addr = listen_addr self._party = party diff --git a/fed/proxy/brpc_link/link.py b/fed/proxy/brpc_link/link.py new file mode 100644 index 0000000..5452af5 --- /dev/null +++ b/fed/proxy/brpc_link/link.py @@ -0,0 +1,112 @@ +import logging +import threading +from typing import Dict + +import cloudpickle +import spu.libspu.link as link + +from fed.proxy.barriers import ( + add_two_dim_dict, + key_exists_in_two_dim_dict, + pop_from_two_dim_dict, +) +from fed.proxy.base_proxy import SenderReceiverProxy +from fed.proxy.brpc_link.link_config import BrpcLinkCrossSiloMessageConfig + +logger = logging.getLogger(__name__) + + +def _fill_link_ssl_opts(tls_config: Dict, link_ssl_opts: link.SSLOptions): + ca_cert = tls_config["ca_cert"] + cert = tls_config["cert"] + key = tls_config["key"] + link_ssl_opts.cert.certificate_path = cert + link_ssl_opts.cert.private_key_path = key + link_ssl_opts.verify.ca_file_path = ca_cert + link_ssl_opts.verify.verify_depth = 1 + + +class BrpcLinkSenderReceiverProxy(SenderReceiverProxy): + def __init__( + self, + addresses: Dict, + self_party: str, + tls_config: Dict = None, + proxy_config: Dict = None, + ) -> None: + proxy_config = BrpcLinkCrossSiloMessageConfig.from_dict(proxy_config) + super().__init__(addresses, self_party, tls_config, proxy_config) + self._parties_rank = { + party: i for i, party in enumerate(self._addresses.keys()) + } + self._rank = list(self._addresses).index(self_party) + + desc = link.Desc() + for party, addr in self._addresses.items(): + desc.add_party(party, addr) + if tls_config: + _fill_link_ssl_opts(tls_config, desc.server_ssl_opts) + _fill_link_ssl_opts(tls_config, desc.client_ssl_opts) + if isinstance(proxy_config, BrpcLinkCrossSiloMessageConfig): + proxy_config.dump_to_link_desc(desc) + self._desc = desc + + self._all_data = {} + self._server_ready_with_msg = (False, "") + self._server_ready_event = threading.Event() + + def start(self): + try: + self._linker = link.create_brpc(self._desc, self._rank) + self._server_ready = ( + True, + f"Succeeded to listen on {self._addresses[self._party]}.", + ) + self._server_ready_event.set() + + except Exception as e: + self._server_ready = ( + False, + f"Failed to listen on {self._addresses[self._party]} as exception:\n{e}", + ) + self._server_ready_event.set() + + def is_ready(self): + self._server_ready_event.wait() + return self._server_ready + + def send(self, dest_party, data, upstream_seq_id, downstream_seq_id): + msg = { + "upstream_seq_id": upstream_seq_id, + "downstream_seq_id": downstream_seq_id, + "payload": data, + } + msg_bytes = cloudpickle.dumps(msg) + self._linker.send_async(self._parties_rank[dest_party], msg_bytes) + + return True + + def get_data(self, src_party, upstream_seq_id, curr_seq_id): + data_log_msg = f"data for {curr_seq_id} from {upstream_seq_id} of {src_party}" + logger.debug(f"Getting {data_log_msg}") + all_data = self._all_data + rank = self._parties_rank[src_party] + if key_exists_in_two_dim_dict(all_data, upstream_seq_id, curr_seq_id): + logger.debug(f"Getted {data_log_msg}.") + return pop_from_two_dim_dict(all_data, upstream_seq_id, curr_seq_id) + + while True: + msg = self._linker.recv(rank) + msg = cloudpickle.loads(msg) + upstream_seq_id_in_msg = msg["upstream_seq_id"] + downstream_seq_id_in_msg = msg["downstream_seq_id"] + data = msg["payload"] + if ( + upstream_seq_id_in_msg == upstream_seq_id + and downstream_seq_id_in_msg == curr_seq_id + ): + logger.debug(f"Getted {data_log_msg}.") + return data + else: + logger.debug(f"Received {data_log_msg}.") + add_two_dim_dict(all_data, upstream_seq_id, curr_seq_id, data) diff --git a/fed/proxy/brpc_link/link_config.py b/fed/proxy/brpc_link/link_config.py new file mode 100644 index 0000000..e7259f2 --- /dev/null +++ b/fed/proxy/brpc_link/link_config.py @@ -0,0 +1,50 @@ +import logging +from dataclasses import dataclass + +from spu.libspu import link + +from fed.config import CrossSiloMessageConfig + + +@dataclass +class BrpcLinkCrossSiloMessageConfig(CrossSiloMessageConfig): + connect_retry_times: int = None + connect_retry_interval_ms: int = None + recv_timeout_ms: int = None + http_timeout_ms: int = None + http_max_payload_size: int = None + throttle_window_size: int = None + brpc_channel_protocol: str = None + brpc_channel_connection_type: str = None + + def dump_to_link_desc(self, link_desc: link.Desc): + if self.timeout_in_ms is not None: + link_desc.http_timeout_ms = self.timeout_in_ms + + if self.connect_retry_times is not None: + link_desc.connect_retry_times = self.connect_retry_times + if self.connect_retry_interval_ms is not None: + link_desc.connect_retry_interval_ms = self.connect_retry_interval_ms + if self.recv_timeout_ms is not None: + link_desc.recv_timeout_ms = self.recv_timeout_ms + if self.http_timeout_ms is not None: + logging.warning( + "http_timeout_ms and timeout_ms are set at the same time, " + f"http_timeout_ms {self.http_timeout_ms} will be used." + ) + link_desc.http_timeout_ms = self.http_timeout_ms + if self.http_max_payload_size is not None: + link_desc.http_max_payload_size = self.http_max_payload_size + if self.throttle_window_size is not None: + link_desc.throttle_window_size = self.throttle_window_size + if self.brpc_channel_protocol is not None: + link_desc.brpc_channel_protocol = self.brpc_channel_protocol + if self.brpc_channel_connection_type is not None: + link_desc.brpc_channel_connection_type = self.brpc_channel_connection_type + + if not hasattr(link_desc, "recv_timeout_ms"): + # set default timeout 3600s + link_desc.recv_timeout_ms = 3600 * 1000 + if not hasattr(link_desc, "http_timeout_ms"): + # set default timeout 120s + link_desc.http_timeout_ms = 120 * 1000 diff --git a/fed/proxy/grpc/grpc_options.py b/fed/proxy/grpc/grpc_options.py index 6e4b2d1..e37245d 100644 --- a/fed/proxy/grpc/grpc_options.py +++ b/fed/proxy/grpc/grpc_options.py @@ -30,21 +30,20 @@ _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH = 500 * 1024 * 1024 _DEFAULT_GRPC_CHANNEL_OPTIONS = { - 'grpc.enable_retries': 1, - 'grpc.so_reuseport': 0, - 'grpc.max_send_message_length': _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH, - 'grpc.max_receive_message_length': _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH, - 'grpc.service_config': - json.dumps( - { - 'methodConfig': [ - { - 'name': [{'service': _GRPC_SERVICE}], - 'retryPolicy': _DEFAULT_GRPC_RETRY_POLICY, - } - ] - } - ), + "grpc.enable_retries": 1, + "grpc.so_reuseport": 0, + "grpc.max_send_message_length": _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH, + "grpc.max_receive_message_length": _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH, + "grpc.service_config": json.dumps( + { + "methodConfig": [ + { + "name": [{"service": _GRPC_SERVICE}], + "retryPolicy": _DEFAULT_GRPC_RETRY_POLICY, + } + ] + } + ), } @@ -60,26 +59,26 @@ def get_grpc_options( return [ ( - 'grpc.max_send_message_length', + "grpc.max_send_message_length", max_send_message_length, ), ( - 'grpc.max_receive_message_length', + "grpc.max_receive_message_length", max_receive_message_length, ), - ('grpc.enable_retries', 1), + ("grpc.enable_retries", 1), ( - 'grpc.service_config', + "grpc.service_config", json.dumps( { - 'methodConfig': [ + "methodConfig": [ { - 'name': [{'service': _GRPC_SERVICE}], - 'retryPolicy': retry_policy, + "name": [{"service": _GRPC_SERVICE}], + "retryPolicy": retry_policy, } ] } ), ), - ('grpc.so_reuseport', 0), + ("grpc.so_reuseport", 0), ] diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index 1fdfa4e..cfe1457 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -14,27 +14,29 @@ import asyncio import copy -import cloudpickle -import grpc +import json import logging import threading -import json from typing import Dict -import fed.utils as fed_utils +import cloudpickle +import grpc -from fed.config import CrossSiloMessageConfig, GrpcCrossSiloMessageConfig import fed._private.compatible_utils as compatible_utils -from fed.proxy.grpc.grpc_options import _DEFAULT_GRPC_CHANNEL_OPTIONS, _GRPC_SERVICE +import fed.utils as fed_utils +from fed.config import CrossSiloMessageConfig, GrpcCrossSiloMessageConfig from fed.proxy.barriers import ( add_two_dim_dict, get_from_two_dim_dict, - pop_from_two_dim_dict, key_exists_in_two_dim_dict, + pop_from_two_dim_dict, ) -from fed.proxy.base_proxy import SenderProxy, ReceiverProxy +from fed.proxy.base_proxy import ReceiverProxy, SenderProxy +from fed.proxy.grpc.grpc_options import _DEFAULT_GRPC_CHANNEL_OPTIONS, _GRPC_SERVICE + if compatible_utils._compare_version_strings( - fed_utils.get_package_version('protobuf'), '4.0.0'): + fed_utils.get_package_version("protobuf"), "4.0.0" +): from fed.grpc.pb4 import fed_pb2 as fed_pb2 from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc else: @@ -67,43 +69,44 @@ def parse_grpc_options(proxy_config: CrossSiloMessageConfig): # However, `GrpcCrossSiloMessageConfig` provides a more flexible way # to configure grpc channel options, i.e. the `grpc_channel_options` # field, which may override the `messages_max_size_in_bytes` field. - if (isinstance(proxy_config, CrossSiloMessageConfig)): - if (proxy_config.messages_max_size_in_bytes is not None): - grpc_channel_options.update({ - 'grpc.max_send_message_length': - proxy_config.messages_max_size_in_bytes, - 'grpc.max_receive_message_length': - proxy_config.messages_max_size_in_bytes, - }) + if isinstance(proxy_config, CrossSiloMessageConfig): + if proxy_config.messages_max_size_in_bytes is not None: + grpc_channel_options.update( + { + "grpc.max_send_message_length": proxy_config.messages_max_size_in_bytes, + "grpc.max_receive_message_length": proxy_config.messages_max_size_in_bytes, + } + ) if isinstance(proxy_config, GrpcCrossSiloMessageConfig): if proxy_config.grpc_channel_options is not None: grpc_channel_options.update(proxy_config.grpc_channel_options) if proxy_config.grpc_retry_policy is not None: - grpc_channel_options.update({ - 'grpc.service_config': - json.dumps( - { - 'methodConfig': [ - { - 'name': [{'service': _GRPC_SERVICE}], - 'retryPolicy': proxy_config.grpc_retry_policy, - } - ] - } - ), - }) + grpc_channel_options.update( + { + "grpc.service_config": json.dumps( + { + "methodConfig": [ + { + "name": [{"service": _GRPC_SERVICE}], + "retryPolicy": proxy_config.grpc_retry_policy, + } + ] + } + ), + } + ) return grpc_channel_options class GrpcSenderProxy(SenderProxy): def __init__( - self, - cluster: Dict, - party: str, - job_name: str, - tls_config: Dict, - proxy_config: Dict = None + self, + cluster: Dict, + party: str, + job_name: str, + tls_config: Dict, + proxy_config: Dict = None, ) -> None: proxy_config = GrpcCrossSiloMessageConfig.from_dict(proxy_config) super().__init__(cluster, party, job_name, tls_config, proxy_config) @@ -113,29 +116,27 @@ def __init__( # Mapping the destination party name to the reused client stub. self._stubs = {} - async def send( - self, - dest_party, - data, - upstream_seq_id, - downstream_seq_id): + async def send(self, dest_party, data, upstream_seq_id, downstream_seq_id): dest_addr = self._addresses[dest_party] grpc_metadata, grpc_channel_options = self.get_grpc_config_by_party(dest_party) tls_enabled = fed_utils.tls_enabled(self._tls_config) if dest_party not in self._stubs: if tls_enabled: ca_cert, private_key, cert_chain = fed_utils.load_cert_config( - self._tls_config) + self._tls_config + ) credentials = grpc.ssl_channel_credentials( certificate_chain=cert_chain, private_key=private_key, root_certificates=ca_cert, ) channel = grpc.aio.secure_channel( - dest_addr, credentials, options=grpc_channel_options) + dest_addr, credentials, options=grpc_channel_options + ) else: channel = grpc.aio.insecure_channel( - dest_addr, options=grpc_channel_options) + dest_addr, options=grpc_channel_options + ) stub = fed_pb2_grpc.GrpcServiceStub(channel) self._stubs[dest_party] = stub @@ -153,8 +154,7 @@ async def send( return response.result def get_grpc_config_by_party(self, dest_party): - """Overide global config by party specific config - """ + """Overide global config by party specific config""" grpc_metadata = self._grpc_metadata grpc_options = self._grpc_options @@ -162,14 +162,9 @@ def get_grpc_config_by_party(self, dest_party): if dest_party_msg_config is not None: if dest_party_msg_config.http_header is not None: dest_party_grpc_metadata = dict(dest_party_msg_config.http_header) - grpc_metadata = { - **grpc_metadata, - **dest_party_grpc_metadata - } + grpc_metadata = {**grpc_metadata, **dest_party_grpc_metadata} dest_party_grpc_options = parse_grpc_options(dest_party_msg_config) - grpc_options = { - **grpc_options, **dest_party_grpc_options - } + grpc_options = {**grpc_options, **dest_party_grpc_options} return grpc_metadata, fed_utils.dict2tuple(grpc_options) async def get_proxy_config(self, dest_party=None): @@ -178,7 +173,7 @@ async def get_proxy_config(self, dest_party=None): else: _, grpc_options = self.get_grpc_config_by_party(dest_party) proxy_config = self._proxy_config.__dict__ - proxy_config.update({'grpc_options': grpc_options}) + proxy_config.update({"grpc_options": grpc_options}) return proxy_config def handle_response_error(self, response): @@ -188,8 +183,10 @@ def handle_response_error(self, response): if 400 <= response.code < 500: # Request error should also be identified as a sending failure, # though the request was physically sent. - logger.warning(f"Request was successfully sent but got error response, " - f"code: {response.code}, message: {response.result}.") + logger.warning( + f"Request was successfully sent but got error response, " + f"code: {response.code}, message: {response.result}." + ) raise RuntimeError(response.result) @@ -216,21 +213,21 @@ async def send_data_grpc( timeout=timeout, ) logger.debug( - f'Received data response from seq_id {downstream_seq_id}, ' - f'code: {response.code}, ' - f'result: {response.result}.' + f"Received data response from seq_id {downstream_seq_id}, " + f"code: {response.code}, " + f"result: {response.result}." ) return response class GrpcReceiverProxy(ReceiverProxy): def __init__( - self, - listen_addr: str, - party: str, - job_name: str, - tls_config: Dict, - proxy_config: Dict + self, + listen_addr: str, + party: str, + job_name: str, + tls_config: Dict, + proxy_config: Dict, ) -> None: proxy_config = GrpcCrossSiloMessageConfig.from_dict(proxy_config) super().__init__(listen_addr, party, job_name, tls_config, proxy_config) @@ -246,7 +243,7 @@ def __init__( self._lock = threading.Lock() async def start(self): - port = self._listen_addr[self._listen_addr.index(':') + 1 :] + port = self._listen_addr[self._listen_addr.index(":") + 1 :] try: await _run_grpc_server( port, @@ -260,9 +257,11 @@ async def start(self): fed_utils.dict2tuple(self._grpc_options), ) except RuntimeError as err: - msg = f'Grpc server failed to listen to port: {port}' \ - f' Try another port by setting `listen_addr` into `cluster` config' \ - f' when calling `fed.init`. Grpc error msg: {err}' + msg = ( + f"Grpc server failed to listen to port: {port}" + f" Try another port by setting `listen_addr` into `cluster` config" + f" when calling `fed.init`. Grpc error msg: {err}" + ) self._server_ready_future.set_result((False, msg)) async def is_ready(self): @@ -289,12 +288,13 @@ async def get_data(self, src_party, upstream_seq_id, curr_seq_id): # NOTE(qwang): This is used to avoid the conflict with pickle5 in Ray. import fed._private.serialization_utils as fed_ser_utils + fed_ser_utils._apply_loads_function_with_whitelist() return cloudpickle.loads(data) async def get_proxy_config(self): proxy_config = self._proxy_config.__dict__ - proxy_config.update({'grpc_options': fed_utils.dict2tuple(self._grpc_options)}) + proxy_config.update({"grpc_options": fed_utils.dict2tuple(self._grpc_options)}) return proxy_config @@ -309,17 +309,20 @@ def __init__(self, all_events, all_data, party, lock, job_name): async def SendData(self, request, context): job_name = request.job_name if job_name != self._job_name: - logger.warning(f"Receive data from job {job_name}, ignore it. " - f"The reason may be that the ReceiverProxy is listening " - f"on the same address with that job.") + logger.warning( + f"Receive data from job {job_name}, ignore it. " + f"The reason may be that the ReceiverProxy is listening " + f"on the same address with that job." + ) return fed_pb2.SendDataResponse( code=417, - result=f"JobName mis-match, expected {self._job_name}, got {job_name}.") + result=f"JobName mis-match, expected {self._job_name}, got {job_name}.", + ) upstream_seq_id = request.upstream_seq_id downstream_seq_id = request.downstream_seq_id logger.debug( - f'Received a grpc data request from {upstream_seq_id} to ' - f'{downstream_seq_id}.' + f"Received a grpc data request from {upstream_seq_id} to " + f"{downstream_seq_id}." ) with self._lock: @@ -340,8 +343,15 @@ async def SendData(self, request, context): async def _run_grpc_server( - port, event, all_data, party, lock, job_name, - server_ready_future, tls_config=None, grpc_options=None + port, + event, + all_data, + party, + lock, + job_name, + server_ready_future, + tls_config=None, + grpc_options=None, ): logger.info(f"ReceiverProxy binding port {port}, options: {grpc_options}...") server = grpc.aio.server(options=grpc_options) @@ -357,15 +367,15 @@ async def _run_grpc_server( root_certificates=ca_cert, require_client_auth=ca_cert is not None, ) - server.add_secure_port(f'[::]:{port}', server_credentials) + server.add_secure_port(f"[::]:{port}", server_credentials) else: - server.add_insecure_port(f'[::]:{port}') + server.add_insecure_port(f"[::]:{port}") msg = f"Succeeded to add port {port}." await server.start() logger.info( f'Successfully start Grpc service with{"out" if not tls_enabled else ""} ' - 'credentials.' + "credentials." ) server_ready_future.set_result((True, msg)) await server.wait_for_termination() diff --git a/fed/tests/client_mode_tests/test_basic_client_mode.py b/fed/tests/client_mode_tests/test_basic_client_mode.py index 9807802..d5740cf 100644 --- a/fed/tests/client_mode_tests/test_basic_client_mode.py +++ b/fed/tests/client_mode_tests/test_basic_client_mode.py @@ -16,9 +16,10 @@ import pytest import ray + import fed import fed._private.compatible_utils as compatible_utils -from fed.tests.test_utils import ray_client_mode_setup # noqa +from fed.tests.test_utils import ray_client_mode_setup # noqa @fed.remote @@ -49,15 +50,18 @@ def mean(x, y): def run(party): import time - if party == 'alice': + + if party == "alice": time.sleep(1.4) - address = 'ray://127.0.0.1:21012' if party == 'alice' else 'ray://127.0.0.1:21011' # noqa + address = ( + "ray://127.0.0.1:21012" if party == "alice" else "ray://127.0.0.1:21011" + ) # noqa compatible_utils.init_ray(address=address) addresses = { - 'alice': '127.0.0.1:31012', - 'bob': '127.0.0.1:31011', + "alice": "127.0.0.1:31012", + "bob": "127.0.0.1:31011", } fed.init(addresses=addresses, party=party) @@ -83,9 +87,9 @@ def run(party): ray.shutdown() -def test_fed_get_in_2_parties(ray_client_mode_setup): # noqa - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) +def test_fed_get_in_2_parties(ray_client_mode_setup): # noqa + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/multi-jobs/test_ignore_other_job_msg.py b/fed/tests/multi-jobs/test_ignore_other_job_msg.py index a75d8f7..cd09878 100644 --- a/fed/tests/multi-jobs/test_ignore_other_job_msg.py +++ b/fed/tests/multi-jobs/test_ignore_other_job_msg.py @@ -13,32 +13,30 @@ # limitations under the License. import multiprocessing -import fed -import ray + import grpc import pytest -import fed.utils as fed_utils +import ray + +import fed import fed._private.compatible_utils as compatible_utils +import fed.utils as fed_utils from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy, send_data_grpc + if compatible_utils._compare_version_strings( - fed_utils.get_package_version('protobuf'), '4.0.0'): + fed_utils.get_package_version("protobuf"), "4.0.0" +): from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc else: from fed.grpc.pb3 import fed_pb2_grpc as fed_pb2_grpc class TestGrpcSenderProxy(GrpcSenderProxy): - async def send( - self, - dest_party, - data, - upstream_seq_id, - downstream_seq_id): + async def send(self, dest_party, data, upstream_seq_id, downstream_seq_id): dest_addr = self._addresses[dest_party] grpc_metadata, grpc_channel_options = self.get_grpc_config_by_party(dest_party) if dest_party not in self._stubs: - channel = grpc.aio.insecure_channel( - dest_addr, options=grpc_channel_options) + channel = grpc.aio.insecure_channel(dest_addr, options=grpc_channel_options) stub = fed_pb2_grpc.GrpcServiceStub(channel) self._stubs[dest_party] = stub @@ -74,23 +72,26 @@ def agg_fn(obj1, obj2): addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } def run(party, job_name): - ray.init(address='local') - fed.init(addresses=addresses, - party=party, - job_name=job_name, - sender_proxy_cls=TestGrpcSenderProxy, - config={ - 'cross_silo_comm': { - 'exit_on_sending_failure': True, - }}) + ray.init(address="local") + fed.init( + addresses=addresses, + party=party, + job_name=job_name, + sender_proxy_cls=TestGrpcSenderProxy, + config={ + "cross_silo_comm": { + "exit_on_sending_failure": True, + } + }, + ) # 'bob' only needs to start the proxy actors - if party == 'alice': + if party == "alice": ds1, ds2 = [123, 789] actor_alice = MyActor.party("alice").remote(party, ds1) actor_bob = MyActor.party("bob").remote(party, ds2) @@ -103,13 +104,14 @@ def run(party, job_name): fed.shutdown() ray.shutdown() import time + # Wait for SIGTERM as failure on sending. time.sleep(86400) def test_ignore_other_job_msg(): - p_alice = multiprocessing.Process(target=run, args=('alice', 'job1')) - p_bob = multiprocessing.Process(target=run, args=('bob', 'job2')) + p_alice = multiprocessing.Process(target=run, args=("alice", "job1")) + p_bob = multiprocessing.Process(target=run, args=("bob", "job2")) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/multi-jobs/test_multi_proxy_actor.py b/fed/tests/multi-jobs/test_multi_proxy_actor.py index 5021ec0..ab0ec33 100644 --- a/fed/tests/multi-jobs/test_multi_proxy_actor.py +++ b/fed/tests/multi-jobs/test_multi_proxy_actor.py @@ -13,32 +13,30 @@ # limitations under the License. import multiprocessing -import fed -import ray + import grpc import pytest -import fed.utils as fed_utils +import ray + +import fed import fed._private.compatible_utils as compatible_utils +import fed.utils as fed_utils from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy, send_data_grpc + if compatible_utils._compare_version_strings( - fed_utils.get_package_version('protobuf'), '4.0.0'): + fed_utils.get_package_version("protobuf"), "4.0.0" +): from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc else: from fed.grpc.pb3 import fed_pb2_grpc as fed_pb2_grpc class TestGrpcSenderProxy(GrpcSenderProxy): - async def send( - self, - dest_party, - data, - upstream_seq_id, - downstream_seq_id): + async def send(self, dest_party, data, upstream_seq_id, downstream_seq_id): dest_addr = self._addresses[dest_party] grpc_metadata, grpc_channel_options = self.get_grpc_config_by_party(dest_party) if dest_party not in self._stubs: - channel = grpc.aio.insecure_channel( - dest_addr, options=grpc_channel_options) + channel = grpc.aio.insecure_channel(dest_addr, options=grpc_channel_options) stub = fed_pb2_grpc.GrpcServiceStub(channel) self._stubs[dest_party] = stub @@ -74,29 +72,32 @@ def agg_fn(obj1, obj2): addresses = { - 'job1': { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "job1": { + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", }, - 'job2': { - 'alice': '127.0.0.1:12012', - 'bob': '127.0.0.1:12011', + "job2": { + "alice": "127.0.0.1:12012", + "bob": "127.0.0.1:12011", }, } def run(party, job_name): - ray.init(address='local') - fed.init(addresses=addresses[job_name], - party=party, - job_name=job_name, - sender_proxy_cls=TestGrpcSenderProxy, - config={ - 'cross_silo_comm': { - 'exit_on_sending_failure': True, - # Create unique proxy for current job - 'use_global_proxy': False - }}) + ray.init(address="local") + fed.init( + addresses=addresses[job_name], + party=party, + job_name=job_name, + sender_proxy_cls=TestGrpcSenderProxy, + config={ + "cross_silo_comm": { + "exit_on_sending_failure": True, + # Create unique proxy for current job + "use_global_proxy": False, + } + }, + ) sender_proxy_actor_name = f"SenderProxyActor_{job_name}" receiver_proxy_actor_name = f"ReceiverProxyActor_{job_name}" @@ -108,8 +109,8 @@ def run(party, job_name): def test_multi_proxy_actor(): - p_alice_job1 = multiprocessing.Process(target=run, args=('alice', 'job1')) - p_alice_job2 = multiprocessing.Process(target=run, args=('alice', 'job2')) + p_alice_job1 = multiprocessing.Process(target=run, args=("alice", "job1")) + p_alice_job2 = multiprocessing.Process(target=run, args=("alice", "job2")) p_alice_job1.start() p_alice_job2.start() p_alice_job1.join() diff --git a/fed/tests/serializations_tests/test_unpickle_with_whitelist.py b/fed/tests/serializations_tests/test_unpickle_with_whitelist.py index 43e3c3f..3420f66 100644 --- a/fed/tests/serializations_tests/test_unpickle_with_whitelist.py +++ b/fed/tests/serializations_tests/test_unpickle_with_whitelist.py @@ -41,10 +41,10 @@ def pass_arg(d): def run(party): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } allowed_list = { "numpy.core.numeric": ["*"], @@ -53,7 +53,7 @@ def run(party): fed.init( addresses=addresses, party=party, - config={"cross_silo_comm": {'serializing_allowed_list': allowed_list}}, + config={"cross_silo_comm": {"serializing_allowed_list": allowed_list}}, ) # Test passing an allowed type. @@ -80,8 +80,8 @@ def run(party): def test_restricted_loads(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/simple_example.py b/fed/tests/simple_example.py index 4ca1095..73d7b08 100644 --- a/fed/tests/simple_example.py +++ b/fed/tests/simple_example.py @@ -13,9 +13,11 @@ # limitations under the License. import multiprocessing -import fed + import ray +import fed + @fed.remote class MyActor: @@ -43,13 +45,13 @@ def agg_fn(obj1, obj2): addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } def run(party): - ray.init(address='local') + ray.init(address="local") fed.init(addresses=addresses, party=party) print(f"Running the script in party {party}") @@ -71,8 +73,8 @@ def run(party): def main(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_api.py b/fed/tests/test_api.py index 9a3b00a..006f80b 100644 --- a/fed/tests/test_api.py +++ b/fed/tests/test_api.py @@ -13,17 +13,19 @@ # limitations under the License. import multiprocessing + import pytest +import ray + import fed import fed._private.compatible_utils as compatible_utils -import ray import fed.config as fed_config def run(): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11012', + "alice": "127.0.0.1:11012", } fed.init(addresses=addresses, party="alice") config = fed_config.get_cluster_config() @@ -41,9 +43,9 @@ def test_fed_apis(): def _run(): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11012', + "alice": "127.0.0.1:11012", } fed.init(addresses=addresses, party="alice") diff --git a/fed/tests/test_async_startup_2_clusters.py b/fed/tests/test_async_startup_2_clusters.py index 9542f87..3cee62c 100644 --- a/fed/tests/test_async_startup_2_clusters.py +++ b/fed/tests/test_async_startup_2_clusters.py @@ -15,8 +15,8 @@ import multiprocessing import pytest - import ray + import fed import fed._private.compatible_utils as compatible_utils @@ -42,10 +42,10 @@ def _run(party: str): time.sleep(10) - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } fed.init(addresses=addresses, party=party) @@ -61,8 +61,8 @@ def _run(party: str): # This case is used to test that we start 2 clusters not at the same time. def test_async_startup_2_clusters(): - p_alice = multiprocessing.Process(target=_run, args=('alice',)) - p_bob = multiprocessing.Process(target=_run, args=('bob',)) + p_alice = multiprocessing.Process(target=_run, args=("alice",)) + p_bob = multiprocessing.Process(target=_run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_basic_pass_fed_objects.py b/fed/tests/test_basic_pass_fed_objects.py index ebca911..5068990 100644 --- a/fed/tests/test_basic_pass_fed_objects.py +++ b/fed/tests/test_basic_pass_fed_objects.py @@ -16,6 +16,7 @@ import pytest import ray + import fed import fed._private.compatible_utils as compatible_utils @@ -35,10 +36,10 @@ def get_value(self): def run(party, is_inner_party): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } fed.init(addresses=addresses, party=party) @@ -57,8 +58,8 @@ def run(party, is_inner_party): def test_pass_fed_objects_for_actor_creation_inner_party(): - p_alice = multiprocessing.Process(target=run, args=('alice', True)) - p_bob = multiprocessing.Process(target=run, args=('bob', True)) + p_alice = multiprocessing.Process(target=run, args=("alice", True)) + p_bob = multiprocessing.Process(target=run, args=("bob", True)) p_alice.start() p_bob.start() p_alice.join() @@ -67,8 +68,8 @@ def test_pass_fed_objects_for_actor_creation_inner_party(): def test_pass_fed_objects_for_actor_creation_across_party(): - p_alice = multiprocessing.Process(target=run, args=('alice', False)) - p_bob = multiprocessing.Process(target=run, args=('bob', False)) + p_alice = multiprocessing.Process(target=run, args=("alice", False)) + p_bob = multiprocessing.Process(target=run, args=("bob", False)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_brpc_link.py b/fed/tests/test_brpc_link.py new file mode 100644 index 0000000..4ad93d1 --- /dev/null +++ b/fed/tests/test_brpc_link.py @@ -0,0 +1,99 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing + +import pytest +import ray +import fed +import fed._private.compatible_utils as compatible_utils + + +@fed.remote +class MyModel: + def __init__(self, party, step_length): + self._trained_steps = 0 + self._step_length = step_length + self._weights = 0 + self._party = party + + def train(self): + self._trained_steps += 1 + self._weights += self._step_length + return self._weights + + def get_weights(self): + return self._weights + + def set_weights(self, new_weights): + self._weights = new_weights + return new_weights + + +@fed.remote +def mean(x, y): + return (x + y) / 2 + + +def run(party): + compatible_utils.init_ray(address="local") + addresses = { + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", + } + from fed.proxy.brpc_link.link import BrpcLinkSenderReceiverProxy + + fed.init( + addresses=addresses, + party=party, + receiver_sender_proxy_cls=BrpcLinkSenderReceiverProxy, + logging_level="debug", + ) + + epochs = 3 + alice_model = MyModel.party("alice").remote("alice", 2) + bob_model = MyModel.party("bob").remote("bob", 4) + + all_mean_weights = [] + for epoch in range(epochs): + w1 = alice_model.train.remote() + w2 = bob_model.train.remote() + new_weights = mean.party("alice").remote(w1, w2) + result = fed.get(new_weights) + alice_model.set_weights.remote(new_weights) + bob_model.set_weights.remote(new_weights) + all_mean_weights.append(result) + assert all_mean_weights == [3, 6, 9] + latest_weights = fed.get( + [alice_model.get_weights.remote(), bob_model.get_weights.remote()] + ) + assert latest_weights == [9, 9] + fed.shutdown() + ray.shutdown() + + +def test_fed_get_in_2_parties(): + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + assert p_alice.exitcode == 0 and p_bob.exitcode == 0 + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-sv", __file__])) diff --git a/fed/tests/test_cache_fed_objects.py b/fed/tests/test_cache_fed_objects.py index 3c47711..826ca69 100644 --- a/fed/tests/test_cache_fed_objects.py +++ b/fed/tests/test_cache_fed_objects.py @@ -33,10 +33,10 @@ def g(x, index): def run(party): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } fed.init(addresses=addresses, party=party) @@ -62,8 +62,8 @@ def run(party): def test_cache_fed_object_if_sent(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_cross_silo_error.py b/fed/tests/test_cross_silo_error.py index 17547a0..8837f00 100644 --- a/fed/tests/test_cross_silo_error.py +++ b/fed/tests/test_cross_silo_error.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ray import multiprocessing +import sys +from unittest.mock import Mock import pytest +import ray + import fed import fed._private.compatible_utils as compatible_utils -import sys - -from unittest.mock import Mock from fed.exceptions import FedRemoteError @@ -45,33 +45,33 @@ def error_func(self): def run(party): my_failure_handler = Mock() - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } fed.init( addresses=addresses, party=party, - logging_level='debug', + logging_level="debug", config={ - 'cross_silo_comm': { - 'exit_on_sending_failure': True, - 'timeout_ms': 20 * 1000, - 'expose_error_trace': True + "cross_silo_comm": { + "exit_on_sending_failure": True, + "timeout_ms": 20 * 1000, + "expose_error_trace": True, }, }, - failure_handler=my_failure_handler + failure_handler=my_failure_handler, ) # Both party should catch the error o = error_func.party("alice").remote() with pytest.raises(Exception) as e: fed.get(o) - if party == 'bob': + if party == "bob": assert isinstance(e.value.cause, FedRemoteError) - assert 'RemoteError occurred at alice' in str(e.value.cause) + assert "RemoteError occurred at alice" in str(e.value.cause) assert "normal task Error" in str(e.value.cause) else: assert isinstance(e.value.cause, MyError) @@ -82,8 +82,8 @@ def run(party): def test_cross_silo_normal_task_error(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() @@ -94,34 +94,34 @@ def test_cross_silo_normal_task_error(): def run2(party): my_failure_handler = Mock() - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } fed.init( addresses=addresses, party=party, - logging_level='debug', + logging_level="debug", config={ - 'cross_silo_comm': { - 'exit_on_sending_failure': True, - 'timeout_ms': 20 * 1000, - 'expose_error_trace': True + "cross_silo_comm": { + "exit_on_sending_failure": True, + "timeout_ms": 20 * 1000, + "expose_error_trace": True, }, }, - failure_handler=my_failure_handler + failure_handler=my_failure_handler, ) # Both party should catch the error - my = My.party('alice').remote() + my = My.party("alice").remote() o = my.error_func.remote() with pytest.raises(Exception) as e: fed.get(o) - if party == 'bob': + if party == "bob": assert isinstance(e.value.cause, FedRemoteError) - assert 'RemoteError occurred at alice' in str(e.value.cause) + assert "RemoteError occurred at alice" in str(e.value.cause) assert "actor task Error" in str(e.value.cause) my_failure_handler.assert_called_once() else: @@ -134,8 +134,8 @@ def run2(party): def test_cross_silo_actor_task_error(): - p_alice = multiprocessing.Process(target=run2, args=('alice',)) - p_bob = multiprocessing.Process(target=run2, args=('bob',)) + p_alice = multiprocessing.Process(target=run2, args=("alice",)) + p_bob = multiprocessing.Process(target=run2, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() @@ -146,33 +146,33 @@ def test_cross_silo_actor_task_error(): def run3(party): my_failure_handler = Mock() - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } fed.init( addresses=addresses, party=party, - logging_level='debug', + logging_level="debug", config={ - 'cross_silo_comm': { - 'exit_on_sending_failure': True, - 'timeout_ms': 20 * 1000, + "cross_silo_comm": { + "exit_on_sending_failure": True, + "timeout_ms": 20 * 1000, }, }, - failure_handler=my_failure_handler + failure_handler=my_failure_handler, ) # Both party should catch the error o = error_func.party("alice").remote() with pytest.raises(Exception) as e: fed.get(o) - if party == 'bob': + if party == "bob": assert isinstance(e.value.cause, FedRemoteError) - assert 'RemoteError occurred at alice' in str(e.value.cause) - assert 'caused by' not in str(e.value.cause) + assert "RemoteError occurred at alice" in str(e.value.cause) + assert "caused by" not in str(e.value.cause) else: assert isinstance(e.value.cause, MyError) assert "normal task Error" in str(e.value.cause) @@ -182,8 +182,8 @@ def run3(party): def test_cross_silo_not_expose_error_trace(): - p_alice = multiprocessing.Process(target=run3, args=('alice',)) - p_bob = multiprocessing.Process(target=run3, args=('bob',)) + p_alice = multiprocessing.Process(target=run3, args=("alice",)) + p_bob = multiprocessing.Process(target=run3, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_enable_tls_across_parties.py b/fed/tests/test_enable_tls_across_parties.py index 5c3c1c3..3df349f 100644 --- a/fed/tests/test_enable_tls_across_parties.py +++ b/fed/tests/test_enable_tls_across_parties.py @@ -16,8 +16,8 @@ import os import pytest - import ray + import fed import fed._private.compatible_utils as compatible_utils @@ -38,7 +38,7 @@ def add(x, y): def _run(party: str): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") cert_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "/tmp/rayfed/test-certs/" ) @@ -49,8 +49,8 @@ def _run(party: str): } addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } fed.init(addresses=addresses, party=party, tls_config=cert_config) @@ -65,8 +65,8 @@ def _run(party: str): def test_enable_tls_across_parties(): - p_alice = multiprocessing.Process(target=_run, args=('alice',)) - p_bob = multiprocessing.Process(target=_run, args=('bob',)) + p_alice = multiprocessing.Process(target=_run, args=("alice",)) + p_bob = multiprocessing.Process(target=_run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_exit_on_failure_sending.py b/fed/tests/test_exit_on_failure_sending.py index ab1a861..493878f 100644 --- a/fed/tests/test_exit_on_failure_sending.py +++ b/fed/tests/test_exit_on_failure_sending.py @@ -13,17 +13,16 @@ # limitations under the License. import multiprocessing +import os +import signal +import sys import pytest import ray + import fed import fed._private.compatible_utils as compatible_utils -import signal - -import os -import sys - def signal_handler(sig, frame): if sig == signal.SIGTERM.value: @@ -49,10 +48,10 @@ def get_value(self): def run(party): signal.signal(signal.SIGTERM, signal_handler) - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } retry_policy = { "maxAttempts": 2, @@ -65,15 +64,15 @@ def run(party): fed.init( addresses=addresses, party=party, - logging_level='debug', + logging_level="debug", config={ - 'cross_silo_comm': { - 'grpc_retry_policy': retry_policy, - 'exit_on_sending_failure': True, - 'timeout_ms': 20 * 1000, + "cross_silo_comm": { + "grpc_retry_policy": retry_policy, + "exit_on_sending_failure": True, + "timeout_ms": 20 * 1000, }, }, - failure_handler=lambda : os.kill(os.getpid(), signal.SIGTERM) + failure_handler=lambda: os.kill(os.getpid(), signal.SIGTERM), ) o = f.party("alice").remote() @@ -85,7 +84,7 @@ def run(party): def test_exit_when_failure_on_sending(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) p_alice.start() p_alice.join() assert p_alice.exitcode == 0 diff --git a/fed/tests/test_fed_get.py b/fed/tests/test_fed_get.py index 5752f77..76386a0 100644 --- a/fed/tests/test_fed_get.py +++ b/fed/tests/test_fed_get.py @@ -16,6 +16,7 @@ import pytest import ray + import fed import fed._private.compatible_utils as compatible_utils @@ -48,16 +49,17 @@ def mean(x, y): def run(party): import time - if party == 'alice': + + if party == "alice": time.sleep(1.4) # address = 'ray://127.0.0.1:21012' if party == 'alice' else 'ray://127.0.0.1:21011' # noqa # compatible_utils.init_ray(address=address) - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:31012', - 'bob': '127.0.0.1:31011', + "alice": "127.0.0.1:31012", + "bob": "127.0.0.1:31011", } fed.init(addresses=addresses, party=party) @@ -84,8 +86,8 @@ def run(party): def test_fed_get_in_2_parties(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_grpc_options_on_proxies.py b/fed/tests/test_grpc_options_on_proxies.py index cb14e92..4dee07b 100644 --- a/fed/tests/test_grpc_options_on_proxies.py +++ b/fed/tests/test_grpc_options_on_proxies.py @@ -13,11 +13,12 @@ # limitations under the License. import multiprocessing + import pytest -import fed -import fed._private.compatible_utils as compatible_utils import ray +import fed +import fed._private.compatible_utils as compatible_utils from fed.proxy.barriers import receiver_proxy_actor_name, sender_proxy_actor_name @@ -27,34 +28,34 @@ def dummpy(): def run(party): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11019', - 'bob': '127.0.0.1:11018', + "alice": "127.0.0.1:11019", + "bob": "127.0.0.1:11018", } fed.init( addresses=addresses, party=party, config={ "cross_silo_comm": { - "grpc_channel_options": [('grpc.max_send_message_length', 100)], + "grpc_channel_options": [("grpc.max_send_message_length", 100)], }, }, ) def _assert_on_proxy(proxy_actor): config = ray.get(proxy_actor._get_proxy_config.remote()) - options = config['grpc_options'] + options = config["grpc_options"] assert ("grpc.max_send_message_length", 100) in options - assert ('grpc.so_reuseport', 0) in options + assert ("grpc.so_reuseport", 0) in options sender_proxy = ray.get_actor(sender_proxy_actor_name()) receiver_proxy = ray.get_actor(receiver_proxy_actor_name()) _assert_on_proxy(sender_proxy) _assert_on_proxy(receiver_proxy) - a = dummpy.party('alice').remote() - b = dummpy.party('bob').remote() + a = dummpy.party("alice").remote() + b = dummpy.party("bob").remote() fed.get([a, b]) fed.shutdown() @@ -62,8 +63,8 @@ def _assert_on_proxy(proxy_actor): def test_grpc_max_size_by_channel_options(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() @@ -72,10 +73,10 @@ def test_grpc_max_size_by_channel_options(): def run2(party): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11019', - 'bob': '127.0.0.1:11018', + "alice": "127.0.0.1:11019", + "bob": "127.0.0.1:11018", } fed.init( addresses=addresses, @@ -89,18 +90,18 @@ def run2(party): def _assert_on_proxy(proxy_actor): config = ray.get(proxy_actor._get_proxy_config.remote()) - options = config['grpc_options'] + options = config["grpc_options"] assert ("grpc.max_send_message_length", 100) in options assert ("grpc.max_receive_message_length", 100) in options - assert ('grpc.so_reuseport', 0) in options + assert ("grpc.so_reuseport", 0) in options sender_proxy = ray.get_actor(sender_proxy_actor_name()) receiver_proxy = ray.get_actor(receiver_proxy_actor_name()) _assert_on_proxy(sender_proxy) _assert_on_proxy(receiver_proxy) - a = dummpy.party('alice').remote() - b = dummpy.party('bob').remote() + a = dummpy.party("alice").remote() + b = dummpy.party("bob").remote() fed.get([a, b]) fed.shutdown() @@ -108,8 +109,8 @@ def _assert_on_proxy(proxy_actor): def test_grpc_max_size_by_common_config(): - p_alice = multiprocessing.Process(target=run2, args=('alice',)) - p_bob = multiprocessing.Process(target=run2, args=('bob',)) + p_alice = multiprocessing.Process(target=run2, args=("alice",)) + p_bob = multiprocessing.Process(target=run2, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() @@ -118,10 +119,10 @@ def test_grpc_max_size_by_common_config(): def run3(party): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11019', - 'bob': '127.0.0.1:11018', + "alice": "127.0.0.1:11019", + "bob": "127.0.0.1:11018", } fed.init( addresses=addresses, @@ -130,26 +131,26 @@ def run3(party): "cross_silo_comm": { "messages_max_size_in_bytes": 100, "grpc_channel_options": [ - ('grpc.max_send_message_length', 200), - ], + ("grpc.max_send_message_length", 200), + ], }, }, ) def _assert_on_proxy(proxy_actor): config = ray.get(proxy_actor._get_proxy_config.remote()) - options = config['grpc_options'] + options = config["grpc_options"] assert ("grpc.max_send_message_length", 200) in options assert ("grpc.max_receive_message_length", 100) in options - assert ('grpc.so_reuseport', 0) in options + assert ("grpc.so_reuseport", 0) in options sender_proxy = ray.get_actor(sender_proxy_actor_name()) receiver_proxy = ray.get_actor(receiver_proxy_actor_name()) _assert_on_proxy(sender_proxy) _assert_on_proxy(receiver_proxy) - a = dummpy.party('alice').remote() - b = dummpy.party('bob').remote() + a = dummpy.party("alice").remote() + b = dummpy.party("bob").remote() fed.get([a, b]) fed.shutdown() @@ -157,8 +158,8 @@ def _assert_on_proxy(proxy_actor): def test_grpc_max_size_by_both_config(): - p_alice = multiprocessing.Process(target=run3, args=('alice',)) - p_bob = multiprocessing.Process(target=run3, args=('bob',)) + p_alice = multiprocessing.Process(target=run3, args=("alice",)) + p_bob = multiprocessing.Process(target=run3, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_internal_kv.py b/fed/tests/test_internal_kv.py index bb04823..460372a 100644 --- a/fed/tests/test_internal_kv.py +++ b/fed/tests/test_internal_kv.py @@ -1,17 +1,19 @@ import multiprocessing +import time + import pytest import ray +import ray.experimental.internal_kv as ray_internal_kv + import fed -import time import fed._private.compatible_utils as compatible_utils -import ray.experimental.internal_kv as ray_internal_kv def run(party): compatible_utils.init_ray("local") addresses = { - 'alice': '127.0.0.1:11010', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11010", + "bob": "127.0.0.1:11011", } assert compatible_utils.kv is None fed.init(addresses=addresses, party=party, job_name="test_job_name") @@ -21,8 +23,10 @@ def run(party): # Test that a prefix key name is added under the hood. assert ray_internal_kv._internal_kv_get(b"test_key") is None - assert ray_internal_kv._internal_kv_get( - b"RAYFED#test_job_name#test_key") == b"test_val" + assert ( + ray_internal_kv._internal_kv_get(b"RAYFED#test_job_name#test_key") + == b"test_val" + ) time.sleep(5) fed.shutdown() @@ -35,8 +39,8 @@ def run(party): def test_kv_init(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_listening_address.py b/fed/tests/test_listening_address.py index a278722..c78d1b5 100644 --- a/fed/tests/test_listening_address.py +++ b/fed/tests/test_listening_address.py @@ -16,6 +16,7 @@ import pytest import ray + import fed import fed._private.compatible_utils as compatible_utils @@ -23,7 +24,7 @@ def _run(party): import socket - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") occupied_port = 11020 # NOTE(NKcqx): Firstly try to bind IPv6 because the grpc server will do so. # Otherwise this UT will fail because socket bind $occupied_port @@ -36,9 +37,7 @@ def _run(party): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("127.0.0.1", occupied_port)) - addresses = { - 'alice': f'127.0.0.1:{occupied_port}' - } + addresses = {"alice": f"127.0.0.1:{occupied_port}"} # Starting grpc server on an used port will cause AssertionError with pytest.raises(AssertionError): @@ -56,7 +55,7 @@ def _run(party): def test_listen_used_address(): - p_alice = multiprocessing.Process(target=_run, args=('alice',)) + p_alice = multiprocessing.Process(target=_run, args=("alice",)) p_alice.start() p_alice.join() assert p_alice.exitcode == 0 diff --git a/fed/tests/test_options.py b/fed/tests/test_options.py index 13bbb6f..87385ca 100644 --- a/fed/tests/test_options.py +++ b/fed/tests/test_options.py @@ -16,6 +16,7 @@ import pytest import ray + import fed import fed._private.compatible_utils as compatible_utils @@ -32,10 +33,10 @@ def bar(x): def run(party): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } fed.init(addresses=addresses, party=party) @@ -51,8 +52,8 @@ def run(party): def test_fed_get_in_2_parties(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_pass_fed_objects_in_containers_in_actor.py b/fed/tests/test_pass_fed_objects_in_containers_in_actor.py index 08d3a2d..31ecb98 100644 --- a/fed/tests/test_pass_fed_objects_in_containers_in_actor.py +++ b/fed/tests/test_pass_fed_objects_in_containers_in_actor.py @@ -15,8 +15,8 @@ import multiprocessing import pytest - import ray + import fed import fed._private.compatible_utils as compatible_utils @@ -38,13 +38,13 @@ def bar(self, li): addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } def run(party): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") fed.init(addresses=addresses, party=party) my1 = My.party("alice").remote() my2 = My.party("bob").remote() @@ -60,8 +60,8 @@ def run(party): def test_pass_fed_objects_in_list(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py b/fed/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py index 7b78cb6..cbdc5f3 100644 --- a/fed/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py +++ b/fed/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py @@ -15,8 +15,8 @@ import multiprocessing import pytest - import ray + import fed import fed._private.compatible_utils as compatible_utils @@ -38,10 +38,10 @@ def bar(li): def run(party): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } fed.init(addresses=addresses, party=party) o1 = foo.party("alice").remote(0) @@ -56,8 +56,8 @@ def run(party): def test_pass_fed_objects_in_list(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_ping_others.py b/fed/tests/test_ping_others.py index 4753dde..0c1ff49 100644 --- a/fed/tests/test_ping_others.py +++ b/fed/tests/test_ping_others.py @@ -12,42 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import multiprocessing +import time + +import pytest +import ray + import fed import fed._private.compatible_utils as compatible_utils -import ray -import time from fed.proxy.barriers import ping_others - addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } def test_ping_non_started_party(): def run(party): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") fed.init(addresses=addresses, party=party) - if (party == 'alice'): + if party == "alice": with pytest.raises(RuntimeError): ping_others(addresses, party, 5) fed.shutdown() ray.shutdown() - p_alice = multiprocessing.Process(target=run, args=('alice',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) p_alice.start() p_alice.join() def test_ping_started_party(): def run(party): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") fed.init(addresses=addresses, party=party) - if (party == 'alice'): + if party == "alice": ping_success = ping_others(addresses, party, 5) assert ping_success is True else: @@ -57,8 +58,8 @@ def run(party): fed.shutdown() ray.shutdown() - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_repeat_init.py b/fed/tests/test_repeat_init.py index 8926c78..e3b6601 100644 --- a/fed/tests/test_repeat_init.py +++ b/fed/tests/test_repeat_init.py @@ -16,9 +16,10 @@ import multiprocessing import pytest +import ray + import fed import fed._private.compatible_utils as compatible_utils -import ray @fed.remote @@ -38,14 +39,14 @@ def bar(self, li): addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } def run(party): def _run(): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") fed.init(addresses=addresses, party=party) my1 = My.party("alice").remote() @@ -66,8 +67,8 @@ def _run(): def test_pass_fed_objects_in_list(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_reset_context.py b/fed/tests/test_reset_context.py index 95c6e53..b1f75f7 100644 --- a/fed/tests/test_reset_context.py +++ b/fed/tests/test_reset_context.py @@ -1,12 +1,14 @@ import multiprocessing -import fed + +import pytest import ray + +import fed import fed._private.compatible_utils as compatible_utils -import pytest addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } @@ -20,17 +22,15 @@ def get(self): def run(party): - compatible_utils.init_ray(address='local') - fed.init( - addresses=addresses, - party=party) + compatible_utils.init_ray(address="local") + fed.init(addresses=addresses, party=party) - actor = A.party('alice').remote(10) + actor = A.party("alice").remote(10) alice_fed_obj = actor.get.remote() alice_first_fed_obj_id = alice_fed_obj.get_fed_task_id() assert fed.get(alice_fed_obj) == 10 - actor = A.party('bob').remote(12) + actor = A.party("bob").remote(12) bob_fed_obj = actor.get.remote() bob_first_fed_obj_id = bob_fed_obj.get_fed_task_id() assert fed.get(bob_fed_obj) == 12 @@ -44,18 +44,16 @@ def run(party): # `AttributeError` compatible_utils.kv.put("key2", "val2") - compatible_utils.init_ray(address='local') - fed.init( - addresses=addresses, - party=party) + compatible_utils.init_ray(address="local") + fed.init(addresses=addresses, party=party) - actor = A.party('alice').remote(10) + actor = A.party("alice").remote(10) alice_fed_obj = actor.get.remote() alice_second_fed_obj_id = alice_fed_obj.get_fed_task_id() assert fed.get(alice_fed_obj) == 10 assert alice_first_fed_obj_id == alice_second_fed_obj_id - actor = A.party('bob').remote(12) + actor = A.party("bob").remote(12) bob_fed_obj = actor.get.remote() bob_second_fed_obj_id = bob_fed_obj.get_fed_task_id() assert fed.get(bob_fed_obj) == 12 @@ -70,8 +68,8 @@ def run(party): def test_reset_context(): - p_alice = multiprocessing.Process(target=run, args=('alice', )) - p_bob = multiprocessing.Process(target=run, args=('bob', )) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() import time @@ -85,4 +83,5 @@ def test_reset_context(): if __name__ == "__main__": import sys + sys.exit(pytest.main(["-sv", __file__])) diff --git a/fed/tests/test_retry_policy.py b/fed/tests/test_retry_policy.py index 574ce9f..d33b5f7 100644 --- a/fed/tests/test_retry_policy.py +++ b/fed/tests/test_retry_policy.py @@ -15,6 +15,7 @@ import multiprocessing from unittest import TestCase + import pytest import ray @@ -38,10 +39,10 @@ def get_value(self): def run(): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11012", + "bob": "127.0.0.1:11011", } retry_policy = { "maxAttempts": 4, @@ -50,13 +51,13 @@ def run(): "backoffMultiplier": 1, "retryableStatusCodes": ["UNAVAILABLE"], } - test_job_name = 'test_retry_policy' + test_job_name = "test_retry_policy" fed.init( addresses=addresses, - party='alice', + party="alice", config={ - 'cross_silo_comm': { - 'grpc_retry_policy': retry_policy, + "cross_silo_comm": { + "grpc_retry_policy": retry_policy, } }, ) @@ -64,7 +65,7 @@ def run(): job_config = config.get_job_config(test_job_name) cross_silo_comm_config = job_config.cross_silo_comm_config_dict TestCase().assertDictEqual( - cross_silo_comm_config['grpc_retry_policy'], retry_policy + cross_silo_comm_config["grpc_retry_policy"], retry_policy ) fed.shutdown() diff --git a/fed/tests/test_setup_proxy_actor.py b/fed/tests/test_setup_proxy_actor.py index ca72ea2..2901825 100644 --- a/fed/tests/test_setup_proxy_actor.py +++ b/fed/tests/test_setup_proxy_actor.py @@ -24,10 +24,10 @@ def run(party): - compatible_utils.init_ray(address='local', resources={"127.0.0.1": 2}) + compatible_utils.init_ray(address="local", resources={"127.0.0.1": 2}) addresses = { - 'alice': '127.0.0.1:11010', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11010", + "bob": "127.0.0.1:11011", } fed.init( addresses=addresses, @@ -42,10 +42,10 @@ def run(party): def run_failure(party): - compatible_utils.init_ray(address='local', resources={"127.0.0.1": 1}) + compatible_utils.init_ray(address="local", resources={"127.0.0.1": 1}) addresses = { - 'alice': '127.0.0.1:11010', - 'bob': '127.0.0.1:11011', + "alice": "127.0.0.1:11010", + "bob": "127.0.0.1:11011", } sender_proxy_resources = {"127.0.0.2": 1} # Insufficient resource receiver_proxy_resources = {"127.0.0.2": 1} # Insufficient resource @@ -54,10 +54,10 @@ def run_failure(party): addresses=addresses, party=party, config={ - 'cross_silo_comm': { - 'send_resource_label': sender_proxy_resources, - 'recv_resource_label': receiver_proxy_resources, - 'timeout_in_ms': 10 * 1000, + "cross_silo_comm": { + "send_resource_label": sender_proxy_resources, + "recv_resource_label": receiver_proxy_resources, + "timeout_in_ms": 10 * 1000, } }, ) @@ -67,8 +67,8 @@ def run_failure(party): def test_setup_proxy_success(): - p_alice = multiprocessing.Process(target=run, args=('alice',)) - p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice = multiprocessing.Process(target=run, args=("alice",)) + p_bob = multiprocessing.Process(target=run, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() @@ -77,8 +77,8 @@ def test_setup_proxy_success(): def test_setup_proxy_failed(): - p_alice = multiprocessing.Process(target=run_failure, args=('alice',)) - p_bob = multiprocessing.Process(target=run_failure, args=('bob',)) + p_alice = multiprocessing.Process(target=run_failure, args=("alice",)) + p_bob = multiprocessing.Process(target=run_failure, args=("bob",)) p_alice.start() p_bob.start() p_alice.join() diff --git a/fed/tests/test_transport_proxy.py b/fed/tests/test_transport_proxy.py index bb6f3f2..d3d3009 100644 --- a/fed/tests/test_transport_proxy.py +++ b/fed/tests/test_transport_proxy.py @@ -31,7 +31,8 @@ from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy, GrpcSenderProxy if compatible_utils._compare_version_strings( - fed_utils.get_package_version('protobuf'), '4.0.0'): + fed_utils.get_package_version("protobuf"), "4.0.0" +): from fed.grpc.pb4 import fed_pb2 as fed_pb2 from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc else: @@ -44,9 +45,9 @@ def test_n_to_1_transport(): sending data to the target receiver proxy, and there also have N receivers to `get_data` from receiver proxy at that time. """ - compatible_utils.init_ray(address='local') - test_job_name = 'test_n_to_1_transport' - party = 'test_party' + compatible_utils.init_ray(address="local") + test_job_name = "test_n_to_1_transport" + party = "test_party" global_context.init_global_context(party, test_job_name) global_context.get_global_context().get_cleanup_manager().start() cluster_config = { @@ -62,18 +63,18 @@ def test_n_to_1_transport(): NUM_DATA = 10 SERVER_ADDRESS = "127.0.0.1:12344" - addresses = {'test_party': SERVER_ADDRESS} + addresses = {"test_party": SERVER_ADDRESS} _start_receiver_proxy( addresses, party, - logging_level='info', + logging_level="info", proxy_cls=GrpcReceiverProxy, proxy_config={}, ) _start_sender_proxy( addresses, party, - logging_level='info', + logging_level="info", proxy_cls=GrpcSenderProxy, proxy_config={}, ) @@ -99,8 +100,9 @@ def test_n_to_1_transport(): class TestSendDataService(fed_pb2_grpc.GrpcServiceServicer): - def __init__(self, all_events, all_data, party, lock, - expected_metadata, expected_jobname): + def __init__( + self, all_events, all_data, party, lock, expected_metadata, expected_jobname + ): self.expected_metadata = expected_metadata or {} self._expected_jobname = expected_jobname or "" @@ -109,8 +111,9 @@ async def SendData(self, request, context): assert self._expected_jobname == job_name metadata = dict(context.invocation_metadata()) for k, v in self.expected_metadata.items(): - assert k in metadata, \ - f"The expected key {k} is not in the metadata keys: {metadata.keys()}." + assert ( + k in metadata + ), f"The expected key {k} is not in the metadata keys: {metadata.keys()}." assert v == metadata[k] event = asyncio.Event() event.set() @@ -129,11 +132,12 @@ async def _test_run_grpc_server( ): server = grpc.aio.server(options=grpc_options) fed_pb2_grpc.add_GrpcServiceServicer_to_server( - TestSendDataService(event, all_data, party, lock, - expected_metadata, expected_jobname), - server + TestSendDataService( + event, all_data, party, lock, expected_metadata, expected_jobname + ), + server, ) - server.add_insecure_port(f'[::]:{port}') + server.add_insecure_port(f"[::]:{port}") await server.start() await server.wait_for_termination() @@ -154,13 +158,13 @@ def __init__( async def run_grpc_server(self): return await _test_run_grpc_server( - self._listen_addr[self._listen_addr.index(':') + 1:], + self._listen_addr[self._listen_addr.index(":") + 1 :], None, None, self._party, None, expected_metadata=self._expected_metadata, - expected_jobname=self._expected_jobname + expected_jobname=self._expected_jobname, ) async def is_ready(self): @@ -178,27 +182,30 @@ def _test_start_receiver_proxy( address = addresses[party] receiver_proxy_actor = TestReceiverProxyActor.options( name=receiver_proxy_actor_name(), max_concurrency=1000 - ).remote(listen_addr=address, party=party, - expected_metadata=expected_metadata, - expected_jobname=expected_jobname) + ).remote( + listen_addr=address, + party=party, + expected_metadata=expected_metadata, + expected_jobname=expected_jobname, + ) receiver_proxy_actor.run_grpc_server.remote() assert ray.get(receiver_proxy_actor.is_ready.remote()) def test_send_grpc_with_meta(): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") cluster_config = { constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: "", } metadata = {"key": "value"} - config = {'http_header': metadata} + config = {"http_header": metadata} job_config = { constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT: config, } - test_job_name = 'test_send_grpc_with_meta' - party_name = 'test_party' + test_job_name = "test_send_grpc_with_meta" + party_name = "test_party" global_context.init_global_context(party_name, test_job_name) compatible_utils._init_internal_kv(test_job_name) compatible_utils.kv.put( @@ -214,12 +221,12 @@ def test_send_grpc_with_meta(): addresses, party_name, expected_metadata=metadata, - expected_jobname=test_job_name + expected_jobname=test_job_name, ) _start_sender_proxy( addresses, party_name, - logging_level='info', + logging_level="info", proxy_cls=GrpcSenderProxy, proxy_config=config, ) diff --git a/fed/tests/test_transport_proxy_tls.py b/fed/tests/test_transport_proxy_tls.py index e4af7ce..97b6b70 100644 --- a/fed/tests/test_transport_proxy_tls.py +++ b/fed/tests/test_transport_proxy_tls.py @@ -34,8 +34,8 @@ def test_n_to_1_transport(): sending data to the target receiver proxy, and there also have N receivers to `get_data` from receiver proxy at that time. """ - compatible_utils.init_ray(address='local') - test_job_name = 'test_n_to_1_transport' + compatible_utils.init_ray(address="local") + test_job_name = "test_n_to_1_transport" cert_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "/tmp/rayfed/test-certs/" ) @@ -44,7 +44,7 @@ def test_n_to_1_transport(): "cert": os.path.join(cert_dir, "server.crt"), "key": os.path.join(cert_dir, "server.key"), } - party = 'test_party' + party = "test_party" cluster_config = { constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", @@ -59,11 +59,11 @@ def test_n_to_1_transport(): NUM_DATA = 10 SERVER_ADDRESS = "127.0.0.1:65422" - addresses = {'test_party': SERVER_ADDRESS} + addresses = {"test_party": SERVER_ADDRESS} _start_receiver_proxy( addresses, party, - logging_level='info', + logging_level="info", tls_config=tls_config, proxy_cls=GrpcReceiverProxy, proxy_config={}, @@ -71,7 +71,7 @@ def test_n_to_1_transport(): _start_sender_proxy( addresses, party, - logging_level='info', + logging_level="info", tls_config=tls_config, proxy_cls=GrpcSenderProxy, proxy_config={}, diff --git a/fed/tests/test_utils.py b/fed/tests/test_utils.py index f17f1a6..3e4b35f 100644 --- a/fed/tests/test_utils.py +++ b/fed/tests/test_utils.py @@ -13,25 +13,26 @@ # limitations under the License. import time + import pytest import fed.utils as fed_utils def start_ray_cluster( - ray_port, - client_server_port, - dashboard_port, + ray_port, + client_server_port, + dashboard_port, ): command = [ - 'ray', - 'start', - '--head', - f'--port={ray_port}', - f'--ray-client-server-port={client_server_port}', - f'--dashboard-port={dashboard_port}', + "ray", + "start", + "--head", + f"--port={ray_port}", + f"--ray-client-server-port={client_server_port}", + f"--dashboard-port={dashboard_port}", ] - command_str = ' '.join(command) + command_str = " ".join(command) try: _ = fed_utils.start_command(command_str) except RuntimeError as e: @@ -45,8 +46,9 @@ def start_ray_cluster( # container, you can increase /dev/shm size by passing '--shm-size=1.97gb' to # 'docker run' (or add it to the run_options list in a Ray cluster config). # Make sure to set this to more than 0% of available RAM. - assert 'Overwriting previous Ray address' in str(e) \ - or 'WARNING: The object store is using /tmp instead of /dev/shm' in str(e) + assert "Overwriting previous Ray address" in str( + e + ) or "WARNING: The object store is using /tmp instead of /dev/shm" in str(e) @pytest.fixture @@ -57,4 +59,4 @@ def ray_client_mode_setup(): start_ray_cluster(ray_port=41011, client_server_port=21011, dashboard_port=9111) yield - fed_utils.start_command('ray stop --force') + fed_utils.start_command("ray stop --force") diff --git a/fed/tests/without_ray_tests/test_tree_utils.py b/fed/tests/without_ray_tests/test_tree_utils.py index 41729ad..416b5d3 100644 --- a/fed/tests/without_ray_tests/test_tree_utils.py +++ b/fed/tests/without_ray_tests/test_tree_utils.py @@ -1,4 +1,3 @@ - # Copyright 2023 The RayFed Team # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, List, Tuple, Union + import pytest -from typing import Any, Union, List, Tuple, Dict import fed.tree_util as tree_utils @@ -28,7 +28,6 @@ def test_flatten_none(): def test_flatten_single_primivite_elements(): - def _assert_flatten_single_element(target: Any): li, tree_def = tree_utils.tree_flatten(target) assert isinstance(li, list) diff --git a/fed/tests/without_ray_tests/test_utils.py b/fed/tests/without_ray_tests/test_utils.py index b00e042..9823325 100644 --- a/fed/tests/without_ray_tests/test_utils.py +++ b/fed/tests/without_ray_tests/test_utils.py @@ -1,4 +1,3 @@ - # Copyright 2023 The RayFed Team # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,18 +17,21 @@ import fed -@pytest.mark.parametrize("input_address, is_valid_address", [ - ("192.168.0.1:8080", True), - ("sa127032as:80", True), - ("https://www.example.com", True), - ("http://www.example.com", True), - ("local", True), - ("localhost", True), - (None, False), - ("invalid_string", False), - ("http", False), - ("example.com", False), -]) +@pytest.mark.parametrize( + "input_address, is_valid_address", + [ + ("192.168.0.1:8080", True), + ("sa127032as:80", True), + ("https://www.example.com", True), + ("http://www.example.com", True), + ("local", True), + ("localhost", True), + (None, False), + ("invalid_string", False), + ("http", False), + ("example.com", False), + ], +) def test_validate_address(input_address, is_valid_address): if is_valid_address: fed.utils.validate_address(input_address) @@ -43,4 +45,5 @@ def test_validate_address(input_address, is_valid_address): if __name__ == "__main__": import sys + sys.exit(pytest.main(["-sv", __file__])) diff --git a/fed/tree_util.py b/fed/tree_util.py index 14a99fc..3f8dd60 100644 --- a/fed/tree_util.py +++ b/fed/tree_util.py @@ -14,15 +14,14 @@ # Most codes are copied from https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/_pytree.py # noqa -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, TypeVar -from collections import namedtuple, OrderedDict +from collections import OrderedDict, namedtuple from dataclasses import dataclass +from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Type, TypeVar, cast - -T = TypeVar('T') -S = TypeVar('S') -U = TypeVar('U') -R = TypeVar('R') +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") +R = TypeVar("R") """ Contains utility functions for working with nested python data structures. @@ -63,9 +62,8 @@ class NodeDef(NamedTuple): def _register_pytree_node( - typ: Any, - flatten_fn: FlattenFunc, - unflatten_fn: UnflattenFunc) -> None: + typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc +) -> None: SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) @@ -101,11 +99,11 @@ def _namedtuple_unflatten(values: List[Any], context: Context) -> NamedTuple: return cast(NamedTuple, context(*values)) -def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Context]: +def _odict_flatten(d: "OrderedDict[Any, Any]") -> Tuple[List[Any], Context]: return list(d.values()), list(d.keys()) -def _odict_unflatten(values: List[Any], context: Context) -> 'OrderedDict[Any, Any]': +def _odict_unflatten(values: List[Any], context: Context) -> "OrderedDict[Any, Any]": return OrderedDict((key, value) for key, value in zip(context, values)) @@ -122,7 +120,7 @@ def _is_namedtuple_instance(pytree: Any) -> bool: bases = typ.__bases__ if len(bases) != 1 or bases[0] != tuple: return False - fields = getattr(typ, '_fields', None) + fields = getattr(typ, "_fields", None) if not isinstance(fields, tuple): return False return all(isinstance(entry, str) for entry in fields) @@ -148,22 +146,25 @@ def _is_leaf(pytree: PyTree) -> bool: class TreeSpec: type: Any context: Context - children_specs: List['TreeSpec'] + children_specs: List["TreeSpec"] def __post_init__(self) -> None: self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f'TreeSpec({self.type.__name__}, {self.context}, [' - children_specs_str: str = '' + repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" + children_specs_str: str = "" if len(self.children_specs): indent += len(repr_prefix) children_specs_str += self.children_specs[0].__repr__(indent) - children_specs_str += ',' if len(self.children_specs) > 1 else '' - children_specs_str += ','.join( - ['\n' + ' ' * indent + child.__repr__(indent) - for child in self.children_specs[1:]]) - repr_suffix: str = f'{children_specs_str}])' + children_specs_str += "," if len(self.children_specs) > 1 else "" + children_specs_str += ",".join( + [ + "\n" + " " * indent + child.__repr__(indent) + for child in self.children_specs[1:] + ] + ) + repr_suffix: str = f"{children_specs_str}])" return repr_prefix + repr_suffix @@ -173,7 +174,7 @@ def __init__(self) -> None: self.num_leaves = 1 def __repr__(self, indent: int = 0) -> str: - return '*' + return "*" def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: @@ -188,8 +189,8 @@ def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: child_pytrees, context = flatten_fn(pytree) # Recursively flatten the children - result : List[Any] = [] - children_specs : List['TreeSpec'] = [] + result: List[Any] = [] + children_specs: List["TreeSpec"] = [] for child in child_pytrees: flat, child_spec = tree_flatten(child) result += flat @@ -204,13 +205,15 @@ def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: """ if not isinstance(spec, TreeSpec): raise ValueError( - f'tree_unflatten(values, spec): Expected `spec` to be instance of ' - f'TreeSpec but got item of type {type(spec)}.') + f"tree_unflatten(values, spec): Expected `spec` to be instance of " + f"TreeSpec but got item of type {type(spec)}." + ) if len(values) != spec.num_leaves: raise ValueError( - f'tree_unflatten(values, spec): `values` has length {len(values)} ' - f'but the spec refers to a pytree that holds {spec.num_leaves} ' - f'items ({spec}).') + f"tree_unflatten(values, spec): `values` has length {len(values)} " + f"but the spec refers to a pytree that holds {spec.num_leaves} " + f"items ({spec})." + ) if isinstance(spec, LeafSpec): return values[0] diff --git a/fed/utils.py b/fed/utils.py index b5450f2..7aa20b2 100644 --- a/fed/utils.py +++ b/fed/utils.py @@ -14,8 +14,8 @@ import logging import re -import sys import subprocess +import sys import ray @@ -35,7 +35,7 @@ def get_package_version(package_name: str) -> str: When using Python 3.8 and above, it uses `importlib.metadata`. """ curr_python_version = sys.version.split(" ")[0] - if _compare_version_strings(curr_python_version, '3.7.99'): + if _compare_version_strings(curr_python_version, "3.7.99"): import importlib.metadata return importlib.metadata.version(package_name) @@ -59,8 +59,8 @@ def resolve_dependencies(current_party, current_fed_task_id, *args, **kwargs): resolved.append(arg.get_ray_object_ref()) else: logger.debug( - f'Insert recv_op, arg task id {arg.get_fed_task_id()}, current ' - f'task id {current_fed_task_id}' + f"Insert recv_op, arg task id {arg.get_fed_task_id()}, current " + f"task id {current_fed_task_id}" ) if arg.get_ray_object_ref() is not None: # This code path indicates the ray object is already received in @@ -200,21 +200,21 @@ def validate_address(address: str) -> None: raise ValueError("The address shouldn't be None.") # The specific case for `local` or `localhost`. - if address == 'local' or address == 'localhost': + if address == "local" or address == "localhost": return # Rule 1: "ip:port" format - ip_port_pattern = r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d+$' + ip_port_pattern = r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d+$" if re.match(ip_port_pattern, address): return # Rule 2: "hostname:port" format - hostname_port_pattern = r'^[a-zA-Z0-9.-]+:\d+$' + hostname_port_pattern = r"^[a-zA-Z0-9.-]+:\d+$" if re.match(hostname_port_pattern, address): return # Rule 3: https or http link - link_pattern = r'^(https?://).*' + link_pattern = r"^(https?://).*" if re.match(link_pattern, address): return @@ -235,21 +235,20 @@ def validate_addresses(addresses: dict): for address in addresses.values(): assert ( isinstance(address, str) and address - ), f'Address should be string but got {address}.' + ), f"Address should be string but got {address}." validate_address(address) -def start_command(command: str, timeout=60) : +def start_command(command: str, timeout=60): """ A util to start a shell command. """ process = subprocess.Popen( - command, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) output, error = process.communicate(timeout=timeout) if len(error) != 0: raise RuntimeError( - f'Failed to start command [{command}], the error is:\n {error.decode()}') + f"Failed to start command [{command}], the error is:\n {error.decode()}" + ) return output diff --git a/setup.py b/setup.py index 468ff6a..0b87365 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ VERSION = BASE_VERSION + ".dev0" this_directory = os.path.abspath(os.path.dirname(__file__)) -with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f: +with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: long_description = f.read() plat_name = "any" @@ -38,7 +38,7 @@ def read_requirements(): requirements = [] - with open('requirements.txt') as file: + with open("requirements.txt") as file: requirements = file.read().splitlines() print("Requirements: ", requirements) return requirements @@ -58,23 +58,23 @@ def finalize_options(self): self._cwd = os.getcwd() def run(self): - assert os.getcwd() == self._cwd, 'Must be in package root: %s' % self._cwd - os.system('rm -rf ./build ./dist') + assert os.getcwd() == self._cwd, "Must be in package root: %s" % self._cwd + os.system("rm -rf ./build ./dist") setup( name=package_name, version=VERSION, - license='Apache 2.0', - description='A multiple parties joint, distributed execution engine based on Ray,' - 'to help build your own federated learning frameworks in minutes.', + license="Apache 2.0", + description="A multiple parties joint, distributed execution engine based on Ray," + "to help build your own federated learning frameworks in minutes.", long_description=long_description, - long_description_content_type='text/markdown', - author='RayFed Team', - author_email='rayfed-dev@googlegroups.com', - url='https://github.com/ray-project/rayfed', - packages=find_packages(exclude=('examples', 'tests', 'tests.*')), + long_description_content_type="text/markdown", + author="RayFed Team", + author_email="rayfed-dev@googlegroups.com", + url="https://github.com/ray-project/rayfed", + packages=find_packages(exclude=("examples", "tests", "tests.*")), install_requires=read_requirements(), - extras_require={'dev': ['pylint']}, - options={'bdist_wheel': {'plat_name': plat_name}}, + extras_require={"dev": ["pylint"]}, + options={"bdist_wheel": {"plat_name": plat_name}}, )