Skip to content

Commit

Permalink
Reorg config. (#164)
Browse files Browse the repository at this point in the history
- [x] deprecated specifying config per party and also remove the related unit test.
- [x] made all face config API as dict type instead of an internal strong type.
- [x] simplified param address.
  • Loading branch information
jovany-wang authored Jul 21, 2023
1 parent 44a0556 commit 270546d
Show file tree
Hide file tree
Showing 34 changed files with 374 additions and 590 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ The above codes:
def main(party):
ray.init(address='local')

cluster = {
'alice': {'address': '127.0.0.1:11012'},
'bob': {'address': '127.0.0.1:11011'},
addresses = {
'alice': '127.0.0.1:11012',
'bob': '127.0.0.1:11011',
}
fed.init(cluster=cluster, party=party)
fed.init(addresses=addresses, party=party)
```
This first declares a two-party cluster, whose addresses corresponding to '127.0.0.1:11012' in 'alice' and '127.0.0.1:11011' in 'bob'.
And then, the `fed.init` create a cluster in the specified party.
Expand Down Expand Up @@ -145,11 +145,11 @@ def aggregate(val1, val2):
def main(party):
ray.init(address='local')

cluster = {
'alice': {'address': '127.0.0.1:11012'},
'bob': {'address': '127.0.0.1:11011'},
addresses = {
'alice': '127.0.0.1:11012',
'bob': '127.0.0.1:11011',
}
fed.init(cluster=cluster, party=party)
fed.init(addresses=addresses, party=party)

actor_alice = MyActor.party("alice").remote(1)
actor_bob = MyActor.party("bob").remote(1)
Expand Down
8 changes: 4 additions & 4 deletions benchmarks/many_tiny_tasks_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def aggr(self, val1, val2):
def main(party):
ray.init(address='local')

cluster = {
'alice': {'address': '127.0.0.1:11010'},
'bob': {'address': '127.0.0.1:11011'},
addresses = {
'alice': '127.0.0.1:11010',
'bob': '127.0.0.1:11011',
}
fed.init(cluster=cluster, party=party)
fed.init(addresses=addresses, party=party)

actor_alice = MyActor.party("alice").remote()
actor_bob = MyActor.party("bob").remote()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ For example:
>>> import ray
>>> import fed
>>> ray.init()
>>> fed.init(cluster=cluster, party="Alice", tls_config=tls_config)
>>> fed.init(addresses=addresses, party="Alice", tls_config=tls_config)

Successfully to connect to current Ray cluster in party `Alice`
10 changes: 5 additions & 5 deletions fed/_private/fed_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ class FedActorHandle:
def __init__(
self,
fed_class_task_id,
cluster,
addresses,
cls,
party,
node_party,
options,
) -> None:
self._fed_class_task_id = fed_class_task_id
self._cluster = cluster
self._addresses = addresses
self._body = cls
self._party = party
self._node_party = node_party
Expand All @@ -46,7 +46,7 @@ def __getattr__(self, method_name: str):
# Raise an error if the method is invalid.
getattr(self._body, method_name)
call_node = FedActorMethod(
self._cluster,
self._addresses,
self._party,
self._node_party,
self,
Expand Down Expand Up @@ -90,13 +90,13 @@ def _execute_remote_method(self, method_name, options, args, kwargs):
class FedActorMethod:
def __init__(
self,
cluster,
addresses,
party,
node_party,
fed_actor_handle,
method_name,
) -> None:
self._cluster = cluster
self._addresses = addresses
self._party = party # Current party
self._node_party = node_party
self._fed_actor_handle = fed_actor_handle
Expand Down
99 changes: 37 additions & 62 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import functools
import inspect
import logging
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List, Union

import cloudpickle
import ray
Expand All @@ -35,58 +35,36 @@
_start_sender_proxy,
)
from fed.proxy.grpc.grpc_proxy import SenderProxy, ReceiverProxy
from fed.config import CrossSiloMessageConfig
from fed.config import GrpcCrossSiloMessageConfig
from fed.fed_object import FedObject
from fed.utils import is_ray_object_refs, setup_logger

logger = logging.getLogger(__name__)


def init(
cluster: Dict = None,
addresses: Dict = None,
party: str = None,
config: Dict = {},
tls_config: Dict = None,
logging_level: str = 'info',
enable_waiting_for_other_parties_ready: bool = False,
sender_proxy_cls: SenderProxy = None,
receiver_proxy_cls: ReceiverProxy = None,
global_cross_silo_message_config: Optional[CrossSiloMessageConfig] = None,
**kwargs,
):
"""
Initialize a RayFed client.
Args:
cluster: optional; a dict describes the cluster config. E.g.
addresses: optional; a dict describes the addresses configurations. E.g.
.. code:: python
{
'alice': {
# The address for other parties.
'address': '127.0.0.1:10001',
# (Optional) the listen address, the `address` will be
# used if not provided.
'listen_addr': '0.0.0.0:10001',
'cross_silo_message_config': CrossSiloMessageConfig
},
'bob': {
# The address for other parties.
'address': '127.0.0.1:10002',
# (Optional) the listen address, the `address` will be
# used if not provided.
'listen_addr': '0.0.0.0:10002',
# (Optional) The party specific metadata sent with grpc requests
'grpc_metadata': (('token', 'bob-token'),),
},
'carol': {
# The address for other parties.
'address': '127.0.0.1:10003',
# (Optional) the listen address, the `address` will be
# used if not provided.
'listen_addr': '0.0.0.0:10003',
# (Optional) The party specific metadata sent with grpc requests
'grpc_metadata': (('token', 'carol-token'),),
},
# The address that can be connected to `alice` by other parties.
'alice': '127.0.0.1:10001',
# The address that can be connected to `bob` by other parties.
'bob': '127.0.0.1:10002',
# The address that can be connected to `carol` by other parties.
'carol': '127.0.0.1:10003',
}
party: optional; self party.
tls_config: optional; a dict describes the tls config. E.g.
Expand All @@ -109,50 +87,46 @@ def init(
}
logging_level: optional; the logging level, could be `debug`, `info`,
`warning`, `error`, `critical`, not case sensititive.
enable_waiting_for_other_parties_ready: ping other parties until they
are all ready if True.
global_cross_silo_message_config: Global cross-silo message related
configs that are applied to all connections. Supported configs
can refer to CrossSiloMessageConfig in config.py.
Examples:
>>> import fed
>>> import ray
>>> ray.init(address='local')
>>> cluster = {
>>> 'alice': {'address': '127.0.0.1:10001'},
>>> 'bob': {'address': '127.0.0.1:10002'},
>>> 'carol': {'address': '127.0.0.1:10003'},
>>> addresses = {
>>> 'alice': '127.0.0.1:10001',
>>> 'bob': '127.0.0.1:10002',
>>> 'carol': '127.0.0.1:10003',
>>> }
>>> # Start as alice.
>>> fed.init(cluster=cluster, self_party='alice')
>>> fed.init(addresses=addresses, party='alice')
"""
assert cluster, "Cluster should be provided."
assert addresses, "Addresses should be provided."
assert party, "Party should be provided."
assert party in cluster, f"Party {party} is not in cluster {cluster}."
assert party in addresses, f"Party {party} is not in the addresses {addresses}."

fed_utils.validate_cluster_info(cluster)
fed_utils.validate_addresses(addresses)

tls_config = {} if tls_config is None else tls_config
if tls_config:
assert (
'cert' in tls_config and 'key' in tls_config
), 'Cert or key are not in tls_config.'

global_cross_silo_message_config = \
global_cross_silo_message_config or CrossSiloMessageConfig()
cross_silo_message_dict = config.get("cross_silo_message", {})
cross_silo_message_config = GrpcCrossSiloMessageConfig.from_dict(
cross_silo_message_dict)
# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv()

cluster_config = {
constants.KEY_OF_CLUSTER_ADDRESSES: cluster,
constants.KEY_OF_CLUSTER_ADDRESSES: addresses,
constants.KEY_OF_CURRENT_PARTY_NAME: party,
constants.KEY_OF_TLS_CONFIG: tls_config,
}

job_config = {
constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG:
global_cross_silo_message_config,
cross_silo_message_config,
}
compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG,
cloudpickle.dumps(cluster_config))
Expand All @@ -170,7 +144,7 @@ def init(

logger.info(f'Started rayfed with {cluster_config}')
get_global_context().get_cleanup_manager().start(
exit_when_failure_sending=global_cross_silo_message_config.exit_on_sending_failure) # noqa
exit_when_failure_sending=cross_silo_message_config.exit_on_sending_failure) # noqa

if receiver_proxy_cls is None:
logger.debug(
Expand All @@ -179,12 +153,12 @@ def init(
from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy
receiver_proxy_cls = GrpcReceiverProxy
_start_receiver_proxy(
cluster=cluster,
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=receiver_proxy_cls,
proxy_config=global_cross_silo_message_config
proxy_config=cross_silo_message_config
)

if sender_proxy_cls is None:
Expand All @@ -194,17 +168,18 @@ def init(
from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy
sender_proxy_cls = GrpcSenderProxy
_start_sender_proxy(
cluster=cluster,
addresses=addresses,
party=party,
logging_level=logging_level,
tls_config=tls_config,
proxy_cls=sender_proxy_cls,
proxy_config=global_cross_silo_message_config
# TODO(qwang): proxy_config -> cross_silo_message_config
proxy_config=cross_silo_message_config
)

if enable_waiting_for_other_parties_ready:
if config.get("barrier_on_initializing", False):
# TODO(zhouaihui): can be removed after we have a better retry strategy.
ping_others(cluster=cluster, self_party=party, max_retries=3600)
ping_others(addresses=addresses, self_party=party, max_retries=3600)


def shutdown():
Expand All @@ -216,9 +191,9 @@ def shutdown():
logger.info('Shutdowned rayfed.')


def _get_cluster():
def _get_addresses():
"""
Get the RayFed cluster configration.
Get the RayFed addresses configration.
"""
return fed_config.get_cluster_config().cluster_addresses

Expand Down Expand Up @@ -290,7 +265,7 @@ def remote(self, *cls_args, **cls_kwargs):
fed_class_task_id = get_global_context().next_seq_id()
fed_actor_handle = FedActorHandle(
fed_class_task_id,
_get_cluster(),
_get_addresses(),
self._cls,
_get_party(),
self._party,
Expand Down Expand Up @@ -341,7 +316,7 @@ def get(
# A fake fed_task_id for a `fed.get()` operator. This is useful
# to help contruct the whole DAG within `fed.get`.
fake_fed_task_id = get_global_context().next_seq_id()
cluster = _get_cluster()
addresses = _get_addresses()
current_party = _get_party()
is_individual_id = isinstance(fed_objects, FedObject)
if is_individual_id:
Expand All @@ -357,7 +332,7 @@ def get(
assert ray_object_ref is not None
ray_refs.append(ray_object_ref)

for party_name in cluster:
for party_name in addresses:
if party_name == current_party:
continue
else:
Expand Down
8 changes: 7 additions & 1 deletion fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class CrossSiloMessageConfig:
send_resource_label: Optional[Dict[str, str]] = None
recv_resource_label: Optional[Dict[str, str]] = None
http_header: Optional[Dict[str, str]] = None
# (Optional) The address this party is going to listen on.
# If not provided, the `address` will be used.
listening_address: str = None

def __json__(self):
return json.dumps(self.__dict__)
Expand All @@ -132,7 +135,10 @@ def from_dict(cls, data: Dict):
CrossSiloMessageConfig: An instance of CrossSiloMessageConfig.
"""
# Get the attributes of the class
attrs = {attr for attr, _ in cls.__annotations__.items()}

data = data or {}
all_annotations = {**cls.__annotations__, **cls.__base__.__annotations__}
attrs = {attr for attr, _ in all_annotations.items()}
# Filter the dictionary to only include keys that are attributes of the class
filtered_data = {key: value for key, value in data.items() if key in attrs}
return cls(**filtered_data)
Expand Down
Loading

0 comments on commit 270546d

Please sign in to comment.