diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index 000b5c9a685..3ec8c05a1da 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -1,9 +1,11 @@ +import itertools import socket from contextlib import contextmanager from functools import wraps -from typing import Any, Callable, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union import torch +from torch import distributed as dist from ignite.distributed.comp_models import ( _SerialModel, @@ -43,6 +45,7 @@ "one_rank_only", "new_group", "one_rank_first", + "all_gather_tensors_with_shapes", ] _model = _SerialModel() @@ -350,6 +353,60 @@ def all_reduce( return _model.all_reduce(tensor, op, group=group) +def all_gather_tensors_with_shapes( + tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None +) -> List[torch.Tensor]: + """Helper method to gather tensors of possibly different shapes but with the same number of dimensions + across processes. + + This function gets the shapes of participating tensors as input so you should know them beforehand. If your + tensors are of different number of dimensions or you don't know their shapes beforehand, you can use + ``torch.distributed.all_gather_object``, otherwise this method is quite faster. + + Examples: + .. code-block:: python + + import ignite.distributed as idist + + rank = idist.get_rank() + ws = idist.get_world_size() + tensor = torch.randn(rank+1, rank+2) + tensors = idist.all_gather_tensors_with_shapes(tensor, [[r+1, r+2] for r in range(ws)]) + + Args: + tensor: tensor to collect across participating processes. + shapes: A sequence containing the shape of participating processes' ``tensor`` s. + group: list of integer or the process group for each backend. If None, the default process group will be used. + + Returns: + List[torch.Tensor] + """ + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + + if isinstance(group, list) and all(isinstance(item, int) for item in group): + group = _model.new_group(group) + + if isinstance(_model, _SerialModel) or group == dist.GroupMember.NON_GROUP_MEMBER: + return [tensor] + + max_shape = torch.tensor(shapes).amax(dim=0) + padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist() + padded_tensor = torch.nn.functional.pad( + tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes)))) + ) + all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group) + return [ + all_padded_tensors[ + [ + slice(rank * max_shape[0] if dim == 0 else 0, rank * max_shape[0] + dim_size if dim == 0 else dim_size) + for dim, dim_size in enumerate(shape) + ] + ] + for rank, shape in enumerate(shapes) + ] + + def all_gather( tensor: Union[torch.Tensor, float, str, Any], group: Optional[Union[Any, List[int]]] = None ) -> Union[torch.Tensor, float, List[float], List[str], List[Any]]: diff --git a/ignite/handlers/wandb_logger.py b/ignite/handlers/wandb_logger.py index ec3871ae9f8..0264f27d8c9 100644 --- a/ignite/handlers/wandb_logger.py +++ b/ignite/handlers/wandb_logger.py @@ -141,7 +141,7 @@ def __getattr__(self, attr: Any) -> Any: return getattr(self._wandb, attr) def close(self) -> None: - self._wandb.finish() # type: ignore[attr-defined] + self._wandb.finish() def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index 7845f0cd1ce..1f3ad55dd84 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -3,7 +3,7 @@ import torch.distributed as dist import ignite.distributed as idist -from ignite.distributed.utils import sync +from ignite.distributed.utils import all_gather_tensors_with_shapes, sync from ignite.engine import Engine, Events @@ -291,6 +291,60 @@ def _test_distrib_all_gather_group(device): res = idist.all_gather(t, group="abc") +def _test_idist_all_gather_tensors_with_shapes(device): + torch.manual_seed(41) + rank = idist.get_rank() + ws = idist.get_world_size() + reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device) + rank_tensor = reference[ + rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1, + rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2, + rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3, + ] + tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in range(ws)]) + for r in range(ws): + r_tensor = reference[ + r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1, + r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2, + r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3, + ] + assert (r_tensor == tensors[r]).all() + + +def _test_idist_all_gather_tensors_with_shapes_group(device): + if idist.get_world_size() > 1: + torch.manual_seed(41) + + rank = idist.get_rank() + ranks = list(range(1, idist.get_world_size())) + ws = idist.get_world_size() + bnd = idist.backend() + if rank in ranks: + reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device) + rank_tensor = reference[ + rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1, + rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2, + rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3, + ] + else: + rank_tensor = torch.tensor([rank], device=device) + if bnd in ("horovod"): + with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"): + tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks) + else: + tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks) + if rank in ranks: + for r in ranks: + r_tensor = reference[ + r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1, + r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2, + r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3, + ] + assert (r_tensor == tensors[r - 1]).all() + else: + assert [rank_tensor] == tensors + + def _test_distrib_broadcast(device): rank = idist.get_rank() ws = idist.get_world_size() diff --git a/tests/ignite/distributed/utils/test_horovod.py b/tests/ignite/distributed/utils/test_horovod.py index ead6ed4c330..0f297772f94 100644 --- a/tests/ignite/distributed/utils/test_horovod.py +++ b/tests/ignite/distributed/utils/test_horovod.py @@ -17,6 +17,8 @@ _test_distrib_new_group, _test_distrib_one_rank_only, _test_distrib_one_rank_only_with_engine, + _test_idist_all_gather_tensors_with_shapes, + _test_idist_all_gather_tensors_with_shapes_group, _test_sync, ) @@ -163,6 +165,8 @@ def test_idist_all_gather_hvd(gloo_hvd_executor): np = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_all_gather, (device,), np=np, do_init=True) gloo_hvd_executor(_test_distrib_all_gather_group, (device,), np=np, do_init=True) + gloo_hvd_executor(_test_idist_all_gather_tensors_with_shapes, (device,), np=np, do_init=True) + gloo_hvd_executor(_test_idist_all_gather_tensors_with_shapes_group, (device,), np=np, do_init=True) @pytest.mark.distributed diff --git a/tests/ignite/distributed/utils/test_native.py b/tests/ignite/distributed/utils/test_native.py index fda3e1126cc..ee828ef5d9f 100644 --- a/tests/ignite/distributed/utils/test_native.py +++ b/tests/ignite/distributed/utils/test_native.py @@ -19,6 +19,8 @@ _test_distrib_new_group, _test_distrib_one_rank_only, _test_distrib_one_rank_only_with_engine, + _test_idist_all_gather_tensors_with_shapes, + _test_idist_all_gather_tensors_with_shapes_group, _test_sync, ) @@ -253,6 +255,23 @@ def test_idist_all_gather_gloo(distributed_context_single_node_gloo): _test_distrib_all_gather_group(device) +@pytest.mark.distributed +@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test_idist_all_gather_tensors_with_shapes_nccl(distributed_context_single_node_nccl): + device = idist.device() + _test_idist_all_gather_tensors_with_shapes(device) + _test_idist_all_gather_tensors_with_shapes_group(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") +def test_idist_all_gather_tensors_with_shapes_gloo(distributed_context_single_node_gloo): + device = idist.device() + _test_idist_all_gather_tensors_with_shapes(device) + _test_idist_all_gather_tensors_with_shapes_group(device) + + @pytest.mark.distributed @pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") diff --git a/tests/ignite/distributed/utils/test_serial.py b/tests/ignite/distributed/utils/test_serial.py index df2d6742b54..a680cffe25c 100644 --- a/tests/ignite/distributed/utils/test_serial.py +++ b/tests/ignite/distributed/utils/test_serial.py @@ -10,6 +10,7 @@ _test_distrib_barrier, _test_distrib_broadcast, _test_distrib_new_group, + _test_idist_all_gather_tensors_with_shapes, _test_sync, ) @@ -70,6 +71,7 @@ def test_idist__model_methods_no_dist(): def test_idist_collective_ops_no_dist(): _test_distrib_all_reduce("cpu") _test_distrib_all_gather("cpu") + _test_idist_all_gather_tensors_with_shapes("cpu") _test_distrib_barrier("cpu") _test_distrib_broadcast("cpu") _test_distrib_new_group("cpu") @@ -77,6 +79,7 @@ def test_idist_collective_ops_no_dist(): if torch.cuda.device_count() > 1: _test_distrib_all_reduce("cuda") _test_distrib_all_gather("cuda") + _test_idist_all_gather_tensors_with_shapes("cuda") _test_distrib_barrier("cuda") _test_distrib_broadcast("cuda") _test_distrib_new_group("cuda") diff --git a/tests/ignite/distributed/utils/test_xla.py b/tests/ignite/distributed/utils/test_xla.py index bb109eacdea..b41a14c2b88 100644 --- a/tests/ignite/distributed/utils/test_xla.py +++ b/tests/ignite/distributed/utils/test_xla.py @@ -15,6 +15,8 @@ _test_distrib_new_group, _test_distrib_one_rank_only, _test_distrib_one_rank_only_with_engine, + _test_idist_all_gather_tensors_with_shapes, + _test_idist_all_gather_tensors_with_shapes_group, _test_sync, ) @@ -151,6 +153,8 @@ def test_idist_all_gather_xla(): device = idist.device() _test_distrib_all_gather(device) _test_distrib_all_gather_group(device) + _test_idist_all_gather_tensors_with_shapes(device) + _test_idist_all_gather_tensors_with_shapes_group(device) def _test_idist_all_gather_xla_in_child_proc(index):