Skip to content

Commit

Permalink
[C10D] Make new_group eager when used with comm_split (pytorch#129284)
Browse files Browse the repository at this point in the history
If users pass `device_id` to init_process_group, they enable eager init
for the default group.  Then if they subsequently call `new_group`, the
device_id argument is not required as it should be assumed to match the
one used for init_process_group.

However, both `init_process_group` and `new_group` apis share a helper
function, which expects a `device_id` value that defaults to None.  When
it's None, eager initialization is disabled.

This PR ensures that if a device_id was passed to init_process_group,
the same device_id will automatically be fed into the helper function
for any new_group calls that follow.

**Test plan**
I found an existing test in CI  `test_comm_split_subgroup` that failed after my change, because it was asserting that backend comm_split counter did not increment eagerly, and its behavior had changed to increment eagerly.  I updated the test in the PR to pass with my change.

I also tested locally via simple program with TORCH_CPP_LOG_LEVEL=INFO and
observed eager initialization of the 'lows' and 'highs' PGs before the
'Here' print.

```
import torch
import torch.distributed as dist
dist.init_process_group(backend="nccl", device_id =torch.device(f"cuda:{torch.distributed.get_node_local_rank(0)}"))
dist.new_group([0, 1], group_desc="lows")
dist.new_group([2, 3], group_desc="highs")
print("Here")
torch.distributed.destroy_process_group()
```

Output:
https://gist.github.com/wconstab/88a5ba0b970244ca1f79133f989e0349

Pull Request resolved: pytorch#129284
Approved by: https://github.com/pavanbalaji, https://github.com/fduwjj, https://github.com/d4l3k, https://github.com/nvcastet
  • Loading branch information
wconstab authored and pytorchmergebot committed Jun 25, 2024
1 parent e58ef5b commit e1499f6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 3 additions & 5 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,14 +614,12 @@ def test_comm_split_subgroup(self):
original_tensor = tensor.clone()
ng = c10d.new_group([0])

# rank 0 hasn't split yet, but rank 1 did for the
# nocolor... so split count matches rank count coincidentally
# in each of the proceses this test spawned!
self.assertEqual(backend.comm_split_count(), self.rank)
# comm split happens eagerly since device_id is passed to init_process_group.
self.assertEqual(backend.comm_split_count(), 1)
if self.rank == 0:
dist.broadcast(tensor, 0, group=ng)

# now everyone has split because rank 0 has performed a comm
# no additional comm split happens after a collective.
self.assertEqual(backend.comm_split_count(), 1)
self.assertEqual(tensor, original_tensor)

Expand Down
2 changes: 2 additions & 0 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4394,6 +4394,7 @@ def _new_group_with_tag(
global _world

default_pg = _get_default_group()
device_id = default_pg.bound_device_id or None
default_backend, default_store = _world.pg_map[default_pg]
global_rank = default_pg.rank()
global_world_size = default_pg.size()
Expand Down Expand Up @@ -4457,6 +4458,7 @@ def _new_group_with_tag(
pg_options=pg_options,
timeout=timeout,
pg_tag=pg_tag,
device_id=device_id,
group_desc=group_desc,
)

Expand Down

0 comments on commit e1499f6

Please sign in to comment.