Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correct get exeception code and some other minor fixes. #186

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 31 additions & 26 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,37 @@
import inspect
import logging
import signal
from typing import Any, Dict, List, Union, Callable
from typing import Any, Callable, Dict, List, Union

import cloudpickle
import ray
from ray.exceptions import RayError

import fed._private.compatible_utils as compatible_utils
import fed.config as fed_config
import fed.utils as fed_utils
from fed._private import constants
from fed._private.fed_actor import FedActorHandle
from fed._private.fed_call_holder import FedCallHolder
from fed.exceptions import FedRemoteError
from fed._private.global_context import (
init_global_context,
clear_global_context,
get_global_context,
clear_global_context
init_global_context,
)
from fed.config import CrossSiloMessageConfig
from fed.exceptions import FedRemoteError
from fed.fed_object import FedObject
from fed.proxy.barriers import (
ping_others,
recv,
send,
_start_receiver_proxy,
_start_sender_proxy,
_start_sender_receiver_proxy,
ping_others,
recv,
send,
set_proxy_actor_name,
)
from fed.proxy.base_proxy import SenderProxy, ReceiverProxy, SenderReceiverProxy
from fed.config import CrossSiloMessageConfig
from fed.fed_object import FedObject
from fed.proxy.base_proxy import ReceiverProxy, SenderProxy, SenderReceiverProxy
from fed.utils import is_ray_object_refs, setup_logger
from ray.exceptions import RayError

logger = logging.getLogger(__name__)

Expand All @@ -59,7 +59,8 @@ def _signal_handler(signum, frame):
logger.warning(
"Stop signal received (e.g. via SIGINT/Ctrl+C), "
"try to shutdown fed. Press CTRL+C "
"(or send SIGINT/SIGKILL/SIGTERM) to skip.")
"(or send SIGINT/SIGKILL/SIGTERM) to skip."
)
_shutdown(intended=False)


Expand Down Expand Up @@ -162,8 +163,9 @@ def init(
assert party in addresses, f"Party {party} is not in the addresses {addresses}."

fed_utils.validate_addresses(addresses)
init_global_context(current_party=party, job_name=job_name,
failure_handler=failure_handler)
init_global_context(
current_party=party, job_name=job_name, failure_handler=failure_handler
)
tls_config = {} if tls_config is None else tls_config
if tls_config:
assert (
Expand Down Expand Up @@ -196,20 +198,21 @@ def init(
logging_format=constants.RAYFED_LOG_FMT,
date_format=constants.RAYFED_DATE_FMT,
party_val=_get_party(job_name),
job_name=job_name
job_name=job_name,
)

logger.info(f'Started rayfed with {cluster_config}')
cross_silo_comm_config = CrossSiloMessageConfig.from_dict(cross_silo_comm_dict)
signal.signal(signal.SIGINT, _signal_handler)
get_global_context().get_cleanup_manager().start(
exit_on_sending_failure=cross_silo_comm_config.exit_on_sending_failure,
expose_error_trace=cross_silo_comm_config.expose_error_trace
expose_error_trace=cross_silo_comm_config.expose_error_trace,
)

if receiver_sender_proxy_cls is not None:
set_proxy_actor_name(
job_name, cross_silo_comm_dict.get("use_global_proxy", True), True)
job_name, cross_silo_comm_dict.get("use_global_proxy", True), True
)
_start_sender_receiver_proxy(
addresses=addresses,
party=party,
Expand All @@ -231,7 +234,8 @@ def init(

receiver_proxy_cls = GrpcReceiverProxy
set_proxy_actor_name(
job_name, cross_silo_comm_dict.get("use_global_proxy", True))
job_name, cross_silo_comm_dict.get("use_global_proxy", True)
)
_start_receiver_proxy(
addresses=addresses,
party=party,
Expand All @@ -244,8 +248,7 @@ def init(

if sender_proxy_cls is None:
logger.debug(
"No sender proxy class specified, use `GrpcSenderProxy` by "
"default."
"No sender proxy class specified, use `GrpcSenderProxy` by default."
)
from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy

Expand Down Expand Up @@ -281,12 +284,12 @@ def _shutdown(intended=True):
intended: (Optional) Whether this is a intended exit. If not
a "failure handler" will be triggered.
"""
if (get_global_context() is not None):
if get_global_context() is not None:
# Job has inited, can be shutdown
failure_handler = get_global_context().get_failure_handler()
compatible_utils._clear_internal_kv()
clear_global_context()
if (not intended and failure_handler is not None):
if not intended and failure_handler is not None:
failure_handler()
logger.info('Shutdowned rayfed.')

Expand Down Expand Up @@ -472,10 +475,12 @@ def get(
values = values[0]
return values
except RayError as e:
if isinstance(e.cause, FedRemoteError):
logger.warning("Encounter RemoteError happend in other parties"
f", prepare to exit, error message: {e.cause}")
if (get_global_context().acquire_shutdown_flag()):
if isinstance(e, FedRemoteError):
logger.warning(
"Encounter RemoteError happend in other parties"
f", prepare to exit, error message: {e.cause}"
)
if get_global_context().acquire_shutdown_flag():
_shutdown(intended=False)
raise e

Expand Down
25 changes: 14 additions & 11 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
are mutable.
"""

import fed._private.compatible_utils as compatible_utils
import fed._private.constants as fed_constants
import cloudpickle
import json

from typing import Dict, List, Optional
from dataclasses import dataclass, fields
from typing import Dict, List, Optional

import cloudpickle

import fed._private.compatible_utils as compatible_utils
import fed._private.constants as fed_constants


class ClusterConfig:
Expand Down Expand Up @@ -48,24 +49,26 @@ def cross_silo_comm_config_dict(self) -> Dict:
_job_config = None


def get_cluster_config(job_name: str = None):
def get_cluster_config(job_name: str = None) -> ClusterConfig:
"""This function is not thread safe to use."""
global _cluster_config
if _cluster_config is None:
assert job_name is not None, \
"Initializing internal kv need to provide job_name."
assert (
job_name is not None
), "Initializing internal kv need to provide job_name."
compatible_utils._init_internal_kv(job_name)
raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_CLUSTER_CONFIG)
_cluster_config = ClusterConfig(raw_dict)
return _cluster_config


def get_job_config(job_name: str = None):
def get_job_config(job_name: str = None) -> JobConfig:
"""This config still acts like cluster config for now"""
global _job_config
if _job_config is None:
assert job_name is not None, \
"Initializing internal kv need to provide job_name."
assert (
job_name is not None
), "Initializing internal kv need to provide job_name."
compatible_utils._init_internal_kv(job_name)
raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_JOB_CONFIG)
_job_config = JobConfig(raw_dict)
Expand Down
4 changes: 2 additions & 2 deletions fed/proxy/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def __init__(
self._addresses = addresses
self._party = party
self._tls_config = tls_config
job_config = fed_config.get_job_config()
job_config = fed_config.get_job_config(job_name=job_name)
cross_silo_comm_config = job_config.cross_silo_comm_config_dict
self._proxy_instance = proxy_cls(
addresses, party, tls_config, cross_silo_comm_config
Expand Down Expand Up @@ -397,7 +397,7 @@ def send(
except Exception as e:
logger.error(f'Failed to {send_log_msg}, error: {e}')
return False
logger.debug(f"Succeeded to send {send_log_msg}. Response is {response}")
logger.debug(f"Succeeded to {send_log_msg}. Response is {response}")
return True # True indicates it's sent successfully.

def _get_stats(self):
Expand Down
Loading