Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support running RayFed job in Ray client mode. #173

Merged
merged 25 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 47 additions & 19 deletions fed/_private/fed_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging

import ray
from ray.util.client.common import ClientActorHandle
from fed._private.fed_call_holder import FedCallHolder
from fed.fed_object import FedObject

Expand All @@ -37,22 +38,41 @@ def __init__(
self._party = party
self._node_party = node_party
self._options = options
self._actor_handle = None
self._ray_actor_handle = None

def __getattr__(self, method_name: str):
# User trying to call .bind() without a bind class method
if method_name == "remote" and "remote" not in dir(self._body):
raise AttributeError(f".remote() cannot be used again on {type(self)} ")
# Raise an error if the method is invalid.
getattr(self._body, method_name)
call_node = FedActorMethod(
self._addresses,
self._party,
self._node_party,
self,
method_name,
).options(**self._options)
return call_node

if self._party == self._node_party:
ray_actor_handle = self._ray_actor_handle
try:
ray_wrappered_method = ray_actor_handle.__getattribute__(method_name)
except AttributeError:
# The code path in Ray client mode.
assert isinstance(ray_actor_handle, ClientActorHandle)
ray_wrappered_method = ray_actor_handle.__getattr__(method_name)

return FedActorMethod(
self._addresses,
self._party,
self._node_party,
self,
method_name,
ray_wrappered_method,
).options(**self._options)
else:
return FedActorMethod(
self._addresses,
self._party,
self._node_party,
self,
method_name,
None,
).options(**self._options)

def _execute_impl(self, cls_args, cls_kwargs):
"""Executor of ClassNode by ray.remote()
Expand All @@ -63,28 +83,34 @@ def _execute_impl(self, cls_args, cls_kwargs):
current node is executed.
"""
if self._node_party == self._party:
self._actor_handle = (
self._ray_actor_handle = (
ray.remote(self._body)
.options(**self._options)
.remote(*cls_args, **cls_kwargs)
)

def _execute_remote_method(self, method_name, options, args, kwargs):
def _execute_remote_method(
self,
method_name,
options,
_ray_wrappered_method,
args,
kwargs,
):
num_returns = 1
if options and 'num_returns' in options:
num_returns = options['num_returns']
logger.debug(
f"Actor method call: {method_name}, num_returns: {num_returns}"
)
ray_object_ref = self._actor_handle._actor_method_call(
method_name,
args=args,
kwargs=kwargs,
name="",

return _ray_wrappered_method.options(
name='',
num_returns=num_returns,
concurrency_group_name="",
).remote(
*args,
**kwargs,
)
return ray_object_ref


class FedActorMethod:
Expand All @@ -95,13 +121,15 @@ def __init__(
node_party,
fed_actor_handle,
method_name,
ray_wrappered_method,
) -> None:
self._addresses = addresses
self._party = party # Current party
self._node_party = node_party
self._fed_actor_handle = fed_actor_handle
self._method_name = method_name
self._options = {}
self._ray_wrappered_method = ray_wrappered_method
self._fed_call_holder = FedCallHolder(node_party, self._execute_impl)

def remote(self, *args, **kwargs) -> FedObject:
Expand All @@ -114,5 +142,5 @@ def options(self, **options):

def _execute_impl(self, args, kwargs):
return self._fed_actor_handle._execute_remote_method(
self._method_name, self._options, args, kwargs
self._method_name, self._options, self._ray_wrappered_method, args, kwargs
)
File renamed without changes.
99 changes: 99 additions & 0 deletions fed/tests/client_mode_tests/test_basic_client_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2023 The RayFed Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing

import pytest
import ray
import fed
import fed._private.compatible_utils as compatible_utils
from fed.tests.test_utils import ray_client_mode_setup # noqa


@fed.remote
class MyModel:
def __init__(self, party, step_length):
self._trained_steps = 0
self._step_length = step_length
self._weights = 0
self._party = party

def train(self):
self._trained_steps += 1
self._weights += self._step_length
return self._weights

def get_weights(self):
return self._weights

def set_weights(self, new_weights):
self._weights = new_weights
return new_weights


@fed.remote
def mean(x, y):
return (x + y) / 2


def run(party):
import time
if party == 'alice':
time.sleep(1.4)

address = 'ray://127.0.0.1:21012' if party == 'alice' else 'ray://127.0.0.1:21011' # noqa
compatible_utils.init_ray(address=address)

addresses = {
'alice': '127.0.0.1:31012',
'bob': '127.0.0.1:31011',
}
fed.init(addresses=addresses, party=party)

epochs = 3
alice_model = MyModel.party("alice").remote("alice", 2)
bob_model = MyModel.party("bob").remote("bob", 4)

all_mean_weights = []
for epoch in range(epochs):
w1 = alice_model.train.remote()
w2 = bob_model.train.remote()
new_weights = mean.party("alice").remote(w1, w2)
result = fed.get(new_weights)
alice_model.set_weights.remote(new_weights)
bob_model.set_weights.remote(new_weights)
all_mean_weights.append(result)
assert all_mean_weights == [3, 6, 9]
latest_weights = fed.get(
[alice_model.get_weights.remote(), bob_model.get_weights.remote()]
)
assert latest_weights == [9, 9]
fed.shutdown()
ray.shutdown()


def test_fed_get_in_2_parties(ray_client_mode_setup): # noqa
p_alice = multiprocessing.Process(target=run, args=('alice',))
p_bob = multiprocessing.Process(target=run, args=('bob',))
p_alice.start()
p_bob.start()
p_alice.join()
p_bob.join()
assert p_alice.exitcode == 0 and p_bob.exitcode == 0


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-sv", __file__]))
File renamed without changes.
File renamed without changes.
File renamed without changes.
11 changes: 9 additions & 2 deletions tests/test_fed_get.py → fed/tests/test_fed_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,17 @@ def mean(x, y):


def run(party):
import time
if party == 'alice':
time.sleep(1.4)

# address = 'ray://127.0.0.1:21012' if party == 'alice' else 'ray://127.0.0.1:21011' # noqa
# compatible_utils.init_ray(address=address)
compatible_utils.init_ray(address='local')

addresses = {
'alice': '127.0.0.1:11012',
'bob': '127.0.0.1:11011',
'alice': '127.0.0.1:31012',
'bob': '127.0.0.1:31011',
}
fed.init(addresses=addresses, party=party)

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import fed._private.compatible_utils as compatible_utils
import fed.utils as fed_utils
from fed._private import constants, global_context
from fed.config import CrossSiloMessageConfig, GrpcCrossSiloMessageConfig
from fed.proxy.barriers import (
_start_receiver_proxy,
_start_sender_proxy,
Expand Down
60 changes: 60 additions & 0 deletions fed/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2023 The RayFed Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import pytest

import fed.utils as fed_utils


def start_ray_cluster(
ray_port,
client_server_port,
dashboard_port,
):
command = [
'ray',
'start',
'--head',
f'--port={ray_port}',
f'--ray-client-server-port={client_server_port}',
f'--dashboard-port={dashboard_port}',
]
command_str = ' '.join(command)
try:
_ = fed_utils.start_command(command_str)
except RuntimeError as e:
# As we should treat the following warning messages is ok to use.
# E RuntimeError: Failed to start command [ray start --head --port=41012
# --ray-client-server-port=21012 --dashboard-port=9112], the error is:
# E 2023-09-13 13:04:11,520 WARNING services.py:1882 -- WARNING: The
# object store is using /tmp instead of /dev/shm because /dev/shm has only
# 67108864 bytes available. This will harm performance! You may be able to
# free up space by deleting files in /dev/shm. If you are inside a Docker
# container, you can increase /dev/shm size by passing '--shm-size=1.97gb' to
# 'docker run' (or add it to the run_options list in a Ray cluster config).
# Make sure to set this to more than 0% of available RAM.
assert 'Overwriting previous Ray address' in str(e) \
or 'WARNING: The object store is using /tmp instead of /dev/shm' in str(e)


@pytest.fixture
def ray_client_mode_setup():
# Start 2 Ray clusters.
start_ray_cluster(ray_port=41012, client_server_port=21012, dashboard_port=9112)
time.sleep(1)
start_ray_cluster(ray_port=41011, client_server_port=21011, dashboard_port=9111)

yield
fed_utils.start_command('ray stop --force')
17 changes: 17 additions & 0 deletions fed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import re
import sys
import subprocess

import ray

Expand Down Expand Up @@ -236,3 +237,19 @@ def validate_addresses(addresses: dict):
isinstance(address, str) and address
), f'Address should be string but got {address}.'
validate_address(address)


def start_command(command: str, timeout=60) :
"""
A util to start a shell command.
"""
process = subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
output, error = process.communicate(timeout=timeout)
if len(error) != 0:
raise RuntimeError(
f'Failed to start command [{command}], the error is:\n {error.decode()}')
return output
3 changes: 2 additions & 1 deletion test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ export RAY_TLS_SERVER_CERT="/tmp/rayfed/test-certs/server.crt"
export RAY_TLS_SERVER_KEY="/tmp/rayfed/test-certs/server.key"
export RAY_TLS_CA_CERT="/tmp/rayfed/test-certs/server.crt"

cd tests
cd fed/tests
python3 -m pytest -v -s test_*
python3 -m pytest -v -s serializations_tests/test_*
python3 -m pytest -v -s without_ray_tests/test_*
python3 -m pytest -v -s client_mode_tests/test_*
cd -

echo "All tests finished."
Loading