Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ReceiverSenderProxy. #168

Merged
merged 5 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fed/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

KEY_OF_TLS_CONFIG = "TLS_CONFIG"

KEY_OF_CROSS_SILO_MESSAGE_CONFIG = "CROSS_SILO_MESSAGE_CONFIG"
KEY_OF_CROSS_SILO_COMM_CONFIG_DICT = "CROSS_SILO_COMM_CONFIG_DICT"
NKcqx marked this conversation as resolved.
Show resolved Hide resolved

RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- %(message)s" # noqa

Expand Down
13 changes: 9 additions & 4 deletions fed/_private/serialization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _restricted_loads(
buffers=None,
):
from sys import version_info

assert version_info.major == 3

if version_info.minor >= 8:
Expand All @@ -41,8 +42,10 @@ class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if _pickle_whitelist is None or (
module in _pickle_whitelist
and (_pickle_whitelist[module] is None or name in _pickle_whitelist[
module])
and (
_pickle_whitelist[module] is None
or name in _pickle_whitelist[module]
)
):
return super().find_class(module, name)

Expand All @@ -63,8 +66,10 @@ def find_class(self, module, name):
def _apply_loads_function_with_whitelist():
global _pickle_whitelist

_pickle_whitelist = fed_config.get_job_config() \
.cross_silo_message_config.serializing_allowed_list
cross_silo_comm_config = fed_config.CrossSiloMessageConfig.from_dict(
fed_config.get_job_config().cross_silo_comm_config_dict
)
_pickle_whitelist = cross_silo_comm_config.serializing_allowed_list
if _pickle_whitelist is None:
return

Expand Down
102 changes: 63 additions & 39 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@
send,
_start_receiver_proxy,
_start_sender_proxy,
_start_sender_receiver_proxy,
set_receiver_proxy_actor_name,
set_sender_proxy_actor_name,
)
from fed.proxy.grpc.grpc_proxy import SenderProxy, ReceiverProxy
from fed.config import GrpcCrossSiloMessageConfig
from fed.proxy.base_proxy import SenderProxy, ReceiverProxy, SenderReceiverProxy
from fed.config import CrossSiloMessageConfig
from fed.fed_object import FedObject
from fed.utils import is_ray_object_refs, setup_logger

