Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add black linter. #187

Merged
merged 3 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@ jobs:
python3 -m virtualenv -p python3 py3
. py3/bin/activate
which python
pip install pytest torch cloudpickle cryptography
pip install ray==2.0.0
pip install flake8 # For code style checking
pip install black==23.1

- name: Lint
run: |
. py3/bin/activate
flake8
black -S --check --diff . --exclude='fed/grpc|py3'
13 changes: 13 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[settings]
# This is to make isort compatible with Black. See
# https://black.readthedocs.io/en/stable/the_black_code_style.html#how-black-wraps-lines.
line_length=88
profile=black
multi_line_output=3
include_trailing_comma=True
use_parentheses=True
float_to_top=True
filter_files=True

known_local_folder=fed
sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
10 changes: 6 additions & 4 deletions benchmarks/many_tiny_tasks_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ray
import time
import sys
import time

import ray

import fed


Expand Down Expand Up @@ -53,8 +55,8 @@ def main(party):
if i % 100 == 0:
print(f"Running {i}th call")
print(f"num calls: {num_calls}")
print("total time (ms) = ", (time.time() - start)*1000)
print("per task overhead (ms) =", (time.time() - start)*1000/num_calls)
print("total time (ms) = ", (time.time() - start) * 1000)
print("per task overhead (ms) =", (time.time() - start) * 1000 / num_calls)

fed.shutdown()
ray.shutdown()
Expand Down
9 changes: 4 additions & 5 deletions fed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from fed.api import (get, init, kill, remote,
shutdown)
from fed.proxy.barriers import recv, send
from fed.fed_object import FedObject
from fed.api import get, init, kill, remote, shutdown
from fed.exceptions import FedRemoteError
from fed.fed_object import FedObject
from fed.proxy.barriers import recv, send

__all__ = [
"get",
Expand All @@ -27,5 +26,5 @@
"recv",
"send",
"FedObject",
"FedRemoteError"
"FedRemoteError",
]
61 changes: 31 additions & 30 deletions fed/_private/compatible_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.

import abc
import ray
import fed._private.constants as fed_constants

import ray
import ray.experimental.internal_kv as ray_internal_kv

import fed._private.constants as fed_constants
from fed._private import constants


Expand All @@ -41,15 +42,14 @@ def _compare_version_strings(version1, version2):


def _ray_version_less_than_2_0_0():
""" Whther the current ray version is less 2.0.0.
"""
"""Whther the current ray version is less 2.0.0."""
return _compare_version_strings(
fed_constants.RAY_VERSION_2_0_0_STR, ray.__version__)
fed_constants.RAY_VERSION_2_0_0_STR, ray.__version__
)


def init_ray(address: str = None, **kwargs):
"""A compatible API to init Ray.
"""
"""A compatible API to init Ray."""
if address == 'local' and _ray_version_less_than_2_0_0():
# Ignore the `local` when ray < 2.0.0
ray.init(**kwargs)
Expand All @@ -58,28 +58,27 @@ def init_ray(address: str = None, **kwargs):


def _get_gcs_address_from_ray_worker():
"""A compatible API to get the gcs address from the ray worker module.
"""
"""A compatible API to get the gcs address from the ray worker module."""
try:
return ray._private.worker._global_node.gcs_address
except AttributeError:
return ray.worker._global_node.gcs_address


def wrap_kv_key(job_name, key: str):
"""Add an prefix to the key to avoid conflict with other jobs.
"""
assert isinstance(key, str), \
f"The key of KV data must be `str` type, got {type(key)}."
"""Add an prefix to the key to avoid conflict with other jobs."""
assert isinstance(
key, str
), f"The key of KV data must be `str` type, got {type(key)}."

return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format(
job_name, key)
return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format(job_name, key)


