Skip to content

Commit

Permalink
ucx
Browse files Browse the repository at this point in the history
  • Loading branch information
luweizheng committed Sep 24, 2024
1 parent ab373ab commit 39de3ab
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
13 changes: 7 additions & 6 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,16 @@ jobs:
MODULE: ${{ matrix.module }}
if: ${{ matrix.module != 'gpu' }}
run: |
pip install numpy scipy cython coverage flaky
pip install numpy scipy cython coverage flaky ucxx-cu12
pip install -e ".[dev,extra]"
ucx_info -v
working-directory: ./python

- name: Install ucx dependencies
if: ${{ (matrix.module != 'gpu') && (matrix.os == 'ubuntu-latest') && (matrix.python-version != '3.11') }}
run: |
conda install -c conda-forge -c rapidsai ucx-proc=*=cpu ucx ucx-py
pip install -U numpy
# - name: Install ucx dependencies
# if: ${{ (matrix.module != 'gpu') && (matrix.os == 'ubuntu-latest') && (matrix.python-version != '3.11') }}
# run: |
# conda install -c conda-forge -c rapidsai ucx-proc=*=cpu ucx ucx-py
# pip install -U numpy

- name: Install fury
if: ${{ (matrix.module != 'gpu') && (matrix.os == 'ubuntu-latest') && (matrix.python-version == '3.9') }}
Expand Down
11 changes: 6 additions & 5 deletions python/xoscar/backends/communication/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .core import register_client, register_server
from .errors import ChannelClosed

ucp = lazy_import("ucp")
ucp = lazy_import("ucxx")
numba_cuda = lazy_import("numba.cuda")
rmm = lazy_import("rmm")

Expand Down Expand Up @@ -86,7 +86,7 @@ def _get_options(ucx_config: dict) -> Tuple[dict, dict]:
tls += ",cuda_copy"

if ucx_config.get("infiniband"): # pragma: no cover
tls = "rc," + tls
tls = "ib," + tls
if ucx_config.get("nvlink"): # pragma: no cover
tls += ",cuda_ipc"

Expand Down Expand Up @@ -177,7 +177,8 @@ def init(ucx_config: dict):
new_environ.update(envs)
os.environ = new_environ # type: ignore
try:
ucp.init(options=options, env_takes_precedence=True)
# let UCX determine the appropriate transports
ucp.init()
finally:
os.environ = original_environ

Expand Down Expand Up @@ -313,7 +314,7 @@ async def send_buffers(self, buffers: list, meta: Optional[_MessageBase] = None)
await self.ucp_endpoint.send(buf)
for buffer in buffers:
await self.ucp_endpoint.send(buffer)
except ucp.exceptions.UCXBaseException: # pragma: no cover
except ucp.exceptions.UCXError:: # pragma: no cover
self.abort()
raise ChannelClosed("While writing, the connection was closed")

Expand Down Expand Up @@ -516,7 +517,7 @@ async def connect(

try:
ucp_endpoint = await ucp.create_endpoint(host, port)
except ucp.exceptions.UCXBaseException as e: # pragma: no cover
except ucp.exceptions.UCXError as e: # pragma: no cover
raise ChannelClosed(
f"Connection closed before handshake completed, "
f"local address: {local_address}, dest address: {dest_address}"
Expand Down

0 comments on commit 39de3ab

Please sign in to comment.