diff --git a/fed/_private/fed_call_holder.py b/fed/_private/fed_call_holder.py index 5d84b9a..1fc339e 100644 --- a/fed/_private/fed_call_holder.py +++ b/fed/_private/fed_call_holder.py @@ -56,6 +56,9 @@ def options(self, **options): return self def internal_remote(self, *args, **kwargs): + if not self._node_party: + raise ValueError("You should specify a party name on the fed actor.") + # Generate a new fed task id for this call. fed_task_id = get_global_context().next_seq_id() if self._party == self._node_party: diff --git a/fed/api.py b/fed/api.py index 8548c5f..460fbde 100644 --- a/fed/api.py +++ b/fed/api.py @@ -300,9 +300,9 @@ def options(self, **options): return self def remote(self, *args, **kwargs): - assert ( - self._node_party is not None - ), "A fed function should be specified within a party to execute." + if not self._node_party: + raise ValueError("You should specify a party name on the fed function.") + return self._fed_call_holder.internal_remote(*args, **kwargs) def _execute_impl(self, args, kwargs): diff --git a/tests/test_api.py b/tests/test_api.py index 55ac1db..71774f2 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -40,6 +40,30 @@ def test_fed_apis(): assert p_alice.exitcode == 0 +def test_miss_party_name_on_actor(): + def run(): + compatible_utils.init_ray(address='local') + cluster = { + 'alice': {'address': '127.0.0.1:11012'}, + } + fed.init(cluster=cluster, party="alice") + + @fed.remote + class MyActor: + pass + + with pytest.raises(ValueError): + MyActor.remote() + + fed.shutdown() + ray.shutdown() + + p_alice = multiprocessing.Process(target=run) + p_alice.start() + p_alice.join() + assert p_alice.exitcode == 0 + + if __name__ == "__main__": import sys