Skip to content

Commit

Permalink
Fix formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouaihui committed Dec 13, 2023
1 parent 09019eb commit a284864
Show file tree
Hide file tree
Showing 56 changed files with 1,020 additions and 694 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ jobs:
- name: Lint
run: |
. py3/bin/activate
black --check --diff .
black --check --diff . --exclude fed/grpc
3 changes: 1 addition & 2 deletions .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ use_parentheses=True
float_to_top=True
filter_files=True

known_local_folder=ray
known_third_party=grpc
known_local_folder=fed
sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
18 changes: 10 additions & 8 deletions benchmarks/many_tiny_tasks_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ray
import time
import sys
import time

import ray

import fed


Expand All @@ -31,11 +33,11 @@ def aggr(self, val1, val2):


def main(party):
ray.init(address='local')
ray.init(address="local")

addresses = {
'alice': '127.0.0.1:11010',
'bob': '127.0.0.1:11011',
"alice": "127.0.0.1:11010",
"bob": "127.0.0.1:11011",
}
fed.init(addresses=addresses, party=party)

Expand All @@ -53,13 +55,13 @@ def main(party):
if i % 100 == 0:
print(f"Running {i}th call")
print(f"num calls: {num_calls}")
print("total time (ms) = ", (time.time() - start)*1000)
print("per task overhead (ms) =", (time.time() - start)*1000/num_calls)
print("total time (ms) = ", (time.time() - start) * 1000)
print("per task overhead (ms) =", (time.time() - start) * 1000 / num_calls)

fed.shutdown()
ray.shutdown()


if __name__ == "__main__":
assert len(sys.argv) == 2, 'Please run this script with party.'
assert len(sys.argv) == 2, "Please run this script with party."
main(sys.argv[1])
32 changes: 16 additions & 16 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,34 @@

# -- Project information

project = 'RayFed'
copyright = '2022, The RayFed Team'
author = 'The RayFed Authors'
project = "RayFed"
copyright = "2022, The RayFed Team"
author = "The RayFed Authors"

release = '0.1'
version = '0.1.0'
release = "0.1"
version = "0.1.0"

# -- General configuration

extensions = [
'sphinx.ext.duration',
'sphinx.ext.doctest',
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
"sphinx.ext.duration",
"sphinx.ext.doctest",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
]

intersphinx_mapping = {
'python': ('https://docs.python.org/3/', None),
'sphinx': ('https://www.sphinx-doc.org/en/master/', None),
"python": ("https://docs.python.org/3/", None),
"sphinx": ("https://www.sphinx-doc.org/en/master/", None),
}
intersphinx_disabled_domains = ['std']
intersphinx_disabled_domains = ["std"]

templates_path = ['_templates']
templates_path = ["_templates"]

# -- Options for HTML output

html_theme = 'sphinx_rtd_theme'
html_theme = "sphinx_rtd_theme"

# -- Options for EPUB output
epub_show_urls = 'footnote'
epub_show_urls = "footnote"
5 changes: 2 additions & 3 deletions fed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from fed.api import (get, init, kill, remote,
shutdown)
from fed.api import get, init, kill, remote, shutdown
from fed.proxy.barriers import recv, send
from fed.fed_object import FedObject
from fed.exceptions import FedRemoteError
Expand All @@ -27,5 +26,5 @@
"recv",
"send",
"FedObject",
"FedRemoteError"
"FedRemoteError",
]
67 changes: 34 additions & 33 deletions fed/_private/compatible_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.

import abc
import ray
import fed._private.constants as fed_constants

import ray
import ray.experimental.internal_kv as ray_internal_kv

import fed._private.constants as fed_constants
from fed._private import constants


Expand All @@ -26,8 +27,8 @@ def _compare_version_strings(version1, version2):
True if version1 is greater, and False if they're equal, and
False if version2 is greater.
"""
v1_list = version1.split('.')
v2_list = version2.split('.')
v1_list = version1.split(".")
v2_list = version2.split(".")
len1 = len(v1_list)
len2 = len(v2_list)

Expand All @@ -41,45 +42,43 @@ def _compare_version_strings(version1, version2):


def _ray_version_less_than_2_0_0():
""" Whther the current ray version is less 2.0.0.
"""
"""Whther the current ray version is less 2.0.0."""
return _compare_version_strings(
fed_constants.RAY_VERSION_2_0_0_STR, ray.__version__)
fed_constants.RAY_VERSION_2_0_0_STR, ray.__version__
)


def init_ray(address: str = None, **kwargs):
"""A compatible API to init Ray.
"""
if address == 'local' and _ray_version_less_than_2_0_0():
"""A compatible API to init Ray."""
if address == "local" and _ray_version_less_than_2_0_0():
# Ignore the `local` when ray < 2.0.0
ray.init(**kwargs)
else:
ray.init(address=address, **kwargs)


def _get_gcs_address_from_ray_worker():
"""A compatible API to get the gcs address from the ray worker module.
"""
"""A compatible API to get the gcs address from the ray worker module."""
try:
return ray._private.worker._global_node.gcs_address
except AttributeError:
return ray.worker._global_node.gcs_address


def wrap_kv_key(job_name, key: str):
"""Add an prefix to the key to avoid conflict with other jobs.
"""
assert isinstance(key, str), \
f"The key of KV data must be `str` type, got {type(key)}."
"""Add an prefix to the key to avoid conflict with other jobs."""
assert isinstance(
key, str
), f"The key of KV data must be `str` type, got {type(key)}."

return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format(
job_name, key)
return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format(job_name, key)


class AbstractInternalKv(abc.ABC):
""" An abstract class that represents for bridging Ray internal kv in
"""An abstract class that represents for bridging Ray internal kv in
both Ray client mode and non Ray client mode.
"""