class AbstractInternalKv(abc.ABC):
""" An abstract class that represents for bridging Ray internal kv in
"""An abstract class that represents for bridging Ray internal kv in
both Ray client mode and non Ray client mode.
"""

def __init__(self) -> None:
pass

Expand All @@ -105,8 +104,8 @@ def reset(self):


class InternalKv(AbstractInternalKv):
"""The internal kv class for non Ray client mode.
"""
"""The internal kv class for non Ray client mode."""

def __init__(self, job_name: str) -> None:
super().__init__()
self._job_name = job_name
Expand All @@ -120,21 +119,18 @@ def initialize(self):
from ray._raylet import GcsClient

gcs_client = GcsClient(
address=_get_gcs_address_from_ray_worker(),
nums_reconnect_retry=10)
address=_get_gcs_address_from_ray_worker(), nums_reconnect_retry=10
)
return ray_internal_kv._initialize_internal_kv(gcs_client)

def put(self, k, v):
return ray_internal_kv._internal_kv_put(
wrap_kv_key(self._job_name, k), v)
return ray_internal_kv._internal_kv_put(wrap_kv_key(self._job_name, k), v)

def get(self, k):
return ray_internal_kv._internal_kv_get(
wrap_kv_key(self._job_name, k))
return ray_internal_kv._internal_kv_get(wrap_kv_key(self._job_name, k))

def delete(self, k):
return ray_internal_kv._internal_kv_del(
wrap_kv_key(self._job_name, k))
return ray_internal_kv._internal_kv_del(wrap_kv_key(self._job_name, k))

def reset(self):
return ray_internal_kv._internal_kv_reset()
Expand All @@ -144,8 +140,8 @@ def _ping(self):


class ClientModeInternalKv(AbstractInternalKv):
"""The internal kv class for Ray client mode.
"""
"""The internal kv class for Ray client mode."""

def __init__(self) -> None:
super().__init__()
self._internal_kv_actor = ray.get_actor("_INTERNAL_KV_ACTOR")
Expand Down Expand Up @@ -176,9 +172,13 @@ def _init_internal_kv(job_name):
global kv
if kv is None:
from ray._private.client_mode_hook import is_client_mode_enabled

if is_client_mode_enabled:
kv_actor = ray.remote(InternalKv).options(
name="_INTERNAL_KV_ACTOR").remote(job_name)
kv_actor = (
ray.remote(InternalKv)
.options(name="_INTERNAL_KV_ACTOR")
.remote(job_name)
)
response = kv_actor._ping.remote()
ray.get(response)
kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv(job_name)
Expand All @@ -192,6 +192,7 @@ def _clear_internal_kv():
kv.delete(constants.KEY_OF_JOB_CONFIG)
kv.reset()
from ray._private.client_mode_hook import is_client_mode_enabled

if is_client_mode_enabled:
_internal_kv_actor = ray.get_actor("_INTERNAL_KV_ACTOR")
ray.kill(_internal_kv_actor)
Expand Down
2 changes: 1 addition & 1 deletion fed/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

KEY_OF_CROSS_SILO_COMM_CONFIG_DICT = "CROSS_SILO_COMM_CONFIG_DICT"

RAYFED_LOG_FMT = "%(asctime)s.%(msecs)03d %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa
RAYFED_LOG_FMT = "%(asctime)s.%(msecs)03d %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa

RAYFED_DATE_FMT = "%Y-%m-%d %H:%M:%S"

Expand Down
17 changes: 8 additions & 9 deletions fed/_private/fed_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ray
from ray.util.client.common import ClientActorHandle

from fed._private.fed_call_holder import FedCallHolder
from fed.fed_object import FedObject

Expand Down Expand Up @@ -90,19 +91,17 @@ def _execute_impl(self, cls_args, cls_kwargs):
)

def _execute_remote_method(
self,
method_name,
options,
_ray_wrappered_method,
args,
kwargs,
self,
method_name,
options,
_ray_wrappered_method,
args,
kwargs,
):
num_returns = 1
if options and 'num_returns' in options:
num_returns = options['num_returns']
logger.debug(
f"Actor method call: {method_name}, num_returns: {num_returns}"
)
logger.debug(f"Actor method call: {method_name}, num_returns: {num_returns}")

return _ray_wrappered_method.options(
name='',
Expand Down
13 changes: 7 additions & 6 deletions fed/_private/fed_call_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

import logging

# Set config in the very beginning to avoid being overwritten by other packages.
logging.basicConfig(level=logging.INFO)

import fed.config as fed_config
from fed._private.global_context import get_global_context
from fed.proxy.barriers import send
from fed.fed_object import FedObject
from fed.utils import resolve_dependencies
from fed.proxy.barriers import send
from fed.tree_util import tree_flatten
import fed.config as fed_config
from fed.utils import resolve_dependencies

# Set config in the very beginning to avoid being overwritten by other packages.
logging.basicConfig(level=logging.INFO)


logger = logging.getLogger(__name__)

Expand Down
20 changes: 11 additions & 9 deletions fed/_private/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from fed.cleanup import CleanupManager
from typing import Callable
import threading
from typing import Callable

from fed.cleanup import CleanupManager


class GlobalContext:
def __init__(self, job_name: str,
current_party: str,
failure_handler: Callable[[], None]) -> None:
def __init__(
self, job_name: str, current_party: str, failure_handler: Callable[[], None]
) -> None:
self._job_name = job_name
self._seq_count = 0
self._failure_handler = failure_handler
self._atomic_shutdown_flag_lock = threading.Lock()
self._atomic_shutdown_flag = True
self._cleanup_manager = CleanupManager(
current_party, self.acquire_shutdown_flag)
current_party, self.acquire_shutdown_flag
)

def next_seq_id(self) -> int:
self._seq_count += 1
Expand Down Expand Up @@ -65,9 +67,9 @@ def acquire_shutdown_flag(self) -> bool:
_global_context = None


def init_global_context(current_party: str,
job_name: str,
failure_handler: Callable[[], None] = None) -> None:
def init_global_context(
current_party: str, job_name: str, failure_handler: Callable[[], None] = None
) -> None:
global _global_context
if _global_context is None:
_global_context = GlobalContext(job_name, current_party, failure_handler)
Expand Down
16 changes: 9 additions & 7 deletions fed/_private/message_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import threading
from collections import deque
import time
import logging

from collections import deque

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,7 +54,8 @@ def _loop():

if self._thread is None or not self._thread.is_alive():
logger.debug(
f"Starting new thread[{self._thread_name}] for message polling.")
f"Starting new thread[{self._thread_name}] for message polling."
)
self._queue = deque()
self._thread = threading.Thread(target=_loop, name=self._thread_name)
self._thread.start()
Expand All @@ -79,9 +79,11 @@ def stop(self):
If False: forcelly kill the for-loop sub-thread.
"""
if threading.current_thread() == self._thread:
logger.error(f"Can't stop the message queue in the message "
f"polling thread[{self._thread_name}]. Ignore it as this"
f"could bring unknown time sequence problems.")
logger.error(
f"Can't stop the message queue in the message "
f"polling thread[{self._thread_name}]. Ignore it as this"
f"could bring unknown time sequence problems."
)
raise RuntimeError("Thread can't kill itself")

# TODO(NKcqx): Force kill sub-thread by calling `._stop()` will
Expand Down
2 changes: 1 addition & 1 deletion fed/_private/serialization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.

import io

import cloudpickle

import fed.config as fed_config


_pickle_whitelist = None


Expand Down
Loading
Loading