Skip to content

Commit

Permalink
FEAT: Support copy_to semantic, from buffer to buffer (#23)
Browse files Browse the repository at this point in the history
* support copy_to

* fix mypy

* fix mypy

* fix codecov

* add more tests

* fix __init__

* fix version

* support block_size option

* assert

* recover toml

* fix

* fix

* fix

* fix typing

* remove conflict string

* fix comments

* fix send err message
  • Loading branch information
ChengjieLi28 authored Jun 25, 2023
1 parent 56b832b commit 5dd2fa7
Show file tree
Hide file tree
Showing 18 changed files with 805 additions and 89 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ jobs:
pip install -e ".[dev,extra]"
working-directory: ./python

- name: Install ucx dependencies
if: ${{ (matrix.module != 'gpu') && (matrix.os == 'ubuntu-latest') && (matrix.python-version != '3.11') }}
run: |
conda install -c conda-forge -c rapidsai ucx-proc=*=cpu ucx ucx-py
- name: Install on GPU
if: ${{ matrix.module == 'gpu' }}
run: |
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
- id: isort
args: [--sp, python/setup.cfg]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.1
rev: v1.3.0
hooks:
- id: mypy
additional_dependencies: [tokenize-rt==3.2.0]
Expand Down
2 changes: 2 additions & 0 deletions python/xoscar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
has_actor,
destroy_actor,
kill_actor,
buffer_ref,
copy_to,
Actor,
StatelessActor,
create_actor_pool,
Expand Down
46 changes: 43 additions & 3 deletions python/xoscar/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

from collections import defaultdict
from numbers import Number
from typing import TYPE_CHECKING, Any, Dict, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from urllib.parse import urlparse

from .backend import get_backend
from .context import get_context
from .core import ActorRef, _Actor, _StatelessActor
from .core import ActorRef, BufferRef, _Actor, _StatelessActor

if TYPE_CHECKING:
from .backends.config import ActorPoolConfig
Expand Down Expand Up @@ -105,7 +105,7 @@ async def actor_ref(*args, **kwargs) -> ActorRef:
return await ctx.actor_ref(*args, **kwargs)


async def kill_actor(actor_ref):
async def kill_actor(actor_ref: ActorRef):
# TODO: explain the meaning of 'kill'
"""
Forcefully kill an actor.
Expand Down Expand Up @@ -161,6 +161,46 @@ async def create_actor_pool(
)


def buffer_ref(address: str, buffer: Any) -> BufferRef:
"""
Init buffer ref according address and buffer.
Parameters
----------
address
The address of the buffer.
buffer
CPU / GPU buffer. Need to support for slicing and retrieving the length.
Returns
----------
BufferRef obj.
"""
ctx = get_context()
return ctx.buffer_ref(address, buffer)


async def copy_to(
local_buffers: list,
remote_buffer_refs: List[BufferRef],
block_size: Optional[int] = None,
):
"""
Copy data from local buffers to remote buffers.
Parameters
----------
local_buffers
Local buffers.
remote_buffer_refs
Remote buffer refs.
block_size
Transfer block size when non-ucx
"""
ctx = get_context()
return await ctx.copy_to(local_buffers, remote_buffer_refs, block_size)


async def wait_actor_pool_recovered(address: str, main_pool_address: str | None = None):
"""
Wait until the specified actor pool has recovered from failure.
Expand Down
92 changes: 64 additions & 28 deletions python/xoscar/backends/communication/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
import logging
import os
import weakref
from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type

import cloudpickle
import numpy as np

from ...nvutils import get_cuda_context, get_index_and_uuid
from ...serialization import deserialize
from ...serialization.aio import BUFFER_SIZES_NAME, AioSerializer, get_header_length
from ...utils import classproperty, implements, lazy_import
from ...utils import classproperty, implements, is_cuda_buffer, lazy_import
from ..message import _MessageBase
from .base import Channel, ChannelType, Client, Server
from .core import register_client, register_server
from .errors import ChannelClosed
Expand Down Expand Up @@ -176,9 +177,7 @@ def init(ucx_config: dict):
new_environ.update(envs)
os.environ = new_environ # type: ignore
try:
ucp.init(
options=options, env_takes_precedence=True, blocking_progress_mode=False
)
ucp.init(options=options, env_takes_precedence=True)
finally:
os.environ = original_environ

Expand Down Expand Up @@ -237,6 +236,11 @@ def _close_channel(channel_ref: weakref.ReferenceType):
if channel is not None:
channel._closed = True

async def _serialize(self, message: Any) -> List[bytes]:
compress = self.compression or 0
serializer = AioSerializer(message, compress=compress)
return await serializer.run()

@property
@implements(Channel.type)
def type(self) -> int:
Expand All @@ -247,26 +251,8 @@ async def send(self, message: Any):
if self.closed:
raise ChannelClosed("UCX Endpoint is closed, unable to send message")

compress = self.compression or 0
serializer = AioSerializer(message, compress=compress)
buffers = await serializer.run()
try:
# It is necessary to first synchronize the default stream before start
# sending We synchronize the default stream because UCX is not
# stream-ordered and syncing the default stream will wait for other
# non-blocking CUDA streams. Note this is only sufficient if the memory
# being sent is not currently in use on non-blocking CUDA streams.
if any(hasattr(buf, "__cuda_array_interface__") for buf in buffers):
# has GPU buffer
synchronize_stream(0)

async with self._send_lock:
for buffer in buffers:
if buffer.nbytes if hasattr(buffer, "nbytes") else len(buffer) > 0:
await self.ucp_endpoint.send(buffer)
except ucp.exceptions.UCXBaseException: # pragma: no cover
self.abort()
raise ChannelClosed("While writing, the connection was closed")
buffers = await self._serialize(message)
return await self.send_buffers(buffers)

@implements(Channel.recv)
async def recv(self):
Expand Down Expand Up @@ -306,6 +292,48 @@ async def recv(self):
raise EOFError("Server closed already")
return deserialize(header, buffers)

async def send_buffers(self, buffers: list, meta: Optional[_MessageBase] = None):
try:
# It is necessary to first synchronize the default stream before start
# sending We synchronize the default stream because UCX is not
# stream-ordered and syncing the default stream will wait for other
# non-blocking CUDA streams. Note this is only sufficient if the memory
# being sent is not currently in use on non-blocking CUDA streams.
if any(is_cuda_buffer(buf) for buf in buffers):
# has GPU buffer
synchronize_stream(0)

meta_buffers = None
if meta:
meta_buffers = await self._serialize(meta)

async with self._send_lock:
if meta_buffers:
for buf in meta_buffers:
await self.ucp_endpoint.send(buf)
for buffer in buffers:
if buffer.nbytes if hasattr(buffer, "nbytes") else len(buffer) > 0:
await self.ucp_endpoint.send(buffer)
except ucp.exceptions.UCXBaseException: # pragma: no cover
self.abort()
raise ChannelClosed("While writing, the connection was closed")

async def recv_buffers(self, buffers: list):
async with self._recv_lock:
try:
for buffer in buffers:
await self.ucp_endpoint.recv(buffer)
except BaseException as e: # pragma: no cover
if not self._closed:
# In addition to UCX exceptions, may be CancelledError or another
# "low-level" exception. The only safe thing to do is to abort.
self.abort()
raise ChannelClosed(
f"Connection closed by writer.\nInner exception: {e!r}"
) from e
else:
raise EOFError("Server closed already")

def abort(self):
self._closed = True
if self.ucp_endpoint is not None:
Expand Down Expand Up @@ -390,7 +418,7 @@ async def serve_forever(client_ucp_endpoint: "ucp.Endpoint"): # type: ignore
client_ucp_endpoint, local_address=server.address
)
except ChannelClosed: # pragma: no cover
logger.debug("Connection closed before handshake completed")
logger.exception("Connection closed before handshake completed")
return

