From 2b8abf122790c4a0a6e573853410e1560851ef26 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Tue, 11 Jul 2023 11:05:27 +0800 Subject: [PATCH 1/4] Enhance the value error when missing a reqiured param. --- fed/_private/fed_call_holder.py | 3 +++ tests/test_api.py | 18 ++++++++++++++++++ tests/test_fed_get.py | 2 +- 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/fed/_private/fed_call_holder.py b/fed/_private/fed_call_holder.py index 5d84b9a..0773901 100644 --- a/fed/_private/fed_call_holder.py +++ b/fed/_private/fed_call_holder.py @@ -56,6 +56,9 @@ def options(self, **options): return self def internal_remote(self, *args, **kwargs): + if self._node_party is None or len(self._node_party) == 0: + raise ValueError("You should specify a party name on the fed actor.") + # Generate a new fed task id for this call. fed_task_id = get_global_context().next_seq_id() if self._party == self._node_party: diff --git a/tests/test_api.py b/tests/test_api.py index 55ac1db..34d0805 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -40,6 +40,24 @@ def test_fed_apis(): assert p_alice.exitcode == 0 +def test_miss_party_name_on_actor(): + compatible_utils.init_ray(address='local') + cluster = { + 'alice': {'address': '127.0.0.1:11012'}, + } + fed.init(cluster=cluster, party="alice") + + @fed.remote + class MyActor: + pass + + with pytest.raises(ValueError): + MyActor.remote() + + fed.shutdown() + ray.shutdown() + + if __name__ == "__main__": import sys diff --git a/tests/test_fed_get.py b/tests/test_fed_get.py index f49fc6a..714ef7c 100644 --- a/tests/test_fed_get.py +++ b/tests/test_fed_get.py @@ -55,7 +55,7 @@ def run(party): fed.init(cluster=cluster, party=party) epochs = 3 - alice_model = MyModel.party("alice").remote("alice", 2) + alice_model = MyModel.remote("alice", 2) bob_model = MyModel.party("bob").remote("bob", 4) all_mean_weights = [] From 992aac6819280dd0bfda9a6c93afde43714ea4b7 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Tue, 11 Jul 2023 11:08:24 +0800 Subject: [PATCH 2/4] revert Signed-off-by: Qing Wang --- tests/test_fed_get.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_fed_get.py b/tests/test_fed_get.py index 714ef7c..f49fc6a 100644 --- a/tests/test_fed_get.py +++ b/tests/test_fed_get.py @@ -55,7 +55,7 @@ def run(party): fed.init(cluster=cluster, party=party) epochs = 3 - alice_model = MyModel.remote("alice", 2) + alice_model = MyModel.party("alice").remote("alice", 2) bob_model = MyModel.party("bob").remote("bob", 4) all_mean_weights = [] From 5c2fba0f94d21773f8dd7ef996254f56fd9ee64e Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Tue, 11 Jul 2023 11:40:55 +0800 Subject: [PATCH 3/4] Address comments. Signed-off-by: Qing Wang --- fed/_private/fed_call_holder.py | 2 +- fed/api.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fed/_private/fed_call_holder.py b/fed/_private/fed_call_holder.py index 0773901..1fc339e 100644 --- a/fed/_private/fed_call_holder.py +++ b/fed/_private/fed_call_holder.py @@ -56,7 +56,7 @@ def options(self, **options): return self def internal_remote(self, *args, **kwargs): - if self._node_party is None or len(self._node_party) == 0: + if not self._node_party: raise ValueError("You should specify a party name on the fed actor.") # Generate a new fed task id for this call. diff --git a/fed/api.py b/fed/api.py index 8548c5f..460fbde 100644 --- a/fed/api.py +++ b/fed/api.py @@ -300,9 +300,9 @@ def options(self, **options): return self def remote(self, *args, **kwargs): - assert ( - self._node_party is not None - ), "A fed function should be specified within a party to execute." + if not self._node_party: + raise ValueError("You should specify a party name on the fed function.") + return self._fed_call_holder.internal_remote(*args, **kwargs) def _execute_impl(self, args, kwargs): From a1a72d676596dcd48d0e504fb77c7aee692bf410 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Tue, 11 Jul 2023 11:57:36 +0800 Subject: [PATCH 4/4] Fix CI Signed-off-by: Qing Wang --- tests/test_api.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/test_api.py b/tests/test_api.py index 34d0805..71774f2 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -41,21 +41,27 @@ def test_fed_apis(): def test_miss_party_name_on_actor(): - compatible_utils.init_ray(address='local') - cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - } - fed.init(cluster=cluster, party="alice") + def run(): + compatible_utils.init_ray(address='local') + cluster = { + 'alice': {'address': '127.0.0.1:11012'}, + } + fed.init(cluster=cluster, party="alice") - @fed.remote - class MyActor: - pass + @fed.remote + class MyActor: + pass - with pytest.raises(ValueError): - MyActor.remote() + with pytest.raises(ValueError): + MyActor.remote() - fed.shutdown() - ray.shutdown() + fed.shutdown() + ray.shutdown() + + p_alice = multiprocessing.Process(target=run) + p_alice.start() + p_alice.join() + assert p_alice.exitcode == 0 if __name__ == "__main__":