From 6891f5109f479d7ff6b480aa43fc4b7795549736 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 15 Nov 2024 18:23:32 +0800 Subject: [PATCH] [coll] Expose configuration. (#10983) --- doc/python/python_api.rst | 14 ++ doc/tutorials/dask.rst | 15 +- python-package/xgboost/collective.py | 65 ++++++- python-package/xgboost/compat.py | 7 +- python-package/xgboost/dask/__init__.py | 172 ++++++++++++------ python-package/xgboost/dask/utils.py | 82 ++++++++- python-package/xgboost/spark/core.py | 26 +-- python-package/xgboost/spark/utils.py | 17 +- python-package/xgboost/testing/dask.py | 8 +- python-package/xgboost/testing/federated.py | 12 +- python-package/xgboost/tracker.py | 46 ++++- src/collective/comm_group.cc | 1 + tests/ci_build/lint_python.py | 2 + tests/python/test_collective.py | 16 +- tests/python/test_tracker.py | 23 ++- .../test_gpu_external_memory.py | 12 +- .../test_gpu_with_dask/test_gpu_with_dask.py | 25 +-- .../test_with_dask/test_external_memory.py | 11 +- .../test_with_dask/test_with_dask.py | 59 +++--- .../test_with_spark/test_spark_local.py | 13 +- 20 files changed, 437 insertions(+), 189 deletions(-) diff --git a/doc/python/python_api.rst b/doc/python/python_api.rst index 11de9385b62e..a8999e119ab4 100644 --- a/doc/python/python_api.rst +++ b/doc/python/python_api.rst @@ -192,3 +192,17 @@ PySpark API :members: :inherited-members: :show-inheritance: + + +Collective +---------- + +.. automodule:: xgboost.collective + +.. autoclass:: xgboost.collective.Config + +.. autofunction:: xgboost.collective.init + +.. automodule:: xgboost.tracker + +.. autoclass:: xgboost.tracker.RabitTracker \ No newline at end of file diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index 43c27e786b8f..6e68d83a0083 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -536,25 +536,22 @@ Troubleshooting - In some environments XGBoost might fail to resolve the IP address of the scheduler, a symptom is user receiving ``OSError: [Errno 99] Cannot assign requested address`` error during training. A quick workaround is to specify the address explicitly. To do that - dask config is used: + the collective :py:class:`~xgboost.collective.Config` is used: - .. versionadded:: 1.6.0 + .. versionadded:: 3.0.0 .. code-block:: python import dask from distributed import Client from xgboost import dask as dxgb + from xgboost.collective import Config + # let xgboost know the scheduler address - dask.config.set({"xgboost.scheduler_address": "192.0.0.100"}) + coll_cfg = Config(retry=1, timeout=20, tracker_host_ip="10.23.170.98", tracker_port=0) with Client(scheduler_file="sched.json") as client: - reg = dxgb.DaskXGBRegressor() - - # We can specify the port for XGBoost as well - with dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"}): - reg = dxgb.DaskXGBRegressor() - + reg = dxgb.DaskXGBRegressor(coll_cfg=coll_cfg) - Please note that XGBoost requires a different port than dask. By default, on a unix-like system XGBoost uses the port 0 to find available ports, which may fail if a user is diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index 0f3feeeb4a6d..715853d0ab54 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -4,8 +4,9 @@ import logging import os import pickle +from dataclasses import dataclass from enum import IntEnum, unique -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional, TypeAlias, Union import numpy as np @@ -15,7 +16,53 @@ LOGGER = logging.getLogger("[xgboost.collective]") -def init(**args: Any) -> None: +_ArgVals: TypeAlias = Optional[Union[int, str]] +_Args: TypeAlias = Dict[str, _ArgVals] + + +@dataclass +class Config: + """User configuration for the communicator context. This is used for easier + integration with distributed frameworks. Users of the collective module can pass the + parameters directly into tracker and the communicator. + + .. versionadded:: 3.0 + + Attributes + ---------- + retry : See `dmlc_retry` in :py:meth:`init`. + + timeout : + See `dmlc_timeout` in :py:meth:`init`. This is only used for communicators, not + the tracker. They are different parameters since the timeout for tracker limits + only the time for starting and finalizing the communication group, whereas the + timeout for communicators limits the time used for collective operations. + + tracker_host_ip : See :py:class:`~xgboost.tracker.RabitTracker`. + + tracker_port : See :py:class:`~xgboost.tracker.RabitTracker`. + + tracker_timeout : See :py:class:`~xgboost.tracker.RabitTracker`. + + """ + + retry: Optional[int] = None + timeout: Optional[int] = None + + tracker_host_ip: Optional[str] = None + tracker_port: Optional[int] = None + tracker_timeout: Optional[int] = None + + def get_comm_config(self, args: _Args) -> _Args: + """Update the arguments for the communicator.""" + if self.retry is not None: + args["dmlc_retry"] = self.retry + if self.timeout is not None: + args["dmlc_timeout"] = self.timeout + return args + + +def init(**args: _ArgVals) -> None: """Initialize the collective library with arguments. Parameters @@ -36,9 +83,7 @@ def init(**args: Any) -> None: - dmlc_timeout: Timeout in seconds. - dmlc_nccl_path: Path to load (dlopen) nccl for GPU-based communication. - Only applicable to the Federated communicator (use upper case for environment - variables, use lower case for runtime configuration): - + Only applicable to the Federated communicator: - federated_server_address: Address of the federated server. - federated_world_size: Number of federated workers. - federated_rank: Rank of the current worker. @@ -47,6 +92,9 @@ def init(**args: Any) -> None: - federated_client_key: Client key file path. Only needed for the SSL mode. - federated_client_cert: Client certificate file path. Only needed for the SSL mode. + + Use upper case for environment variables, use lower case for runtime configuration. + """ _check_call(_LIB.XGCommunicatorInit(make_jcargs(**args))) @@ -117,7 +165,6 @@ def get_processor_name() -> str: name_str = ctypes.c_char_p() _check_call(_LIB.XGCommunicatorGetProcessorName(ctypes.byref(name_str))) value = name_str.value - assert value return py_str(value) @@ -247,7 +294,7 @@ def signal_error() -> None: class CommunicatorContext: """A context controlling collective communicator initialization and finalization.""" - def __init__(self, **args: Any) -> None: + def __init__(self, **args: _ArgVals) -> None: self.args = args key = "dmlc_nccl_path" if args.get(key, None) is not None: @@ -275,12 +322,12 @@ def __init__(self, **args: Any) -> None: except ImportError: pass - def __enter__(self) -> Dict[str, Any]: + def __enter__(self) -> _Args: init(**self.args) assert is_distributed() LOGGER.debug("-------------- communicator say hello ------------------") return self.args - def __exit__(self, *args: List) -> None: + def __exit__(self, *args: Any) -> None: finalize() LOGGER.debug("--------------- communicator say bye ------------------") diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index 9a7d1ce83dcf..3cffcaa2585c 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -1,10 +1,10 @@ -# pylint: disable= invalid-name, unused-import +# pylint: disable=invalid-name,unused-import """For compatibility and optional dependencies.""" import importlib.util import logging import sys import types -from typing import Any, Dict, List, Optional, Sequence, cast +from typing import Any, Sequence, cast import numpy as np @@ -13,8 +13,9 @@ assert sys.version_info[0] == 3, "Python 2 is no longer supported." -def py_str(x: bytes) -> str: +def py_str(x: bytes | None) -> str: """convert c string back to python string""" + assert x is not None # ctypes might return None return x.decode("utf-8") # type: ignore diff --git a/python-package/xgboost/dask/__init__.py b/python-package/xgboost/dask/__init__.py index a946d86b6228..f7329de71887 100644 --- a/python-package/xgboost/dask/__init__.py +++ b/python-package/xgboost/dask/__init__.py @@ -20,10 +20,33 @@ Optional dask configuration =========================== -- **xgboost.scheduler_address**: Specify the scheduler address, see :ref:`tracker-ip`. +- **coll_cfg**: + Specify the scheduler address along with communicator configurations. This can be + used as a replacement of the existing global Dask configuration + `xgboost.scheduler_address` (see below). See :ref:`tracker-ip` for more info. The + `tracker_host_ip` should specify the IP address of the Dask scheduler node. + + .. versionadded:: 3.0.0 + + .. code-block:: python + + from xgboost import dask as dxgb + from xgboost.collective import Config + + coll_cfg = Config( + retry=1, timeout=20, tracker_host_ip="10.23.170.98", tracker_port=0 + ) + + clf = dxgb.DaskXGBClassifier(coll_cfg=coll_cfg) + # or + dxgb.train(client, {}, Xy, num_boost_round=10, coll_cfg=coll_cfg) + +- **xgboost.scheduler_address**: Specify the scheduler address .. versionadded:: 1.6.0 + .. deprecated:: 3.0.0 + .. code-block:: python dask.config.set({"xgboost.scheduler_address": "192.0.0.100"}) @@ -50,6 +73,7 @@ Sequence, Set, Tuple, + TypeAlias, TypedDict, TypeVar, Union, @@ -62,11 +86,14 @@ from dask import bag as db from dask import dataframe as dd -from xgboost import collective, config -from xgboost._typing import _T, FeatureNames, FeatureTypes, IterationRange -from xgboost.callback import TrainingCallback -from xgboost.compat import DataFrame, concat, lazy_isinstance -from xgboost.core import ( +from .. import collective, config +from .._typing import _T, FeatureNames, FeatureTypes, IterationRange +from ..callback import TrainingCallback +from ..collective import Config as CollConfig +from ..collective import _Args as CollArgs +from ..collective import _ArgVals as CollArgsVals +from ..compat import DataFrame, concat, lazy_isinstance +from ..core import ( Booster, DataIter, DMatrix, @@ -78,8 +105,8 @@ _deprecate_positional_args, _expect, ) -from xgboost.data import _is_cudf_ser, _is_cupy_alike -from xgboost.sklearn import ( +from ..data import _is_cudf_ser, _is_cupy_alike +from ..sklearn import ( XGBClassifier, XGBClassifierBase, XGBModel, @@ -93,13 +120,12 @@ _wrap_evaluation_matrices, xgboost_model_doc, ) -from xgboost.tracker import RabitTracker -from xgboost.training import train as worker_train +from ..tracker import RabitTracker +from ..training import train as worker_train +from .utils import get_address_from_user, get_n_threads -from .utils import get_n_threads - -_DaskCollection = Union[da.Array, dd.DataFrame, dd.Series] -_DataT = Union[da.Array, dd.DataFrame] # do not use series as predictor +_DaskCollection: TypeAlias = Union[da.Array, dd.DataFrame, dd.Series] +_DataT: TypeAlias = Union[da.Array, dd.DataFrame] # do not use series as predictor TrainReturnT = TypedDict( "TrainReturnT", { @@ -149,8 +175,9 @@ def _try_start_tracker( n_workers: int, addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]], -) -> Dict[str, Union[int, str]]: - env: Dict[str, Union[int, str]] = {} + timeout: Optional[int], +) -> CollArgs: + env: CollArgs = {} try: if isinstance(addrs[0], tuple): host_ip = addrs[0][0] @@ -160,15 +187,20 @@ def _try_start_tracker( host_ip=host_ip, port=port, sortby="task", + timeout=0 if timeout is None else timeout, ) else: addr = addrs[0] assert isinstance(addr, str) or addr is None rabit_tracker = RabitTracker( - n_workers=n_workers, host_ip=addr, sortby="task" + n_workers=n_workers, + host_ip=addr, + sortby="task", + timeout=0 if timeout is None else timeout, ) rabit_tracker.start() + # No timeout since we don't want to abort the training thread = Thread(target=rabit_tracker.wait_for) thread.daemon = True thread.start() @@ -183,7 +215,7 @@ def _try_start_tracker( str(addrs[1]), str(e), ) - env = _try_start_tracker(n_workers, addrs[1:]) + env = _try_start_tracker(n_workers, addrs[1:], timeout) return env @@ -192,17 +224,19 @@ def _start_tracker( n_workers: int, addr_from_dask: Optional[str], addr_from_user: Optional[Tuple[str, int]], -) -> Dict[str, Union[int, str]]: + timeout: Optional[int], +) -> CollArgs: """Start Rabit tracker, recurse to try different addresses.""" - env = _try_start_tracker(n_workers, [addr_from_user, addr_from_dask]) + env = _try_start_tracker(n_workers, [addr_from_user, addr_from_dask], timeout) return env class CommunicatorContext(collective.CommunicatorContext): """A context controlling collective communicator initialization and finalization.""" - def __init__(self, **args: Any) -> None: + def __init__(self, **args: CollArgsVals) -> None: super().__init__(**args) + worker = distributed.get_worker() with distributed.worker_client() as client: info = client.scheduler_info() @@ -223,7 +257,7 @@ def dconcat(value: Sequence[_T]) -> _T: return dd.multi.concat(list(value), axis=0) -def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Client": +def _get_client(client: Optional["distributed.Client"]) -> "distributed.Client": """Simple wrapper around testing None.""" if not isinstance(client, (type(distributed.get_client()), type(None))): raise TypeError( @@ -284,7 +318,7 @@ def __init__( feature_weights: Optional[_DaskCollection] = None, enable_categorical: bool = False, ) -> None: - client = _xgb_get_client(client) + client = _get_client(client) self.feature_names = feature_names self.feature_types = feature_types @@ -363,8 +397,9 @@ def check_columns(parts: numpy.ndarray) -> None: ) def to_delayed(d: _DaskCollection) -> List[Delayed]: - """Breaking data into partitions, a trick borrowed from dask_xgboost. `to_delayed` - downgrades high-level objects into numpy or pandas equivalents . + """Breaking data into partitions, a trick borrowed from + dask_xgboost. `to_delayed` downgrades high-level objects into numpy or + pandas equivalents. """ d = client.persist(d) @@ -496,7 +531,7 @@ async def map_worker_partitions( """Map a function onto partitions of each worker.""" # Note for function purity: # XGBoost is sensitive to data partition and uses random number generator. - client = _xgb_get_client(client) + client = _get_client(client) futures = [] for addr in workers: args = [] @@ -835,7 +870,10 @@ def _dmatrix_from_list_of_parts(is_quantile: bool, **kwargs: Any) -> DMatrix: async def _get_rabit_args( - n_workers: int, dconfig: Optional[Dict[str, Any]], client: "distributed.Client" + client: "distributed.Client", + n_workers: int, + dconfig: Optional[Dict[str, Any]] = None, + coll_cfg: Optional[CollConfig] = None, ) -> Dict[str, Union[str, int]]: """Get rabit context arguments from data distribution in DaskDMatrix.""" # There are 3 possible different addresses: @@ -843,23 +881,12 @@ async def _get_rabit_args( # 2. Guessed by xgboost `get_host_ip` function # 3. From dask scheduler # We try 1 and 3 if 1 is available, otherwise 2 and 3. - valid_config = ["scheduler_address"] + # See if user config is available + coll_cfg = CollConfig() if coll_cfg is None else coll_cfg host_ip: Optional[str] = None port: int = 0 - if dconfig is not None: - for k in dconfig: - if k not in valid_config: - raise ValueError(f"Unknown configuration: {k}") - host_ip = dconfig.get("scheduler_address", None) - if host_ip is not None and host_ip.startswith("[") and host_ip.endswith("]"): - # convert dask bracket format to proper IPv6 address. - host_ip = host_ip[1:-1] - if host_ip is not None: - try: - host_ip, port = distributed.comm.get_address_host_port(host_ip) - except ValueError: - pass + host_ip, port = get_address_from_user(dconfig, coll_cfg) if host_ip is not None: user_addr = (host_ip, port) @@ -874,9 +901,11 @@ async def _get_rabit_args( except Exception: # pylint: disable=broad-except sched_addr = None + # We assume the scheduler is a fair process and run the tracker there. env = await client.run_on_scheduler( - _start_tracker, n_workers, sched_addr, user_addr + _start_tracker, n_workers, sched_addr, user_addr, coll_cfg.tracker_timeout ) + env = coll_cfg.get_comm_config(env) return env @@ -924,9 +953,12 @@ def _get_dmatrices( evals_name: Sequence[str], n_threads: int, ) -> Tuple[DMatrix, List[Tuple[DMatrix, str]]]: + # Create training DMatrix Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads) + # Create evaluation DMatrices evals: List[Tuple[DMatrix, str]] = [] for i, ref in enumerate(refs): + # Same DMatrix as the training if evals_id[i] == train_id: evals.append((Xy, evals_name[i])) continue @@ -960,17 +992,20 @@ async def _train_async( xgb_model: Optional[Booster], callbacks: Optional[Sequence[TrainingCallback]], custom_metric: Optional[Metric], + coll_cfg: Optional[CollConfig], ) -> Optional[TrainReturnT]: workers = _get_workers_from_data(dtrain, evals) await _check_workers_are_alive(workers, client) - _rabit_args = await _get_rabit_args(len(workers), dconfig, client) + coll_args = await _get_rabit_args( + client, len(workers), dconfig=dconfig, coll_cfg=coll_cfg + ) _check_distributed_params(params) # This function name is displayed in the Dask dashboard task status, let's make it # clear that it's XGBoost training. def do_train( # pylint: disable=too-many-positional-arguments parameters: Dict, - rabit_args: Dict[str, Union[str, int]], + coll_args: Dict[str, Union[str, int]], train_id: int, evals_name: List[str], evals_id: List[int], @@ -984,7 +1019,7 @@ def do_train( # pylint: disable=too-many-positional-arguments local_history: TrainingCallback.EvalsLog = {} - with CommunicatorContext(**rabit_args), config.config_context(**global_config): + with CommunicatorContext(**coll_args), config.config_context(**global_config): Xy, evals = _get_dmatrices( train_ref, train_id, @@ -1035,7 +1070,7 @@ def do_train( # pylint: disable=too-many-positional-arguments do_train, # extra function parameters params, - _rabit_args, + coll_args, id(dtrain), evals_name, evals_id, @@ -1061,6 +1096,7 @@ def train( # pylint: disable=unused-argument verbose_eval: Union[int, bool] = True, callbacks: Optional[Sequence[TrainingCallback]] = None, custom_metric: Optional[Metric] = None, + coll_cfg: Optional[CollConfig] = None, ) -> Any: """Train XGBoost model. @@ -1078,6 +1114,10 @@ def train( # pylint: disable=unused-argument Specify the dask client used for training. Use default client returned from dask if it's set to None. + coll_cfg : + Configuration for the communicator used during training. See + :py:class:`~xgboost.collective.Config`. + Returns ------- results: dict @@ -1091,7 +1131,7 @@ def train( # pylint: disable=unused-argument 'eval': {'logloss': ['0.480385', '0.357756']}}} """ - client = _xgb_get_client(client) + client = _get_client(client) args = locals() return client.sync( _train_async, @@ -1470,7 +1510,7 @@ def predict( # pylint: disable=unused-argument shape. """ - client = _xgb_get_client(client) + client = _get_client(client) return client.sync(_predict_async, global_config=config.get_config(), **locals()) @@ -1487,7 +1527,7 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches base_margin: Optional[_DaskCollection], strict_shape: bool, ) -> _DaskCollection: - client = _xgb_get_client(client) + client = _get_client(client) booster = await _get_model_future(client, model) if not isinstance(data, (da.Array, dd.DataFrame)): raise TypeError(_expect([da.Array, dd.DataFrame], type(data))) @@ -1592,7 +1632,7 @@ def inplace_predict( # pylint: disable=unused-argument shape. """ - client = _xgb_get_client(client) + client = _get_client(client) # When used in asynchronous environment, the `client` object should have # `asynchronous` attribute as True. When invoked by the skl interface, it's # responsible for setting up the client. @@ -1647,6 +1687,11 @@ class DaskScikitLearnBase(XGBModel): _client = None + def __init__(self, *, coll_cfg: Optional[CollConfig] = None, **kwargs: Any) -> None: + super().__init__(**kwargs) + + self.coll_cfg = coll_cfg + async def _predict_async( self, data: _DataT, @@ -1750,13 +1795,13 @@ def __getstate__(self) -> Dict: @property def client(self) -> "distributed.Client": - """The dask client used in this model. The `Client` object can not be serialized for - transmission, so if task is launched from a worker instead of directly from the - client process, this attribute needs to be set at that worker. + """The dask client used in this model. The `Client` object can not be + serialized for transmission, so if task is launched from a worker instead of + directly from the client process, this attribute needs to be set at that worker. """ - client = _xgb_get_client(self._client) + client = _get_client(self._client) return client @client.setter @@ -1855,6 +1900,7 @@ async def _fit_async( verbose_eval=verbose, early_stopping_rounds=self.early_stopping_rounds, callbacks=self.callbacks, + coll_cfg=self.coll_cfg, xgb_model=model, ) self._Booster = results["booster"] @@ -1963,6 +2009,7 @@ async def _fit_async( verbose_eval=verbose, early_stopping_rounds=self.early_stopping_rounds, callbacks=self.callbacks, + coll_cfg=self.coll_cfg, xgb_model=model, ) self._Booster = results["booster"] @@ -2080,10 +2127,16 @@ def _argmax(x: Any) -> Any: ) class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): @_deprecate_positional_args - def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any): + def __init__( + self, + *, + objective: str = "rank:pairwise", + coll_cfg: Optional[CollConfig] = None, + **kwargs: Any, + ) -> None: if callable(objective): raise ValueError("Custom objective function not supported by XGBRanker.") - super().__init__(objective=objective, **kwargs) + super().__init__(objective=objective, coll_cfg=coll_cfg, **kwargs) async def _fit_async( self, @@ -2103,7 +2156,7 @@ async def _fit_async( xgb_model: Optional[Union[XGBModel, Booster]], feature_weights: Optional[_DaskCollection], ) -> "DaskXGBRanker": - msg = "Use `qid` instead of `group` on dask interface." + msg = "Use the `qid` instead of the `group` with the dask interface." if not (group is None and eval_group is None): raise ValueError(msg) if qid is None: @@ -2148,6 +2201,7 @@ async def _fit_async( early_stopping_rounds=self.early_stopping_rounds, callbacks=self.callbacks, xgb_model=model, + coll_cfg=self.coll_cfg, ) self._Booster = results["booster"] self.evals_result_ = results["history"] @@ -2202,6 +2256,7 @@ def __init__( subsample: Optional[float] = 0.8, colsample_bynode: Optional[float] = 0.8, reg_lambda: Optional[float] = 1e-5, + coll_cfg: Optional[CollConfig] = None, **kwargs: Any, ) -> None: super().__init__( @@ -2209,6 +2264,7 @@ def __init__( subsample=subsample, colsample_bynode=colsample_bynode, reg_lambda=reg_lambda, + coll_cfg=coll_cfg, **kwargs, ) @@ -2262,6 +2318,7 @@ def __init__( subsample: Optional[float] = 0.8, colsample_bynode: Optional[float] = 0.8, reg_lambda: Optional[float] = 1e-5, + coll_cfg: Optional[CollConfig] = None, **kwargs: Any, ) -> None: super().__init__( @@ -2269,6 +2326,7 @@ def __init__( subsample=subsample, colsample_bynode=colsample_bynode, reg_lambda=reg_lambda, + coll_cfg=coll_cfg, **kwargs, ) diff --git a/python-package/xgboost/dask/utils.py b/python-package/xgboost/dask/utils.py index d433c807288e..7f71ca6e3bc2 100644 --- a/python-package/xgboost/dask/utils.py +++ b/python-package/xgboost/dask/utils.py @@ -1,13 +1,14 @@ """Utilities for the XGBoost Dask interface.""" import logging -from typing import TYPE_CHECKING, Any, Dict +import warnings +from typing import Any, Dict, Optional, Tuple -LOGGER = logging.getLogger("[xgboost.dask]") +import distributed +from ..collective import Config -if TYPE_CHECKING: - import distributed +LOGGER = logging.getLogger("[xgboost.dask]") def get_n_threads(local_param: Dict[str, Any], worker: "distributed.Worker") -> int: @@ -23,3 +24,76 @@ def get_n_threads(local_param: Dict[str, Any], worker: "distributed.Worker") -> if n_threads == 0 or n_threads is None: n_threads = dwnt return n_threads + + +def get_address_from_user( + dconfig: Optional[Dict[str, Any]], coll_cfg: Config +) -> Tuple[Optional[str], int]: + """Get the tracker address from the optional user configuration. + + Parameters + ---------- + dconfig : + Dask global configuration. + + coll_cfg : + Collective configuration. + + Returns + ------- + The IP address along with the port number. + + """ + + valid_config = ["scheduler_address"] + + host_ip = None + port = 0 + + if dconfig is not None: + for k in dconfig: + if k not in valid_config: + raise ValueError(f"Unknown configuration: {k}") + warnings.warn( + ( + "Use `coll_cfg` instead of the Dask global configuration store" + f" for the XGBoost tracker configuration: {k}." + ), + FutureWarning, + ) + else: + dconfig = {} + + host_ip = dconfig.get("scheduler_address", None) + if host_ip is not None and host_ip.startswith("[") and host_ip.endswith("]"): + # convert dask bracket format to proper IPv6 address. + host_ip = host_ip[1:-1] + if host_ip is not None: + try: + host_ip, port = distributed.comm.get_address_host_port(host_ip) + except ValueError: + pass + + if coll_cfg is None: + coll_cfg = Config() + if coll_cfg.tracker_host_ip is not None: + if host_ip is not None and coll_cfg.tracker_host_ip != host_ip: + raise ValueError( + "Conflicting host IP addresses from the dask configuration and the " + f"collective configuration: {host_ip} v.s. {coll_cfg.tracker_host_ip}." + ) + host_ip = coll_cfg.tracker_host_ip + if coll_cfg.tracker_port is not None: + if ( + port != 0 + and port is not None + and coll_cfg.tracker_port != 0 + and port != coll_cfg.tracker_port + ): + raise ValueError( + "Conflicting ports from the dask configuration and the " + f"collective configuration: {port} v.s. {coll_cfg.tracker_port}." + ) + port = coll_cfg.tracker_port + + return host_ip, port diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index c947d473ec78..3d5618d5d8f4 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -123,7 +123,7 @@ "pred_contrib_col", "use_gpu", "launch_tracker_on_driver", - "tracker_host", + "tracker_host_ip", "tracker_port", ] @@ -257,9 +257,9 @@ class _SparkXGBParams( "launched on the driver side; otherwise, it will be launched on the executor side.", TypeConverters.toBoolean, ) - tracker_host = Param( + tracker_host_ip = Param( Params._dummy(), - "tracker_host", + "tracker_host_ip", "A string variable. The tracker host IP address. To set tracker host ip, you need to " "enable launch_tracker_on_driver to be true first", TypeConverters.toString, @@ -1030,25 +1030,29 @@ def _get_tracker_args(self) -> Tuple[bool, Dict[str, Any]]: launch_tracker_on_driver = self.getOrDefault(self.launch_tracker_on_driver) rabit_args = {} if launch_tracker_on_driver: - tracker_host: Optional[str] = None - if self.isDefined(self.tracker_host): - tracker_host = self.getOrDefault(self.tracker_host) + tracker_host_ip: Optional[str] = None + if self.isDefined(self.tracker_host_ip): + tracker_host_ip = self.getOrDefault(self.tracker_host_ip) else: - tracker_host = ( + tracker_host_ip = ( _get_spark_session().sparkContext.getConf().get("spark.driver.host") ) - assert tracker_host is not None + assert tracker_host_ip is not None tracker_port = 0 if self.isDefined(self.tracker_port): tracker_port = self.getOrDefault(self.tracker_port) num_workers = self.getOrDefault(self.num_workers) - rabit_args.update(_get_rabit_args(tracker_host, num_workers, tracker_port)) + rabit_args.update( + _get_rabit_args(tracker_host_ip, num_workers, tracker_port) + ) else: - if self.isDefined(self.tracker_host) or self.isDefined(self.tracker_port): + if self.isDefined(self.tracker_host_ip) or self.isDefined( + self.tracker_port + ): raise ValueError( "You must enable launch_tracker_on_driver to use " - "tracker_host and tracker_port" + "tracker_host_ip and tracker_port" ) return launch_tracker_on_driver, rabit_args diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index f03770059564..a8a2314272a6 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -14,9 +14,12 @@ from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext from pyspark.sql.session import SparkSession -from xgboost import Booster, XGBModel -from xgboost.collective import CommunicatorContext as CCtx -from xgboost.tracker import RabitTracker +from ..collective import CommunicatorContext as CCtx +from ..collective import _Args as CollArgs +from ..collective import _ArgVals as CollArgsVals +from ..core import Booster +from ..sklearn import XGBModel +from ..tracker import RabitTracker def get_class_name(cls: Type) -> str: @@ -46,14 +49,14 @@ def _get_default_params_from_func( class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods """Context with PySpark specific task ID.""" - def __init__(self, context: BarrierTaskContext, **args: Any) -> None: + def __init__(self, context: BarrierTaskContext, **args: CollArgsVals) -> None: args["dmlc_task_id"] = str(context.partitionId()) super().__init__(**args) -def _start_tracker(host: str, n_workers: int, port: int = 0) -> Dict[str, Any]: +def _start_tracker(host: str, n_workers: int, port: int = 0) -> CollArgs: """Start Rabit tracker with n_workers""" - args: Dict[str, Any] = {"n_workers": n_workers} + args: CollArgs = {"n_workers": n_workers} tracker = RabitTracker(n_workers=n_workers, host_ip=host, sortby="task", port=port) tracker.start() thread = Thread(target=tracker.wait_for) @@ -63,7 +66,7 @@ def _start_tracker(host: str, n_workers: int, port: int = 0) -> Dict[str, Any]: return args -def _get_rabit_args(host: str, n_workers: int, port: int = 0) -> Dict[str, Any]: +def _get_rabit_args(host: str, n_workers: int, port: int = 0) -> CollArgs: """Get rabit context arguments to send to each worker.""" env = _start_tracker(host, n_workers, port) return env diff --git a/python-package/xgboost/testing/dask.py b/python-package/xgboost/testing/dask.py index b730728f4a9a..93514a97fbfd 100644 --- a/python-package/xgboost/testing/dask.py +++ b/python-package/xgboost/testing/dask.py @@ -1,6 +1,6 @@ """Tests for dask shared by different test modules.""" -from typing import List, Literal, cast +from typing import Any, List, Literal, cast import numpy as np import pandas as pd @@ -14,6 +14,7 @@ from xgboost.testing.updater import get_basescore from .. import dask as dxgb +from ..dask import _get_rabit_args def check_init_estimation_clf( @@ -168,3 +169,8 @@ def check_external_memory( # pylint: disable=too-many-locals np.testing.assert_allclose( results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4 ) + + +def get_rabit_args(client: Client, n_workers: int) -> Any: + """Get RABIT collective communicator arguments for tests.""" + return client.sync(_get_rabit_args, client, n_workers) diff --git a/python-package/xgboost/testing/federated.py b/python-package/xgboost/testing/federated.py index 13755af9064d..ddcce88c75f3 100644 --- a/python-package/xgboost/testing/federated.py +++ b/python-package/xgboost/testing/federated.py @@ -16,6 +16,8 @@ from xgboost import testing as tm from xgboost.training import TrainingCallback +from ..collective import _Args as CollArgs + SERVER_KEY = "server-key.pem" SERVER_CERT = "server-cert.pem" CLIENT_KEY = "client-key.pem" @@ -40,23 +42,23 @@ def run_worker( port: int, world_size: int, rank: int, with_ssl: bool, device: str ) -> None: """Run federated client worker for test.""" - communicator_env = { + comm_env: CollArgs = { "dmlc_communicator": "federated", "federated_server_address": f"localhost:{port}", "federated_world_size": world_size, "federated_rank": rank, } if with_ssl: - communicator_env["federated_server_cert_path"] = SERVER_CERT - communicator_env["federated_client_key_path"] = CLIENT_KEY - communicator_env["federated_client_cert_path"] = CLIENT_CERT + comm_env["federated_server_cert_path"] = SERVER_CERT + comm_env["federated_client_key_path"] = CLIENT_KEY + comm_env["federated_client_cert_path"] = CLIENT_CERT cpu_count = os.cpu_count() assert cpu_count is not None n_threads = cpu_count // world_size # Always call this before using distributed module - with xgb.collective.CommunicatorContext(**communicator_env): + with xgb.collective.CommunicatorContext(**comm_env): # Load file, file will not be sharded in federated mode. X, y = load_svmlight_file(f"agaricus.txt-{rank}.train") dtrain = xgb.DMatrix(X, y) diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index ab47b6b0d769..926c957b28a4 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -19,28 +19,60 @@ class RabitTracker: workers. Parameters - .......... + ---------- + + n_workers: + + The total number of workers in the communication group. + + host_ip: + + The IP address of the tracker node. XGBoost can try to guess one by probing with + sockets. But it's best to explicitly pass an address. + + port: + + The port this tracker should listen to. XGBoost can query an available port from + the OS, this configuration is useful for restricted network environments. + sortby: How to sort the workers for rank assignment. The default is host, but users can - set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain - deterministic rank assignment. Available options are: + set the `DMLC_TASK_ID` via arguments of :py:meth:`~xgboost.collective.init` and + obtain deterministic rank assignment through sorting by task name. Available + options are: + - host - task timeout : - Timeout for constructing the communication group and waiting for the tracker to - shutdown when it's instructed to, doesn't apply to communication when tracking - is running. + Timeout for constructing (bootstrap) and shutting down the communication group, + doesn't apply to communication when the group is up and running. The timeout value should take the time of data loading and pre-processing into - account, due to potential lazy execution. + account, due to potential lazy execution. By default the Tracker doesn't have + any timeout to avoid pre-mature aborting. The :py:meth:`.wait_for` method has a different timeout parameter that can stop the tracker even if the tracker is still being used. A value error is raised when timeout is reached. + Examples + -------- + + .. code-block:: python + + from xgboost.tracker import RabitTracker + from xgboost import collective as coll + + tracker = RabitTracker(host_ip="127.0.0.1", n_workers=2) + tracker.start() + + with coll.CommunicatorContext(**tracker.worker_args()): + ret = coll.broadcast("msg", 0) + assert str(ret) == "msg" + """ @unique diff --git a/src/collective/comm_group.cc b/src/collective/comm_group.cc index a9b58ecb5505..6b4c03686a32 100644 --- a/src/collective/comm_group.cc +++ b/src/collective/comm_group.cc @@ -77,6 +77,7 @@ CommGroup::CommGroup() auto retry = get_param("dmlc_retry", static_cast(DefaultRetry()), Integer{}); auto timeout = get_param("dmlc_timeout", static_cast(DefaultTimeoutSec()), Integer{}); + CHECK_GE(timeout, 0); auto task_id = get_param("dmlc_task_id", std::string{}, String{}); if (type == "rabit") { diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index 4de18f2339e8..e97b13f2c465 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -21,6 +21,7 @@ class LintersPaths: # tests "tests/python/test_config.py", "tests/python/test_callback.py", + "tests/python/test_collective.py", "tests/python/test_data_iterator.py", "tests/python/test_dmatrix.py", "tests/python/test_dt.py", @@ -94,6 +95,7 @@ class LintersPaths: # core "python-package/", # tests + "tests/python/test_collective.py", "tests/python/test_dt.py", "tests/python/test_demos.py", "tests/python/test_data_iterator.py", diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index bd444a2375dd..473b38b5b742 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -1,6 +1,5 @@ import socket -import sys -from threading import Thread +from dataclasses import asdict import numpy as np import pytest @@ -9,6 +8,7 @@ import xgboost as xgb from xgboost import RabitTracker, build_info, federated from xgboost import testing as tm +from xgboost.collective import Config def run_rabit_worker(rabit_env: dict, world_size: int) -> int: @@ -59,14 +59,14 @@ def run_federated_worker(port: int, world_size: int, rank: int) -> int: @pytest.mark.skipif(**tm.skip_win()) @pytest.mark.skipif(**tm.no_loky()) -def test_federated_communicator(): +def test_federated_communicator() -> None: if not build_info()["USE_FEDERATED"]: pytest.skip("XGBoost not built with federated learning enabled") port = 9091 world_size = 2 - with get_reusable_executor(max_workers=world_size+1) as pool: - kwargs={"port": port, "n_workers": world_size, "blocking": False} + with get_reusable_executor(max_workers=world_size + 1) as pool: + kwargs = {"port": port, "n_workers": world_size, "blocking": False} tracker = pool.submit(federated.run_federated_server, **kwargs) if not tracker.running(): raise RuntimeError("Error starting Federated Learning server") @@ -79,3 +79,9 @@ def test_federated_communicator(): workers.append(worker) for worker in workers: assert worker.result() == 0 + + +def test_config_serialization() -> None: + cfg = Config(retry=1, timeout=2, tracker_host_ip="127.0.0.1", tracker_port=None) + cfg1 = Config(**asdict(cfg)) + assert cfg == cfg1 diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 099fba0fb572..81f6fc23083e 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -7,7 +7,6 @@ import pytest from hypothesis import HealthCheck, given, settings, strategies -import xgboost as xgb from xgboost import RabitTracker, collective from xgboost import testing as tm @@ -25,9 +24,22 @@ def test_rabit_tracker() -> None: pytest.skip("Windows is not supported.") with pytest.raises(ValueError, match="Failed to bind socket"): + # Port is already being used RabitTracker(host_ip="127.0.0.1", port=port, n_workers=1) +@pytest.mark.skipif(**tm.not_linux()) +def test_wait() -> None: + tracker = RabitTracker(host_ip="127.0.0.1", n_workers=2) + tracker.start() + + with pytest.raises(ValueError, match="Timeout waiting for the tracker"): + tracker.wait_for(1) + + with pytest.raises(ValueError, match="Failed to accept"): + tracker.free() + + @pytest.mark.skipif(**tm.not_linux()) def test_socket_error() -> None: tracker = RabitTracker(host_ip="127.0.0.1", n_workers=2) @@ -150,6 +162,7 @@ def test_rank_assignment() -> None: from distributed import Client, LocalCluster from xgboost import dask as dxgb + from xgboost.testing.dask import get_rabit_args def local_test(worker_id): with dxgb.CommunicatorContext(**args) as ctx: @@ -163,13 +176,7 @@ def local_test(worker_id): with LocalCluster(n_workers=8) as cluster: with Client(cluster) as client: workers = tm.get_client_workers(client) - args = client.sync( - dxgb._get_rabit_args, - len(workers), - None, - client, - ) - + args = get_rabit_args(client, len(workers)) futures = client.map(local_test, range(len(workers)), workers=workers) client.gather(futures) diff --git a/tests/test_distributed/test_gpu_with_dask/test_gpu_external_memory.py b/tests/test_distributed/test_gpu_with_dask/test_gpu_external_memory.py index 1d58e33415ad..c8559ed003db 100644 --- a/tests/test_distributed/test_gpu_with_dask/test_gpu_external_memory.py +++ b/tests/test_distributed/test_gpu_with_dask/test_gpu_external_memory.py @@ -4,9 +4,7 @@ from dask_cuda import LocalCUDACluster from distributed import Client -import xgboost as xgb -from xgboost import dask as dxgb -from xgboost.testing.dask import check_external_memory +from xgboost.testing.dask import check_external_memory, get_rabit_args @pytest.mark.parametrize("is_qdm", [True, False]) @@ -14,13 +12,7 @@ def test_external_memory(is_qdm: bool) -> None: n_workers = 2 with LocalCUDACluster(n_workers=2) as cluster: with Client(cluster) as client: - args = client.sync( - dxgb._get_rabit_args, - 2, - None, - client, - ) - + args = get_rabit_args(client, 2) futs = client.map( check_external_memory, range(n_workers), diff --git a/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py b/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py index f03deff4c05b..3bc7d46eb721 100644 --- a/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py +++ b/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py @@ -14,14 +14,9 @@ import xgboost as xgb from xgboost import testing as tm from xgboost.collective import CommunicatorContext +from xgboost.testing.dask import get_rabit_args from xgboost.testing.params import hist_parameter_strategy -pytestmark = [ - pytest.mark.skipif(**tm.no_dask()), - pytest.mark.skipif(**tm.no_dask_cuda()), - tm.timeout(60), -] - from ..test_with_dask.test_with_dask import generate_array from ..test_with_dask.test_with_dask import kCols as random_cols from ..test_with_dask.test_with_dask import ( @@ -38,6 +33,12 @@ suppress, ) +pytestmark = [ + pytest.mark.skipif(**tm.no_dask()), + pytest.mark.skipif(**tm.no_dask_cuda()), + tm.timeout(60), +] + try: import cudf import dask.dataframe as dd @@ -494,9 +495,7 @@ def test_data_initialization(self, local_cuda_client: Client) -> None: m = dxgb.DaskDMatrix(local_cuda_client, X, y, feature_weights=fw) workers = tm.get_client_workers(local_cuda_client) - rabit_args = local_cuda_client.sync( - dxgb._get_rabit_args, len(workers), None, local_cuda_client - ) + rabit_args = get_rabit_args(local_cuda_client, len(workers)) def worker_fn(worker_addr: str, data_ref: Dict) -> None: with dxgb.CommunicatorContext(**rabit_args): @@ -597,9 +596,7 @@ def test_with_asyncio(local_cuda_client: Client) -> None: def test_invalid_nccl(local_cuda_client: Client) -> None: client = local_cuda_client workers = tm.get_client_workers(client) - args = client.sync( - dxgb._get_rabit_args, len(workers), dxgb._get_dask_config(), client - ) + args = get_rabit_args(client, len(workers)) def run(wid: int) -> None: ctx = CommunicatorContext(dmlc_nccl_path="foo", **args) @@ -638,9 +635,7 @@ def make_model() -> None: client = local_cuda_client workers = tm.get_client_workers(client) - args = client.sync( - dxgb._get_rabit_args, len(workers), dxgb._get_dask_config(), client - ) + args = get_rabit_args(client, len(workers)) # nccl is loaded def run(wid: int) -> None: diff --git a/tests/test_distributed/test_with_dask/test_external_memory.py b/tests/test_distributed/test_with_dask/test_external_memory.py index 0694daed27c8..7643f7305e27 100644 --- a/tests/test_distributed/test_with_dask/test_external_memory.py +++ b/tests/test_distributed/test_with_dask/test_external_memory.py @@ -4,10 +4,8 @@ from distributed import Client, Scheduler, Worker from distributed.utils_test import gen_cluster -import xgboost as xgb -from xgboost import dask as dxgb from xgboost import testing as tm -from xgboost.testing.dask import check_external_memory +from xgboost.testing.dask import check_external_memory, get_rabit_args @pytest.mark.parametrize("is_qdm", [True, False]) @@ -16,13 +14,8 @@ async def test_external_memory( client: Client, s: Scheduler, a: Worker, b: Worker, is_qdm: bool ) -> None: workers = tm.get_client_workers(client) - args = await client.sync( - dxgb._get_rabit_args, - len(workers), - None, - client, - ) n_workers = len(workers) + args = await get_rabit_args(client, n_workers) futs = client.map( check_external_memory, diff --git a/tests/test_distributed/test_with_dask/test_with_dask.py b/tests/test_distributed/test_with_dask/test_with_dask.py index a9a8af6acbe0..dac0860babf8 100644 --- a/tests/test_distributed/test_with_dask/test_with_dask.py +++ b/tests/test_distributed/test_with_dask/test_with_dask.py @@ -20,7 +20,7 @@ import scipy import sklearn from distributed import Client, LocalCluster, Nanny, Worker -from distributed.scheduler import KilledWorker +from distributed.scheduler import KilledWorker, Scheduler from distributed.utils_test import async_poll_for, gen_cluster from hypothesis import HealthCheck, assume, given, note, settings from sklearn.datasets import make_classification, make_regression @@ -29,8 +29,9 @@ from xgboost import collective as coll from xgboost import dask as dxgb from xgboost import testing as tm +from xgboost.collective import Config as CollConfig from xgboost.dask import DaskDMatrix -from xgboost.testing.dask import check_init_estimation, check_uneven_nan +from xgboost.testing.dask import check_init_estimation, check_uneven_nan, get_rabit_args from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy from xgboost.testing.shared import ( get_feature_weights, @@ -992,7 +993,7 @@ def test_empty_dmatrix(tree_method) -> None: async def run_from_dask_array_asyncio(scheduler_address: str) -> dxgb.TrainReturnT: async with Client(scheduler_address, asynchronous=True) as client: X, y, _ = generate_array() - m = await DaskDMatrix(client, X, y) + m = await DaskDMatrix(client, X, y) # type: ignore output = await dxgb.train(client, {}, dtrain=m) with_m = await dxgb.predict(client, output, m) @@ -1097,8 +1098,8 @@ async def train() -> None: ) as cluster: async with Client(cluster, asynchronous=True) as client: X, y, w = generate_array(with_weights=True) - dtrain = await DaskDMatrix(client, X, y, weight=w) - dvalid = await DaskDMatrix(client, X, y, weight=w) + dtrain = await DaskDMatrix(client, X, y, weight=w) # type: ignore + dvalid = await DaskDMatrix(client, X, y, weight=w) # type: ignore output = await dxgb.train(client, {}, dtrain=dtrain) await dxgb.predict(client, output, data=dvalid) @@ -1335,6 +1336,30 @@ def after_iteration(self, model, epoch: int, evals_log) -> bool: pass +def test_invalid_config(client: "Client") -> None: + X, y, _ = generate_array() + dtrain = DaskDMatrix(client, X, y) + + with dask.config.set({"xgboost.foo": "bar"}): + with pytest.raises(ValueError, match=r"Unknown configuration.*"): + dxgb.train(client, {}, dtrain, num_boost_round=4) + + with dask.config.set({"xgboost.scheduler_address": "127.0.0.1:foo"}): + with pytest.raises(socket.gaierror, match=r".*not known.*"): + dxgb.train(client, {}, dtrain, num_boost_round=1) + + # No failure only because we are also using the Dask scheduler address. + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + cfg = CollConfig(tracker_host_ip="127.0.0.1", tracker_port=port) + dxgb.train(client, {}, dtrain, num_boost_round=1, coll_cfg=cfg) + + with pytest.raises(ValueError, match=r"comm_group.*timeout >= 0.*"): + cfg = CollConfig(tracker_host_ip="127.0.0.1", tracker_port=0, timeout=-1) + dxgb.train(client, {}, dtrain, num_boost_round=1, coll_cfg=cfg) + + class TestWithDask: def test_dmatrix_binary(self, client: "Client") -> None: def save_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None: @@ -1355,7 +1380,7 @@ def load_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None: with tempfile.TemporaryDirectory() as tmpdir: workers = tm.get_client_workers(client) - rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client) + rabit_args = get_rabit_args(client, len(workers)) futures = [] for w in workers: # same argument for each worker, must set pure to False otherwise dask @@ -1367,7 +1392,7 @@ def load_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None: futures.append(f) client.gather(futures) - rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client) + rabit_args = get_rabit_args(client, len(workers)) futures = [] for w in workers: f = client.submit( @@ -1426,14 +1451,6 @@ def after_iteration( os.remove(before_fname) os.remove(after_fname) - with dask.config.set({"xgboost.foo": "bar"}): - with pytest.raises(ValueError, match=r"Unknown configuration.*"): - dxgb.train(client, {}, dtrain, num_boost_round=4) - - with dask.config.set({"xgboost.scheduler_address": "127.0.0.1:foo"}): - with pytest.raises(socket.gaierror, match=r".*not known.*"): - dxgb.train(client, {}, dtrain, num_boost_round=1) - def run_updater_test( self, client: "Client", @@ -1619,9 +1636,7 @@ def local_test(rabit_args: Dict[str, Union[int, str]], worker_id: int) -> bool: with LocalCluster(n_workers=2, dashboard_address=":0") as cluster: with Client(cluster) as client: workers = tm.get_client_workers(client) - rabit_args = client.sync( - dxgb._get_rabit_args, len(workers), None, client - ) + rabit_args = get_rabit_args(client, len(workers)) futures = [] for i, _ in enumerate(workers): f = client.submit(local_test, rabit_args, i) @@ -1757,9 +1772,7 @@ def test_no_duplicated_partition(self) -> None: n_partitions = X.npartitions m = dxgb.DaskDMatrix(client, X, y) workers = tm.get_client_workers(client) - rabit_args = client.sync( - dxgb._get_rabit_args, len(workers), None, client - ) + rabit_args = get_rabit_args(client, len(workers)) n_workers = len(workers) def worker_fn(worker_addr: str, data_ref: Dict) -> None: @@ -2259,11 +2272,11 @@ def test_callback(self, client: "Client") -> None: clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True, ) -async def test_worker_left(c, s, a, b): +async def test_worker_left(c: Client, s: Scheduler, a: Worker, b: Worker): async with Worker(s.address): dx = da.random.random((1000, 10)).rechunk(chunks=(10, None)) dy = da.random.random((1000,)).rechunk(chunks=(10,)) - d_train = await dxgb.DaskDMatrix( + d_train = await dxgb.DaskDMatrix( # type: ignore c, dx, dy, diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index 4521ec70d927..8d64dc205ef0 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -4,7 +4,7 @@ import tempfile import uuid from collections import namedtuple -from typing import Generator, Sequence, Type +from typing import Generator, Sequence import numpy as np import pytest @@ -1650,14 +1650,16 @@ def test_unsupported_params(self): def test_tracker(self): classifier = SparkXGBClassifier( launch_tracker_on_driver=True, - tracker_host="192.168.1.32", + tracker_host_ip="192.168.1.32", tracker_port=59981, ) with pytest.raises(Exception, match="Failed to bind socket"): classifier._get_tracker_args() classifier = SparkXGBClassifier( - launch_tracker_on_driver=False, tracker_host="127.0.0.1", tracker_port=58892 + launch_tracker_on_driver=False, + tracker_host_ip="127.0.0.1", + tracker_port=58892, ) with pytest.raises( ValueError, match="You must enable launch_tracker_on_driver" @@ -1666,12 +1668,11 @@ def test_tracker(self): classifier = SparkXGBClassifier( launch_tracker_on_driver=True, - tracker_host="127.0.0.1", - tracker_port=58892, + tracker_host_ip="127.0.0.1", num_workers=2, ) launch_tracker_on_driver, rabit_envs = classifier._get_tracker_args() - assert launch_tracker_on_driver == True + assert launch_tracker_on_driver is True assert rabit_envs["n_workers"] == 2 assert rabit_envs["dmlc_tracker_uri"] == "127.0.0.1"