Skip to content

Commit

Permalink
Keep global rank for embedding tower intra-node RW/CW sharding
Browse files Browse the repository at this point in the history
Summary:
ShardedTensor change pytorch/pytorch#123230 changed the validation logic for `_parse_and_validate_remote_device` to check remote device's rank against the ranks belonging to the passed process group (previously it was against the global pg containing all ranks 0...WORLD_SIZE-1).

This causes issue for embedding towers, which internally converts TWRW sharding to RW sharding with intra-node process group (e.g. 2 nodes with 2 GPUs each would have 2 intra-node process groups: one containing ranks [0,1] and other with ranks [2,3]). In the process, it replaces the sharding plan's sharding spec shard placements to the local ranks instead of global ranks.

This fails validation and breaks ET sharding. The core issue is that torchrec expects ranks to be in [0...WORLD_SIZE-1]. The 'fix' here is a hack where we trick sharding that we are on ranks 0 and 1 while keeping shard metadata to contain global ranks.

Reviewed By: xunnanxu

Differential Revision: D56120722
  • Loading branch information
sarckk authored and facebook-github-bot committed Apr 17, 2024
1 parent daf747c commit b6d6a58
Showing 1 changed file with 0 additions and 4 deletions.
4 changes: 0 additions & 4 deletions torchrec/distributed/embedding_tower_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ def _replace_sharding_with_intra_node(
raise ValueError(f"Sharding type not supported {value.sharding_type}")
if value.ranks:
value.ranks = [rank % local_size for rank in value.ranks]
if value.sharding_spec:
# pyre-ignore [6, 16]
for shard, rank in zip(value.sharding_spec.shards, value.ranks):
shard.placement._rank = rank


class TowerLazyAwaitable(LazyAwaitable[torch.Tensor]):
Expand Down

0 comments on commit b6d6a58

Please sign in to comment.