Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Keep global rank for embedding tower intra-node RW/CW sharding
Summary: ShardedTensor change pytorch/pytorch#123230 changed the validation logic for `_parse_and_validate_remote_device` to check remote device's rank against the ranks belonging to the passed process group (previously it was against the global pg containing all ranks 0...WORLD_SIZE-1). This causes issue for embedding towers, which internally converts TWRW sharding to RW sharding with intra-node process group (e.g. 2 nodes with 2 GPUs each would have 2 intra-node process groups: one containing ranks [0,1] and other with ranks [2,3]). In the process, it replaces the sharding plan's sharding spec shard placements to the local ranks instead of global ranks. This fails validation and breaks ET sharding. The core issue is that torchrec expects ranks to be in [0...WORLD_SIZE-1]. The 'fix' here is a hack where we trick sharding that we are on ranks 0 and 1 while keeping shard metadata to contain global ranks. Reviewed By: xunnanxu Differential Revision: D56120722
- Loading branch information