diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index ee61446..a7b1e0f 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -15,7 +15,7 @@ import copy import logging import time -from typing import Dict +from typing import Any, Dict import ray @@ -458,11 +458,20 @@ def _start_sender_receiver_proxy( logger.info("Succeeded to create receiver proxy actor.") -def send(dest_party, data, upstream_seq_id, downstream_seq_id, is_error=False): +def send( + dest_party: str, + data: Any, + upstream_seq_id: int, + downstream_seq_id: int, + is_error: bool = False, + check_sending: bool = True, +): """ Args: is_error: Whether the `data` is an error object or not. Default is False. If True, the data will be sent to the error message queue. + check_sending: Whether to check the data sending. If true, the data will be + checked in the sending check loop. """ sender_proxy = ray.get_actor(sender_proxy_actor_name()) res = sender_proxy.send.remote( @@ -471,13 +480,14 @@ def send(dest_party, data, upstream_seq_id, downstream_seq_id, is_error=False): upstream_seq_id=upstream_seq_id, downstream_seq_id=downstream_seq_id, ) - get_global_context().get_cleanup_manager().push_to_sending( - res, dest_party, upstream_seq_id, downstream_seq_id, is_error - ) + if check_sending: + get_global_context().get_cleanup_manager().push_to_sending( + res, dest_party, upstream_seq_id, downstream_seq_id, is_error + ) return res -def recv(party: str, src_party: str, upstream_seq_id, curr_seq_id): +def recv(party: str, src_party: str, upstream_seq_id: int, curr_seq_id: int): assert party, 'Party can not be None.' receiver_proxy = ray.get_actor(receiver_proxy_actor_name()) return receiver_proxy.get_data.remote(src_party, upstream_seq_id, curr_seq_id) @@ -496,7 +506,9 @@ def ping_others(addresses: Dict[str, Dict], self_party: str, max_retries=3600): _party_ping_obj = {} # {$party_name: $ObjectRef} # Batch ping all the other parties for other in others: - _party_ping_obj[other] = send(other, b'data', 'ping', 'ping') + _party_ping_obj[other] = send( + other, b'data', 'ping', 'ping', check_sending=False + ) _, _unready = ray.wait(list(_party_ping_obj.values()), timeout=1) # Keep the unready party for the next ping.