Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorg config. #164

Merged
merged 24 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ The above codes:
def main(party):
ray.init(address='local')

cluster = {
'alice': {'address': '127.0.0.1:11012'},
'bob': {'address': '127.0.0.1:11011'},
addresses = {
'alice': '127.0.0.1:11012',
'bob': '127.0.0.1:11011',
}
fed.init(cluster=cluster, party=party)
fed.init(addresses=addresses, party=party)
```
This first declares a two-party cluster, whose addresses corresponding to '127.0.0.1:11012' in 'alice' and '127.0.0.1:11011' in 'bob'.
And then, the `fed.init` create a cluster in the specified party.
Expand Down Expand Up @@ -145,11 +145,11 @@ def aggregate(val1, val2):
def main(party):
ray.init(address='local')

cluster = {
'alice': {'address': '127.0.0.1:11012'},
'bob': {'address': '127.0.0.1:11011'},
addresses = {
'alice': '127.0.0.1:11012',
'bob': '127.0.0.1:11011',
}
fed.init(cluster=cluster, party=party)
fed.init(addresses=addresses, party=party)

actor_alice = MyActor.party("alice").remote(1)
actor_bob = MyActor.party("bob").remote(1)
Expand Down
8 changes: 4 additions & 4 deletions benchmarks/many_tiny_tasks_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def aggr(self, val1, val2):
def main(party):
ray.init(address='local')

cluster = {
'alice': {'address': '127.0.0.1:11010'},
'bob': {'address': '127.0.0.1:11011'},
addresses = {
'alice': '127.0.0.1:11010',
'bob': '127.0.0.1:11011',
}
fed.init(cluster=cluster, party=party)
fed.init(addresses=addresses, party=party)

actor_alice = MyActor.party("alice").remote()
actor_bob = MyActor.party("bob").remote()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ For example:
>>> import ray
>>> import fed
>>> ray.init()
>>> fed.init(cluster=cluster, party="Alice", tls_config=tls_config)
>>> fed.init(addresses=addresses, party="Alice", tls_config=tls_config)

Successfully to connect to current Ray cluster in party `Alice`
10 changes: 5 additions & 5 deletions fed/_private/fed_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ class FedActorHandle:
def __init__(
self,
fed_class_task_id,
cluster,
addresses,
cls,
party,
node_party,
options,
) -> None:
self._fed_class_task_id = fed_class_task_id
self._cluster = cluster
self._addresses = addresses
self._body = cls
self._party = party
self._node_party = node_party
Expand All @@ -46,7 +46,7 @@ def __getattr__(self, method_name: str):
# Raise an error if the method is invalid.
getattr(self._body, method_name)
call_node = FedActorMethod(
self._cluster,
self._addresses,
self._party,
self._node_party,
self,
Expand Down Expand Up @@ -90,13 +90,13 @@ def _execute_remote_method(self, method_name, options, args, kwargs):
class FedActorMethod:
def __init__(
self,
cluster,
addresses,
party,
node_party,
fed_actor_handle,
method_name,
) -> None:
self._cluster = cluster
self._addresses = addresses
self._party = party # Current party
self._node_party = node_party
self._fed_actor_handle = fed_actor_handle
Expand Down
99 changes: 37 additions & 62 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import functools
import inspect
import logging
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List, Union

import cloudpickle
import ray
Expand All @@ -35,58 +35,36 @@
_start_sender_proxy,
)
from fed.proxy.grpc.grpc_proxy import SenderProxy, ReceiverProxy
from fed.config import CrossSiloMessageConfig
from fed.config import GrpcCrossSiloMessageConfig
from fed.fed_object import FedObject
from fed.utils import is_ray_object_refs, setup_logger

logger = logging.getLogger(__name__)


def init(
cluster: Dict = None,
addresses: Dict = None,
party: str = None,
config: Dict = {},
tls_config: Dict = None,
logging_level: str = 'info',
enable_waiting_for_other_parties_ready: bool = False,
sender_proxy_cls: SenderProxy = None,
receiver_proxy_cls: ReceiverProxy = None,
global_cross_silo_message_config: Optional[CrossSiloMessageConfig] = None,
**kwargs,
):
"""
Initialize a RayFed client.

Args:
cluster: optional; a dict describes the cluster config. E.g.
addresses: optional; a dict describes the addresses configurations. E.g.

.. code:: python
{
'alice': {
# The address for other parties.
'address': '127.0.0.1:10001',
# (Optional) the listen address, the `address` will be
# used if not provided.
'listen_addr': '0.0.0.0:10001',
'cross_silo_message_config': CrossSiloMessageConfig
},
'bob': {
# The address for other parties.
'address': '127.0.0.1:10002',
# (Optional) the listen address, the `address` will be
# used if not provided.
'listen_addr': '0.0.0.0:10002',
# (Optional) The party specific metadata sent with grpc requests
'grpc_metadata': (('token', 'bob-token'),),
},
'carol': {
# The address for other parties.
'address': '127.0.0.1:10003',
# (Optional) the listen address, the `address` will be
# used if not provided.
'listen_addr': '0.0.0.0:10003',
# (Optional) The party specific metadata sent with grpc requests
'grpc_metadata': (('token', 'carol-token'),),
},
# The address that can be connected to `alice` by other parties.
'alice': '127.0.0.1:10001',
# The address that can be connected to `bob` by other parties.
'bob': '127.0.0.1:10002',
# The address that can be connected to `carol` by other parties.
'carol': '127.0.0.1:10003',
}
party: optional; self party.
tls_config: optional; a dict describes the tls config. E.g.
Expand All @@ -109,50 +87,46 @@ def init(
}
logging_level: optional; the logging level, could be `debug`, `info`,
`warning`, `error`, `critical`, not case sensititive.
enable_waiting_for_other_parties_ready: ping other parties until they
are all ready if True.
global_cross_silo_message_config: Global cross-silo message related
configs that are applied to all connections. Supported configs
can refer to CrossSiloMessageConfig in config.py.

Examples:
>>> import fed
>>> import ray
>>> ray.init(address='local')
>>> cluster = {
>>> 'alice': {'address': '127.0.0.1:10001'},
>>> 'bob': {'address': '127.0.0.1:10002'},
>>> 'carol': {'address': '127.0.0.1:10003'},
>>> addresses = {
>>> 'alice': '127.0.0.1:10001',
>>> 'bob': '127.0.0.1:10002',
>>> 'carol': '127.0.0.1:10003',
>>> }
>>> # Start as alice.
>>> fed.init(cluster=cluster, self_party='alice')
>>> fed.init(addresses=addresses, party='alice')
"""
assert cluster, "Cluster should be provided."
assert addresses, "Addresses should be provided."
assert party, "Party should be provided."
assert party in cluster, f"Party {party} is not in cluster {cluster}."
assert party in addresses, f"Party {party} is not in the addresses {addresses}."

fed_utils.validate_cluster_info(cluster)
fed_utils.validate_addresses(addresses)

tls_config = {} if tls_config is None else tls_config
if tls_config:
assert (
'cert' in tls_config and 'key' in tls_config
), 'Cert or key are not in tls_config.'

global_cross_silo_message_config = \
global_cross_silo_message_config or CrossSiloMessageConfig()
cross_silo_message_dict = config.get("cross_silo_message", {})
cross_silo_message_config = GrpcCrossSiloMessageConfig.from_dict(
jovany-wang marked this conversation as resolved.
Show resolved Hide resolved
cross_silo_message_dict)
# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv()

cluster_config = {
constants.KEY_OF_CLUSTER_ADDRESSES: cluster,
constants.KEY_OF_CLUSTER_ADDRESSES: addresses,
constants.KEY_OF_CURRENT_PARTY_NAME: party,
constants.KEY_OF_TLS_CONFIG: tls_config,
}

job_config = {
constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG:
global_cross_silo_message_config,
cross_silo_message_config,
}
compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG,
cloudpickle.dumps(cluster_config))
Expand All @@ -170,7 +144,7 @@ def init(

logger.info(f'Started rayfed with {cluster_config}')
get_global_context().get_cleanup_manager().start(
exit_when_failure_sending=global_cross_silo_message_config.exit_on_sending_failure) # noqa
exit_when_failure_sending=cross_silo_message_config.exit_on_sending_failure) # noqa

if receiver_proxy_cls is None:
logger.debug(
Expand All @@ -179,12 +153,12 @@ def init(
from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy
receiver_proxy_cls = GrpcReceiverProxy
_start_receiver_proxy(
cluster=cluster,
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=receiver_proxy_cls,
proxy_config=global_cross_silo_message_config
proxy_config=cross_silo_message_config
)

if sender_proxy_cls is None:
Expand All @@ -194,17 +168,18 @@ def init(
from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy
sender_proxy_cls = GrpcSenderProxy
_start_sender_proxy(
cluster=cluster,
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=sender_proxy_cls,
proxy_config=global_cross_silo_message_config
# TODO(qwang): proxy_config -> cross_silo_message_config
proxy_config=cross_silo_message_config
)

if enable_waiting_for_other_parties_ready:
if config.get("barrier_on_initializing", False):
# TODO(zhouaihui): can be removed after we have a better retry strategy.
ping_others(cluster=cluster, self_party=party, max_retries=3600)
ping_others(addresses=addresses, self_party=party, max_retries=3600)


def shutdown():
Expand All @@ -216,9 +191,9 @@ def shutdown():
logger.info('Shutdowned rayfed.')


def _get_cluster():
def _get_addresses():
"""
Get the RayFed cluster configration.
Get the RayFed addresses configration.
"""
return fed_config.get_cluster_config().cluster_addresses

Expand Down Expand Up @@ -290,7 +265,7 @@ def remote(self, *cls_args, **cls_kwargs):
fed_class_task_id = get_global_context().next_seq_id()
fed_actor_handle = FedActorHandle(
fed_class_task_id,
_get_cluster(),
_get_addresses(),
self._cls,
_get_party(),
self._party,
Expand Down Expand Up @@ -341,7 +316,7 @@ def get(
# A fake fed_task_id for a `fed.get()` operator. This is useful
# to help contruct the whole DAG within `fed.get`.
fake_fed_task_id = get_global_context().next_seq_id()
cluster = _get_cluster()
addresses = _get_addresses()
current_party = _get_party()
is_individual_id = isinstance(fed_objects, FedObject)
if is_individual_id:
Expand All @@ -357,7 +332,7 @@ def get(
assert ray_object_ref is not None
ray_refs.append(ray_object_ref)

for party_name in cluster:
for party_name in addresses:
if party_name == current_party:
continue
else:
Expand Down
8 changes: 7 additions & 1 deletion fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class CrossSiloMessageConfig:
send_resource_label: Optional[Dict[str, str]] = None
recv_resource_label: Optional[Dict[str, str]] = None
http_header: Optional[Dict[str, str]] = None
# (Optional) The address this party is going to listen on.
# If not provided, the `address` will be used.
listening_address: str = None

def __json__(self):
return json.dumps(self.__dict__)
Expand All @@ -132,7 +135,10 @@ def from_dict(cls, data: Dict):
CrossSiloMessageConfig: An instance of CrossSiloMessageConfig.
"""
# Get the attributes of the class
attrs = {attr for attr, _ in cls.__annotations__.items()}

data = data or {}
all_annotations = {**cls.__annotations__, **cls.__base__.__annotations__}
attrs = {attr for attr, _ in all_annotations.items()}
# Filter the dictionary to only include keys that are attributes of the class
filtered_data = {key: value for key, value in data.items() if key in attrs}
return cls(**filtered_data)
Expand Down
Loading