Skip to content

Commit

Permalink
[coll] Expose configuration. (#10983)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Nov 15, 2024
1 parent b835917 commit 6891f51
Show file tree
Hide file tree
Showing 20 changed files with 437 additions and 189 deletions.
14 changes: 14 additions & 0 deletions doc/python/python_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,17 @@ PySpark API
:members:
:inherited-members:
:show-inheritance:


Collective
----------

.. automodule:: xgboost.collective

.. autoclass:: xgboost.collective.Config

.. autofunction:: xgboost.collective.init

.. automodule:: xgboost.tracker

.. autoclass:: xgboost.tracker.RabitTracker
15 changes: 6 additions & 9 deletions doc/tutorials/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -536,25 +536,22 @@ Troubleshooting
- In some environments XGBoost might fail to resolve the IP address of the scheduler, a
symptom is user receiving ``OSError: [Errno 99] Cannot assign requested address`` error
during training. A quick workaround is to specify the address explicitly. To do that
dask config is used:
the collective :py:class:`~xgboost.collective.Config` is used:

.. versionadded:: 1.6.0
.. versionadded:: 3.0.0

.. code-block:: python
import dask
from distributed import Client
from xgboost import dask as dxgb
from xgboost.collective import Config
# let xgboost know the scheduler address
dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})
coll_cfg = Config(retry=1, timeout=20, tracker_host_ip="10.23.170.98", tracker_port=0)
with Client(scheduler_file="sched.json") as client:
reg = dxgb.DaskXGBRegressor()
# We can specify the port for XGBoost as well
with dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"}):
reg = dxgb.DaskXGBRegressor()
reg = dxgb.DaskXGBRegressor(coll_cfg=coll_cfg)
- Please note that XGBoost requires a different port than dask. By default, on a unix-like
system XGBoost uses the port 0 to find available ports, which may fail if a user is
Expand Down
65 changes: 56 additions & 9 deletions python-package/xgboost/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import logging
import os
import pickle
from dataclasses import dataclass
from enum import IntEnum, unique
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional, TypeAlias, Union

import numpy as np

Expand All @@ -15,7 +16,53 @@
LOGGER = logging.getLogger("[xgboost.collective]")


def init(**args: Any) -> None:
_ArgVals: TypeAlias = Optional[Union[int, str]]
_Args: TypeAlias = Dict[str, _ArgVals]


@dataclass
class Config:
"""User configuration for the communicator context. This is used for easier
integration with distributed frameworks. Users of the collective module can pass the
parameters directly into tracker and the communicator.
.. versionadded:: 3.0
Attributes
----------
retry : See `dmlc_retry` in :py:meth:`init`.
timeout :
See `dmlc_timeout` in :py:meth:`init`. This is only used for communicators, not
the tracker. They are different parameters since the timeout for tracker limits
only the time for starting and finalizing the communication group, whereas the
timeout for communicators limits the time used for collective operations.
tracker_host_ip : See :py:class:`~xgboost.tracker.RabitTracker`.
tracker_port : See :py:class:`~xgboost.tracker.RabitTracker`.
tracker_timeout : See :py:class:`~xgboost.tracker.RabitTracker`.
"""

retry: Optional[int] = None
timeout: Optional[int] = None

tracker_host_ip: Optional[str] = None
tracker_port: Optional[int] = None
tracker_timeout: Optional[int] = None

def get_comm_config(self, args: _Args) -> _Args:
"""Update the arguments for the communicator."""
if self.retry is not None:
args["dmlc_retry"] = self.retry
if self.timeout is not None:
args["dmlc_timeout"] = self.timeout
return args


def init(**args: _ArgVals) -> None:
"""Initialize the collective library with arguments.
Parameters
Expand All @@ -36,9 +83,7 @@ def init(**args: Any) -> None:
- dmlc_timeout: Timeout in seconds.
- dmlc_nccl_path: Path to load (dlopen) nccl for GPU-based communication.
Only applicable to the Federated communicator (use upper case for environment
variables, use lower case for runtime configuration):
Only applicable to the Federated communicator:
- federated_server_address: Address of the federated server.
- federated_world_size: Number of federated workers.
- federated_rank: Rank of the current worker.
Expand All @@ -47,6 +92,9 @@ def init(**args: Any) -> None:
- federated_client_key: Client key file path. Only needed for the SSL mode.
- federated_client_cert: Client certificate file path. Only needed for the SSL
mode.
Use upper case for environment variables, use lower case for runtime configuration.
"""
_check_call(_LIB.XGCommunicatorInit(make_jcargs(**args)))

Expand Down Expand Up @@ -117,7 +165,6 @@ def get_processor_name() -> str:
name_str = ctypes.c_char_p()
_check_call(_LIB.XGCommunicatorGetProcessorName(ctypes.byref(name_str)))
value = name_str.value
assert value
return py_str(value)


Expand Down Expand Up @@ -247,7 +294,7 @@ def signal_error() -> None:
class CommunicatorContext:
"""A context controlling collective communicator initialization and finalization."""

def __init__(self, **args: Any) -> None:
def __init__(self, **args: _ArgVals) -> None:
self.args = args
key = "dmlc_nccl_path"
if args.get(key, None) is not None:
Expand Down Expand Up @@ -275,12 +322,12 @@ def __init__(self, **args: Any) -> None:
except ImportError:
pass

def __enter__(self) -> Dict[str, Any]:
def __enter__(self) -> _Args:
init(**self.args)
assert is_distributed()
LOGGER.debug("-------------- communicator say hello ------------------")
return self.args

def __exit__(self, *args: List) -> None:
def __exit__(self, *args: Any) -> None:
finalize()
LOGGER.debug("--------------- communicator say bye ------------------")
7 changes: 4 additions & 3 deletions python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# pylint: disable= invalid-name, unused-import
# pylint: disable=invalid-name,unused-import
"""For compatibility and optional dependencies."""
import importlib.util
import logging
import sys
import types
from typing import Any, Dict, List, Optional, Sequence, cast
from typing import Any, Sequence, cast

import numpy as np

Expand All @@ -13,8 +13,9 @@
assert sys.version_info[0] == 3, "Python 2 is no longer supported."


def py_str(x: bytes) -> str:
def py_str(x: bytes | None) -> str:
"""convert c string back to python string"""
assert x is not None # ctypes might return None
return x.decode("utf-8") # type: ignore


Expand Down
Loading

0 comments on commit 6891f51

Please sign in to comment.