def __init__(self) -> None:
pass

Expand All @@ -105,8 +104,8 @@ def reset(self):


class InternalKv(AbstractInternalKv):
"""The internal kv class for non Ray client mode.
"""
"""The internal kv class for non Ray client mode."""

def __init__(self, job_name: str) -> None:
super().__init__()
self._job_name = job_name
Expand All @@ -120,21 +119,18 @@ def initialize(self):
from ray._raylet import GcsClient

gcs_client = GcsClient(
address=_get_gcs_address_from_ray_worker(),
nums_reconnect_retry=10)
address=_get_gcs_address_from_ray_worker(), nums_reconnect_retry=10
)
return ray_internal_kv._initialize_internal_kv(gcs_client)

def put(self, k, v):
return ray_internal_kv._internal_kv_put(
wrap_kv_key(self._job_name, k), v)
return ray_internal_kv._internal_kv_put(wrap_kv_key(self._job_name, k), v)

def get(self, k):
return ray_internal_kv._internal_kv_get(
wrap_kv_key(self._job_name, k))
return ray_internal_kv._internal_kv_get(wrap_kv_key(self._job_name, k))

def delete(self, k):
return ray_internal_kv._internal_kv_del(
wrap_kv_key(self._job_name, k))
return ray_internal_kv._internal_kv_del(wrap_kv_key(self._job_name, k))

def reset(self):
return ray_internal_kv._internal_kv_reset()
Expand All @@ -144,8 +140,8 @@ def _ping(self):


class ClientModeInternalKv(AbstractInternalKv):
"""The internal kv class for Ray client mode.
"""
"""The internal kv class for Ray client mode."""

def __init__(self) -> None:
super().__init__()
self._internal_kv_actor = ray.get_actor("_INTERNAL_KV_ACTOR")
Expand Down Expand Up @@ -176,9 +172,13 @@ def _init_internal_kv(job_name):
global kv
if kv is None:
from ray._private.client_mode_hook import is_client_mode_enabled

if is_client_mode_enabled:
kv_actor = ray.remote(InternalKv).options(
name="_INTERNAL_KV_ACTOR").remote(job_name)
kv_actor = (
ray.remote(InternalKv)
.options(name="_INTERNAL_KV_ACTOR")
.remote(job_name)
)
response = kv_actor._ping.remote()
ray.get(response)
kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv(job_name)
Expand All @@ -192,6 +192,7 @@ def _clear_internal_kv():
kv.delete(constants.KEY_OF_JOB_CONFIG)
kv.reset()
from ray._private.client_mode_hook import is_client_mode_enabled

if is_client_mode_enabled:
_internal_kv_actor = ray.get_actor("_INTERNAL_KV_ACTOR")
ray.kill(_internal_kv_actor)
Expand Down
2 changes: 1 addition & 1 deletion fed/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

KEY_OF_CROSS_SILO_COMM_CONFIG_DICT = "CROSS_SILO_COMM_CONFIG_DICT"

RAYFED_LOG_FMT = "%(asctime)s.%(msecs)03d %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa
RAYFED_LOG_FMT = "%(asctime)s.%(msecs)03d %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa

RAYFED_DATE_FMT = "%Y-%m-%d %H:%M:%S"

Expand Down
23 changes: 11 additions & 12 deletions fed/_private/fed_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ray
from ray.util.client.common import ClientActorHandle

from fed._private.fed_call_holder import FedCallHolder
from fed.fed_object import FedObject

Expand Down Expand Up @@ -90,22 +91,20 @@ def _execute_impl(self, cls_args, cls_kwargs):
)

def _execute_remote_method(
self,
method_name,
options,
_ray_wrappered_method,
args,
kwargs,
self,
method_name,
options,
_ray_wrappered_method,
args,
kwargs,
):
num_returns = 1
if options and 'num_returns' in options:
num_returns = options['num_returns']
logger.debug(
f"Actor method call: {method_name}, num_returns: {num_returns}"
)
if options and "num_returns" in options:
num_returns = options["num_returns"]
logger.debug(f"Actor method call: {method_name}, num_returns: {num_returns}")

return _ray_wrappered_method.options(
name='',
name="",
num_returns=num_returns,
).remote(
*args,
Expand Down
19 changes: 10 additions & 9 deletions fed/_private/fed_call_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

import logging

# Set config in the very beginning to avoid being overwritten by other packages.
logging.basicConfig(level=logging.INFO)

import fed.config as fed_config
from fed._private.global_context import get_global_context
from fed.proxy.barriers import send
from fed.fed_object import FedObject
from fed.utils import resolve_dependencies
from fed.proxy.barriers import send
from fed.tree_util import tree_flatten
import fed.config as fed_config
from fed.utils import resolve_dependencies

# Set config in the very beginning to avoid being overwritten by other packages.
logging.basicConfig(level=logging.INFO)


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -97,10 +98,10 @@ def internal_remote(self, *args, **kwargs):
)
if (
self._options
and 'num_returns' in self._options
and self._options['num_returns'] > 1
and "num_returns" in self._options
and self._options["num_returns"] > 1
):
num_returns = self._options['num_returns']
num_returns = self._options["num_returns"]
return [
FedObject(self._node_party, fed_task_id, None, i)
for i in range(num_returns)
Expand Down
Loading

0 comments on commit a284864

Please sign in to comment.