Expand All @@ -50,6 +53,7 @@ def init(
logging_level: str = 'info',
sender_proxy_cls: SenderProxy = None,
receiver_proxy_cls: ReceiverProxy = None,
receiver_sender_proxy_cls: SenderReceiverProxy = None,
):
"""
Initialize a RayFed client.
Expand Down Expand Up @@ -112,9 +116,7 @@ def init(
'cert' in tls_config and 'key' in tls_config
), 'Cert or key are not in tls_config.'

cross_silo_message_dict = config.get("cross_silo_message", {})
cross_silo_message_config = GrpcCrossSiloMessageConfig.from_dict(
cross_silo_message_dict)
cross_silo_comm_dict = config.get("cross_silo_comm", {})
# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv()

Expand All @@ -125,11 +127,11 @@ def init(
}

job_config = {
constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG:
cross_silo_message_config,
constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT: cross_silo_comm_dict,
}
compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG,
cloudpickle.dumps(cluster_config))
compatible_utils.kv.put(
constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)
)
compatible_utils.kv.put(constants.KEY_OF_JOB_CONFIG, cloudpickle.dumps(job_config))
# Set logger.
# Note(NKcqx): This should be called after internal_kv has party value, i.e.
Expand All @@ -143,39 +145,61 @@ def init(
)

logger.info(f'Started rayfed with {cluster_config}')
cross_silo_comm_config = CrossSiloMessageConfig.from_dict(cross_silo_comm_dict)
get_global_context().get_cleanup_manager().start(
exit_when_failure_sending=cross_silo_message_config.exit_on_sending_failure) # noqa

if receiver_proxy_cls is None:
logger.debug(
"There is no receiver proxy class specified, it uses `GrpcRecvProxy` by "
"default.")
from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy
receiver_proxy_cls = GrpcReceiverProxy
_start_receiver_proxy(
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=receiver_proxy_cls,
proxy_config=cross_silo_message_config
exit_on_sending_failure=cross_silo_comm_config.exit_on_sending_failure
)
if receiver_sender_proxy_cls is not None:
proxy_actor_name = 'sender_recevier_actor'
set_sender_proxy_actor_name(proxy_actor_name)
set_receiver_proxy_actor_name(proxy_actor_name)
_start_sender_receiver_proxy(
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=receiver_sender_proxy_cls,
proxy_config=cross_silo_comm_dict,
ready_timeout_second=cross_silo_comm_config.timeout_in_ms / 1000,
)
else:
if receiver_proxy_cls is None:
logger.debug(
(
"There is no receiver proxy class specified, "
"it uses `GrpcRecvProxy` by default."
)
)
from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy

receiver_proxy_cls = GrpcReceiverProxy
_start_receiver_proxy(
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=receiver_proxy_cls,
proxy_config=cross_silo_comm_dict,
ready_timeout_second=cross_silo_comm_config.timeout_in_ms / 1000,
)

if sender_proxy_cls is None:
logger.debug(
"There is no sender proxy class specified, it uses `GrpcRecvProxy` by "
"default.")
from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy
sender_proxy_cls = GrpcSenderProxy
_start_sender_proxy(
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=sender_proxy_cls,
# TODO(qwang): proxy_config -> cross_silo_message_config
proxy_config=cross_silo_message_config
)
if sender_proxy_cls is None:
logger.debug(
"There is no sender proxy class specified, it uses `GrpcRecvProxy` by "
"default."
)
from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy

sender_proxy_cls = GrpcSenderProxy
_start_sender_proxy(
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=sender_proxy_cls,
proxy_config=cross_silo_comm_dict,
ready_timeout_second=cross_silo_comm_config.timeout_in_ms / 1000,
)

if config.get("barrier_on_initializing", False):
# TODO(zhouaihui): can be removed after we have a better retry strategy.
Expand Down
6 changes: 3 additions & 3 deletions fed/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(self) -> None:
self._check_send_thread = None
self._monitor_thread = None

def start(self, exit_when_failure_sending=False):
self._exit_when_failure_sending = exit_when_failure_sending
def start(self, exit_on_sending_failure=False):
self._exit_on_sending_failure = exit_on_sending_failure

def __check_func():
self._check_sending_objs()
Expand Down Expand Up @@ -98,7 +98,7 @@ def _signal_exit():
except Exception as e:
logger.warn(f'Failed to send {obj_ref} with error: {e}')
res = False
if not res and self._exit_when_failure_sending:
if not res and self._exit_on_sending_failure:
logger.warn('Signal self to exit.')
_signal_exit()
break
Expand Down
24 changes: 10 additions & 14 deletions fed/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


"""This module should be cached locally due to all configurations
are mutable.
"""
Expand All @@ -10,11 +8,12 @@
import json

from typing import Dict, List, Optional
from dataclasses import dataclass
from dataclasses import dataclass, fields


class ClusterConfig:
"""A local cache of cluster configuration items."""

def __init__(self, raw_bytes: bytes) -> None:
self._data = cloudpickle.loads(raw_bytes)

Expand All @@ -39,10 +38,8 @@ def __init__(self, raw_bytes: bytes) -> None:
self._data = cloudpickle.loads(raw_bytes)

@property
def cross_silo_message_config(self):
return self._data.get(
fed_constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG,
CrossSiloMessageConfig())
def cross_silo_comm_config_dict(self) -> Dict:
return self._data.get(fed_constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT, {})


# A module level cache for the cluster configurations.
Expand Down Expand Up @@ -103,7 +100,9 @@ class CrossSiloMessageConfig:
http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request.
This won't override basic tcp headers, such as `user-agent`, but concat
them together.
max_concurrency: the max_concurrency of the sender/receiver proxy actor.
"""

proxy_max_restarts: int = None
timeout_in_ms: int = 60000
messages_max_size_in_bytes: int = None
Expand All @@ -112,9 +111,7 @@ class CrossSiloMessageConfig:
send_resource_label: Optional[Dict[str, str]] = None
recv_resource_label: Optional[Dict[str, str]] = None
http_header: Optional[Dict[str, str]] = None
# (Optional) The address this party is going to listen on.
# If not provided, the `address` will be used.
listening_address: str = None
max_concurrency: Optional[int] = None

def __json__(self):
return json.dumps(self.__dict__)
Expand All @@ -125,7 +122,7 @@ def from_json(cls, json_str):
return cls(**data)

@classmethod
def from_dict(cls, data: Dict):
def from_dict(cls, data: Dict) -> 'CrossSiloMessageConfig':
"""Initialize CrossSiloMessageConfig from a dictionary.

Args:
Expand All @@ -135,10 +132,8 @@ def from_dict(cls, data: Dict):
CrossSiloMessageConfig: An instance of CrossSiloMessageConfig.
"""
# Get the attributes of the class

data = data or {}
all_annotations = {**cls.__annotations__, **cls.__base__.__annotations__}
attrs = {attr for attr, _ in all_annotations.items()}
attrs = [field.name for field in fields(cls)]
# Filter the dictionary to only include keys that are attributes of the class
filtered_data = {key: value for key, value in data.items() if key in attrs}
return cls(**filtered_data)
Expand Down Expand Up @@ -170,5 +165,6 @@ class GrpcCrossSiloMessageConfig(CrossSiloMessageConfig):
('grpc.max_send_message_length', 50 * 1024 * 1024)
]
"""

grpc_channel_options: List = None
grpc_retry_policy: Dict[str, str] = None
Loading
Loading