Skip to content

Commit

Permalink
Revert unintended modifications.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouaihui committed Dec 29, 2023
1 parent 1e78372 commit db971f9
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 15 deletions.
11 changes: 0 additions & 11 deletions fed/_private/fed_call_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,6 @@
# Set config in the very beginning to avoid being overwritten by other packages.
logging.basicConfig(level=logging.INFO)

from fed._private.global_context import get_global_context
from fed.fed_object import FedObject
from fed.proxy.barriers import send
from fed.utils import resolve_dependencies

try:
from jax.tree_util import tree_flatten
except ImportError:
from fed.tree_util import tree_flatten

import fed.config as fed_config

logger = logging.getLogger(__name__)

Expand Down
8 changes: 6 additions & 2 deletions fed/_private/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,15 @@ def acquire_shutdown_flag(self) -> bool:


def init_global_context(
current_party: str, job_name: str, sending_failure_handler: Callable[[], None] = None
current_party: str,
job_name: str,
sending_failure_handler: Callable[[], None] = None,
) -> None:
global _global_context
if _global_context is None:
_global_context = GlobalContext(job_name, current_party, sending_failure_handler)
_global_context = GlobalContext(
job_name, current_party, sending_failure_handler
)


def get_global_context():
Expand Down
4 changes: 3 additions & 1 deletion fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def init(

fed_utils.validate_addresses(addresses)
init_global_context(
current_party=party, job_name=job_name, failure_handler=sending_failure_handler
current_party=party,
job_name=job_name,
sending_failure_handler=sending_failure_handler,
)
tls_config = {} if tls_config is None else tls_config
if tls_config:
Expand Down
3 changes: 2 additions & 1 deletion fed/proxy/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ray

import fed.config as fed_config
from fed.exceptions import FedRemoteError
from fed._private import constants
from fed._private.global_context import get_global_context
from fed.proxy.base_proxy import ReceiverProxy, SenderProxy, SenderReceiverProxy
Expand Down Expand Up @@ -223,7 +224,7 @@ async def get_data(self, src_party, upstream_seq_id, curr_seq_id):
data = await self._proxy_instance.get_data(
src_party, upstream_seq_id, curr_seq_id
)
if isinstance(data, Exception):
if isinstance(data, FedRemoteError):
logger.debug(
f"Receiving exception: {type(data)}, {data} from {src_party}, "
f"upstream_seq_id: {upstream_seq_id}, "
Expand Down

0 comments on commit db971f9

Please sign in to comment.