diff --git a/fed/_private/fed_actor.py b/fed/_private/fed_actor.py index aa73bfa..dd88057 100644 --- a/fed/_private/fed_actor.py +++ b/fed/_private/fed_actor.py @@ -15,6 +15,7 @@ import logging import ray +from ray.util.client.common import ClientActorHandle from fed._private.fed_call_holder import FedCallHolder from fed.fed_object import FedObject @@ -37,7 +38,7 @@ def __init__( self._party = party self._node_party = node_party self._options = options - self._actor_handle = None + self._ray_actor_handle = None def __getattr__(self, method_name: str): # User trying to call .bind() without a bind class method @@ -45,14 +46,33 @@ def __getattr__(self, method_name: str): raise AttributeError(f".remote() cannot be used again on {type(self)} ") # Raise an error if the method is invalid. getattr(self._body, method_name) - call_node = FedActorMethod( - self._addresses, - self._party, - self._node_party, - self, - method_name, - ).options(**self._options) - return call_node + + if self._party == self._node_party: + ray_actor_handle = self._ray_actor_handle + try: + ray_wrappered_method = ray_actor_handle.__getattribute__(method_name) + except AttributeError: + # The code path in Ray client mode. + assert isinstance(ray_actor_handle, ClientActorHandle) + ray_wrappered_method = ray_actor_handle.__getattr__(method_name) + + return FedActorMethod( + self._addresses, + self._party, + self._node_party, + self, + method_name, + ray_wrappered_method, + ).options(**self._options) + else: + return FedActorMethod( + self._addresses, + self._party, + self._node_party, + self, + method_name, + None, + ).options(**self._options) def _execute_impl(self, cls_args, cls_kwargs): """Executor of ClassNode by ray.remote() @@ -63,28 +83,34 @@ def _execute_impl(self, cls_args, cls_kwargs): current node is executed. """ if self._node_party == self._party: - self._actor_handle = ( + self._ray_actor_handle = ( ray.remote(self._body) .options(**self._options) .remote(*cls_args, **cls_kwargs) ) - def _execute_remote_method(self, method_name, options, args, kwargs): + def _execute_remote_method( + self, + method_name, + options, + _ray_wrappered_method, + args, + kwargs, + ): num_returns = 1 if options and 'num_returns' in options: num_returns = options['num_returns'] logger.debug( f"Actor method call: {method_name}, num_returns: {num_returns}" ) - ray_object_ref = self._actor_handle._actor_method_call( - method_name, - args=args, - kwargs=kwargs, - name="", + + return _ray_wrappered_method.options( + name='', num_returns=num_returns, - concurrency_group_name="", + ).remote( + *args, + **kwargs, ) - return ray_object_ref class FedActorMethod: @@ -95,6 +121,7 @@ def __init__( node_party, fed_actor_handle, method_name, + ray_wrappered_method, ) -> None: self._addresses = addresses self._party = party # Current party @@ -102,6 +129,7 @@ def __init__( self._fed_actor_handle = fed_actor_handle self._method_name = method_name self._options = {} + self._ray_wrappered_method = ray_wrappered_method self._fed_call_holder = FedCallHolder(node_party, self._execute_impl) def remote(self, *args, **kwargs) -> FedObject: @@ -114,5 +142,5 @@ def options(self, **options): def _execute_impl(self, args, kwargs): return self._fed_actor_handle._execute_remote_method( - self._method_name, self._options, args, kwargs + self._method_name, self._options, self._ray_wrappered_method, args, kwargs ) diff --git a/tests/__init__.py b/fed/tests/__init__.py similarity index 100% rename from tests/__init__.py rename to fed/tests/__init__.py diff --git a/fed/tests/client_mode_tests/test_basic_client_mode.py b/fed/tests/client_mode_tests/test_basic_client_mode.py new file mode 100644 index 0000000..9807802 --- /dev/null +++ b/fed/tests/client_mode_tests/test_basic_client_mode.py @@ -0,0 +1,99 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing + +import pytest +import ray +import fed +import fed._private.compatible_utils as compatible_utils +from fed.tests.test_utils import ray_client_mode_setup # noqa + + +@fed.remote +class MyModel: + def __init__(self, party, step_length): + self._trained_steps = 0 + self._step_length = step_length + self._weights = 0 + self._party = party + + def train(self): + self._trained_steps += 1 + self._weights += self._step_length + return self._weights + + def get_weights(self): + return self._weights + + def set_weights(self, new_weights): + self._weights = new_weights + return new_weights + + +@fed.remote +def mean(x, y): + return (x + y) / 2 + + +def run(party): + import time + if party == 'alice': + time.sleep(1.4) + + address = 'ray://127.0.0.1:21012' if party == 'alice' else 'ray://127.0.0.1:21011' # noqa + compatible_utils.init_ray(address=address) + + addresses = { + 'alice': '127.0.0.1:31012', + 'bob': '127.0.0.1:31011', + } + fed.init(addresses=addresses, party=party) + + epochs = 3 + alice_model = MyModel.party("alice").remote("alice", 2) + bob_model = MyModel.party("bob").remote("bob", 4) + + all_mean_weights = [] + for epoch in range(epochs): + w1 = alice_model.train.remote() + w2 = bob_model.train.remote() + new_weights = mean.party("alice").remote(w1, w2) + result = fed.get(new_weights) + alice_model.set_weights.remote(new_weights) + bob_model.set_weights.remote(new_weights) + all_mean_weights.append(result) + assert all_mean_weights == [3, 6, 9] + latest_weights = fed.get( + [alice_model.get_weights.remote(), bob_model.get_weights.remote()] + ) + assert latest_weights == [9, 9] + fed.shutdown() + ray.shutdown() + + +def test_fed_get_in_2_parties(ray_client_mode_setup): # noqa + p_alice = multiprocessing.Process(target=run, args=('alice',)) + p_bob = multiprocessing.Process(target=run, args=('bob',)) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + assert p_alice.exitcode == 0 and p_bob.exitcode == 0 + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-sv", __file__])) diff --git a/tests/multi-jobs/test_ignore_other_job_msg.py b/fed/tests/multi-jobs/test_ignore_other_job_msg.py similarity index 100% rename from tests/multi-jobs/test_ignore_other_job_msg.py rename to fed/tests/multi-jobs/test_ignore_other_job_msg.py diff --git a/tests/serializations_tests/test_unpickle_with_whitelist.py b/fed/tests/serializations_tests/test_unpickle_with_whitelist.py similarity index 100% rename from tests/serializations_tests/test_unpickle_with_whitelist.py rename to fed/tests/serializations_tests/test_unpickle_with_whitelist.py diff --git a/tests/simple_example.py b/fed/tests/simple_example.py similarity index 100% rename from tests/simple_example.py rename to fed/tests/simple_example.py diff --git a/tests/test_api.py b/fed/tests/test_api.py similarity index 100% rename from tests/test_api.py rename to fed/tests/test_api.py diff --git a/tests/test_async_startup_2_clusters.py b/fed/tests/test_async_startup_2_clusters.py similarity index 100% rename from tests/test_async_startup_2_clusters.py rename to fed/tests/test_async_startup_2_clusters.py diff --git a/tests/test_basic_pass_fed_objects.py b/fed/tests/test_basic_pass_fed_objects.py similarity index 100% rename from tests/test_basic_pass_fed_objects.py rename to fed/tests/test_basic_pass_fed_objects.py diff --git a/tests/test_cache_fed_objects.py b/fed/tests/test_cache_fed_objects.py similarity index 100% rename from tests/test_cache_fed_objects.py rename to fed/tests/test_cache_fed_objects.py diff --git a/tests/test_enable_tls_across_parties.py b/fed/tests/test_enable_tls_across_parties.py similarity index 100% rename from tests/test_enable_tls_across_parties.py rename to fed/tests/test_enable_tls_across_parties.py diff --git a/tests/test_exit_on_failure_sending.py b/fed/tests/test_exit_on_failure_sending.py similarity index 100% rename from tests/test_exit_on_failure_sending.py rename to fed/tests/test_exit_on_failure_sending.py diff --git a/tests/test_fed_get.py b/fed/tests/test_fed_get.py similarity index 89% rename from tests/test_fed_get.py rename to fed/tests/test_fed_get.py index 3aaf71d..5752f77 100644 --- a/tests/test_fed_get.py +++ b/fed/tests/test_fed_get.py @@ -47,10 +47,17 @@ def mean(x, y): def run(party): + import time + if party == 'alice': + time.sleep(1.4) + + # address = 'ray://127.0.0.1:21012' if party == 'alice' else 'ray://127.0.0.1:21011' # noqa + # compatible_utils.init_ray(address=address) compatible_utils.init_ray(address='local') + addresses = { - 'alice': '127.0.0.1:11012', - 'bob': '127.0.0.1:11011', + 'alice': '127.0.0.1:31012', + 'bob': '127.0.0.1:31011', } fed.init(addresses=addresses, party=party) diff --git a/tests/test_grpc_options_on_proxies.py b/fed/tests/test_grpc_options_on_proxies.py similarity index 100% rename from tests/test_grpc_options_on_proxies.py rename to fed/tests/test_grpc_options_on_proxies.py diff --git a/tests/test_internal_kv.py b/fed/tests/test_internal_kv.py similarity index 100% rename from tests/test_internal_kv.py rename to fed/tests/test_internal_kv.py diff --git a/tests/test_listening_address.py b/fed/tests/test_listening_address.py similarity index 100% rename from tests/test_listening_address.py rename to fed/tests/test_listening_address.py diff --git a/tests/test_options.py b/fed/tests/test_options.py similarity index 100% rename from tests/test_options.py rename to fed/tests/test_options.py diff --git a/tests/test_pass_fed_objects_in_containers_in_actor.py b/fed/tests/test_pass_fed_objects_in_containers_in_actor.py similarity index 100% rename from tests/test_pass_fed_objects_in_containers_in_actor.py rename to fed/tests/test_pass_fed_objects_in_containers_in_actor.py diff --git a/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py b/fed/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py similarity index 100% rename from tests/test_pass_fed_objects_in_containers_in_normal_tasks.py rename to fed/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py diff --git a/tests/test_ping_others.py b/fed/tests/test_ping_others.py similarity index 100% rename from tests/test_ping_others.py rename to fed/tests/test_ping_others.py diff --git a/tests/test_repeat_init.py b/fed/tests/test_repeat_init.py similarity index 100% rename from tests/test_repeat_init.py rename to fed/tests/test_repeat_init.py diff --git a/tests/test_reset_context.py b/fed/tests/test_reset_context.py similarity index 100% rename from tests/test_reset_context.py rename to fed/tests/test_reset_context.py diff --git a/tests/test_retry_policy.py b/fed/tests/test_retry_policy.py similarity index 100% rename from tests/test_retry_policy.py rename to fed/tests/test_retry_policy.py diff --git a/tests/test_setup_proxy_actor.py b/fed/tests/test_setup_proxy_actor.py similarity index 100% rename from tests/test_setup_proxy_actor.py rename to fed/tests/test_setup_proxy_actor.py diff --git a/tests/test_transport_proxy.py b/fed/tests/test_transport_proxy.py similarity index 99% rename from tests/test_transport_proxy.py rename to fed/tests/test_transport_proxy.py index 9ebb787..55a9b56 100644 --- a/tests/test_transport_proxy.py +++ b/fed/tests/test_transport_proxy.py @@ -22,7 +22,6 @@ import fed._private.compatible_utils as compatible_utils import fed.utils as fed_utils from fed._private import constants, global_context -from fed.config import CrossSiloMessageConfig, GrpcCrossSiloMessageConfig from fed.proxy.barriers import ( _start_receiver_proxy, _start_sender_proxy, diff --git a/tests/test_transport_proxy_tls.py b/fed/tests/test_transport_proxy_tls.py similarity index 100% rename from tests/test_transport_proxy_tls.py rename to fed/tests/test_transport_proxy_tls.py diff --git a/fed/tests/test_utils.py b/fed/tests/test_utils.py new file mode 100644 index 0000000..f17f1a6 --- /dev/null +++ b/fed/tests/test_utils.py @@ -0,0 +1,60 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import pytest + +import fed.utils as fed_utils + + +def start_ray_cluster( + ray_port, + client_server_port, + dashboard_port, +): + command = [ + 'ray', + 'start', + '--head', + f'--port={ray_port}', + f'--ray-client-server-port={client_server_port}', + f'--dashboard-port={dashboard_port}', + ] + command_str = ' '.join(command) + try: + _ = fed_utils.start_command(command_str) + except RuntimeError as e: + # As we should treat the following warning messages is ok to use. + # E RuntimeError: Failed to start command [ray start --head --port=41012 + # --ray-client-server-port=21012 --dashboard-port=9112], the error is: + # E 2023-09-13 13:04:11,520 WARNING services.py:1882 -- WARNING: The + # object store is using /tmp instead of /dev/shm because /dev/shm has only + # 67108864 bytes available. This will harm performance! You may be able to + # free up space by deleting files in /dev/shm. If you are inside a Docker + # container, you can increase /dev/shm size by passing '--shm-size=1.97gb' to + # 'docker run' (or add it to the run_options list in a Ray cluster config). + # Make sure to set this to more than 0% of available RAM. + assert 'Overwriting previous Ray address' in str(e) \ + or 'WARNING: The object store is using /tmp instead of /dev/shm' in str(e) + + +@pytest.fixture +def ray_client_mode_setup(): + # Start 2 Ray clusters. + start_ray_cluster(ray_port=41012, client_server_port=21012, dashboard_port=9112) + time.sleep(1) + start_ray_cluster(ray_port=41011, client_server_port=21011, dashboard_port=9111) + + yield + fed_utils.start_command('ray stop --force') diff --git a/tests/without_ray_tests/test_tree_utils.py b/fed/tests/without_ray_tests/test_tree_utils.py similarity index 100% rename from tests/without_ray_tests/test_tree_utils.py rename to fed/tests/without_ray_tests/test_tree_utils.py diff --git a/tests/without_ray_tests/test_utils.py b/fed/tests/without_ray_tests/test_utils.py similarity index 100% rename from tests/without_ray_tests/test_utils.py rename to fed/tests/without_ray_tests/test_utils.py diff --git a/fed/utils.py b/fed/utils.py index 0a82bc0..b5450f2 100644 --- a/fed/utils.py +++ b/fed/utils.py @@ -15,6 +15,7 @@ import logging import re import sys +import subprocess import ray @@ -236,3 +237,19 @@ def validate_addresses(addresses: dict): isinstance(address, str) and address ), f'Address should be string but got {address}.' validate_address(address) + + +def start_command(command: str, timeout=60) : + """ + A util to start a shell command. + """ + process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + output, error = process.communicate(timeout=timeout) + if len(error) != 0: + raise RuntimeError( + f'Failed to start command [{command}], the error is:\n {error.decode()}') + return output diff --git a/test.sh b/test.sh index 7199064..312b53e 100755 --- a/test.sh +++ b/test.sh @@ -10,10 +10,11 @@ export RAY_TLS_SERVER_CERT="/tmp/rayfed/test-certs/server.crt" export RAY_TLS_SERVER_KEY="/tmp/rayfed/test-certs/server.key" export RAY_TLS_CA_CERT="/tmp/rayfed/test-certs/server.crt" -cd tests +cd fed/tests python3 -m pytest -v -s test_* python3 -m pytest -v -s serializations_tests/test_* python3 -m pytest -v -s without_ray_tests/test_* +python3 -m pytest -v -s client_mode_tests/test_* cd - echo "All tests finished."