Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[C10D] Make new_group eager when used with comm_split (pytorch#129284)
If users pass `device_id` to init_process_group, they enable eager init for the default group. Then if they subsequently call `new_group`, the device_id argument is not required as it should be assumed to match the one used for init_process_group. However, both `init_process_group` and `new_group` apis share a helper function, which expects a `device_id` value that defaults to None. When it's None, eager initialization is disabled. This PR ensures that if a device_id was passed to init_process_group, the same device_id will automatically be fed into the helper function for any new_group calls that follow. **Test plan** I found an existing test in CI `test_comm_split_subgroup` that failed after my change, because it was asserting that backend comm_split counter did not increment eagerly, and its behavior had changed to increment eagerly. I updated the test in the PR to pass with my change. I also tested locally via simple program with TORCH_CPP_LOG_LEVEL=INFO and observed eager initialization of the 'lows' and 'highs' PGs before the 'Here' print. ``` import torch import torch.distributed as dist dist.init_process_group(backend="nccl", device_id =torch.device(f"cuda:{torch.distributed.get_node_local_rank(0)}")) dist.new_group([0, 1], group_desc="lows") dist.new_group([2, 3], group_desc="highs") print("Here") torch.distributed.destroy_process_group() ``` Output: https://gist.github.com/wconstab/88a5ba0b970244ca1f79133f989e0349 Pull Request resolved: pytorch#129284 Approved by: https://github.com/pavanbalaji, https://github.com/fduwjj, https://github.com/d4l3k, https://github.com/nvcastet
- Loading branch information