ucp_listener = ucp.create_listener(serve_forever, port=port)
Expand Down Expand Up @@ -442,6 +470,7 @@ async def stop(self):
await asyncio.gather(
*(channel.close() for channel in self._channels if not channel.closed)
)
self._channels = []
self._ucp_listener = None
self._closed.set()

Expand All @@ -456,6 +485,7 @@ class UCXClient(Client):
__slots__ = ()

scheme = UCXServer.scheme
channel: UCXChannel

@classmethod
def parse_config(cls, config: dict) -> dict:
Expand All @@ -477,9 +507,15 @@ async def connect(

try:
ucp_endpoint = await ucp.create_endpoint(host, port)
except ucp.exceptions.UCXBaseException: # pragma: no cover
raise ChannelClosed("Connection closed before handshake completed")
except ucp.exceptions.UCXBaseException as e: # pragma: no cover
raise ChannelClosed(
f"Connection closed before handshake completed, "
f"local address: {local_address}, dest address: {dest_address}"
) from e
channel = UCXChannel(
ucp_endpoint, local_address=local_address, dest_address=dest_address
)
return UCXClient(local_address, dest_address, channel)

async def send_buffers(self, buffers: list, meta: _MessageBase):
return await self.channel.send_buffers(buffers, meta)
Loading

0 comments on commit 5dd2fa7

Please sign in to comment.