From 39de3abbc9bf4d4a41753fb43ebd6dd406fd439a Mon Sep 17 00:00:00 2001 From: Lu Weizheng Date: Tue, 24 Sep 2024 15:00:11 +0800 Subject: [PATCH] ucx --- .github/workflows/python.yaml | 13 +++++++------ python/xoscar/backends/communication/ucx.py | 11 ++++++----- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 9b8c26b8..abae1a97 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -117,15 +117,16 @@ jobs: MODULE: ${{ matrix.module }} if: ${{ matrix.module != 'gpu' }} run: | - pip install numpy scipy cython coverage flaky + pip install numpy scipy cython coverage flaky ucxx-cu12 pip install -e ".[dev,extra]" + ucx_info -v working-directory: ./python - - name: Install ucx dependencies - if: ${{ (matrix.module != 'gpu') && (matrix.os == 'ubuntu-latest') && (matrix.python-version != '3.11') }} - run: | - conda install -c conda-forge -c rapidsai ucx-proc=*=cpu ucx ucx-py - pip install -U numpy + # - name: Install ucx dependencies + # if: ${{ (matrix.module != 'gpu') && (matrix.os == 'ubuntu-latest') && (matrix.python-version != '3.11') }} + # run: | + # conda install -c conda-forge -c rapidsai ucx-proc=*=cpu ucx ucx-py + # pip install -U numpy - name: Install fury if: ${{ (matrix.module != 'gpu') && (matrix.os == 'ubuntu-latest') && (matrix.python-version == '3.9') }} diff --git a/python/xoscar/backends/communication/ucx.py b/python/xoscar/backends/communication/ucx.py index 548bf6f8..d232fe1c 100644 --- a/python/xoscar/backends/communication/ucx.py +++ b/python/xoscar/backends/communication/ucx.py @@ -34,7 +34,7 @@ from .core import register_client, register_server from .errors import ChannelClosed -ucp = lazy_import("ucp") +ucp = lazy_import("ucxx") numba_cuda = lazy_import("numba.cuda") rmm = lazy_import("rmm") @@ -86,7 +86,7 @@ def _get_options(ucx_config: dict) -> Tuple[dict, dict]: tls += ",cuda_copy" if ucx_config.get("infiniband"): # pragma: no cover - tls = "rc," + tls + tls = "ib," + tls if ucx_config.get("nvlink"): # pragma: no cover tls += ",cuda_ipc" @@ -177,7 +177,8 @@ def init(ucx_config: dict): new_environ.update(envs) os.environ = new_environ # type: ignore try: - ucp.init(options=options, env_takes_precedence=True) + # let UCX determine the appropriate transports + ucp.init() finally: os.environ = original_environ @@ -313,7 +314,7 @@ async def send_buffers(self, buffers: list, meta: Optional[_MessageBase] = None) await self.ucp_endpoint.send(buf) for buffer in buffers: await self.ucp_endpoint.send(buffer) - except ucp.exceptions.UCXBaseException: # pragma: no cover + except ucp.exceptions.UCXError:: # pragma: no cover self.abort() raise ChannelClosed("While writing, the connection was closed") @@ -516,7 +517,7 @@ async def connect( try: ucp_endpoint = await ucp.create_endpoint(host, port) - except ucp.exceptions.UCXBaseException as e: # pragma: no cover + except ucp.exceptions.UCXError as e: # pragma: no cover raise ChannelClosed( f"Connection closed before handshake completed, " f"local address: {local_address}, dest address: {dest_address}"