Skip to content

Commit

Permalink
Add SHARD_GRAD_OP to device mesh error check (#3058)
Browse files Browse the repository at this point in the history
* fix tests

* fi xerror

* fix
  • Loading branch information
mvpatel2000 authored Feb 26, 2024
1 parent 1c19e1c commit 1c52d47
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
4 changes: 3 additions & 1 deletion composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def sync_hook(*args):
if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0'):
if 'device_mesh' in fsdp_config:
device_mesh_size = len(fsdp_config['device_mesh'])
if sharding_strategy in [ShardingStrategy.FULL_SHARD, ShardingStrategy.NO_SHARD] and device_mesh_size != 1:
if sharding_strategy in [
ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP, ShardingStrategy.NO_SHARD
] and device_mesh_size != 1:
raise ValueError(f'FSDP sharding strategy {sharding_map_key.upper()} requires a device mesh '
f'of size 1 but got device mesh size of {device_mesh_size}.')
elif sharding_strategy in [ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2
Expand Down
16 changes: 13 additions & 3 deletions tests/trainer/test_fsdp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import contextlib
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -219,10 +220,19 @@ def test_fsdp_process_group(world_size: int):
@pytest.mark.gpu
@world_size(2)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.2.0'), reason='Device mesh requires Torch 2.2')
def test_wrong_size_device_mesh_error(world_size: int):
with pytest.raises(ValueError, match='.*requires a device mesh of size 1.*'):
@pytest.mark.parametrize('sharding_strategy',
['NO_SHARD', 'SHARD_GRAD_OP', 'FULL_SHARD', 'HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'])
@pytest.mark.parametrize('device_mesh', [[2], [1, 2]])
def test_wrong_size_device_mesh_error(world_size: int, sharding_strategy: str, device_mesh: list[int]):
context = contextlib.nullcontext()
if sharding_strategy in ['NO_SHARD', 'SHARD_GRAD_OP', 'FULL_SHARD'] and len(device_mesh) != 1:
context = pytest.raises(ValueError, match='.*requires a device mesh of size 1.*')
if sharding_strategy in ['HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'] and len(device_mesh) != 2:
context = pytest.raises(ValueError, match='.*requires a device mesh of size 2.*')
with context:
Trainer(model=SimpleModel(), fsdp_config={
'device_mesh': [1, 2],
'sharding_strategy': sharding_strategy,
'device_mesh': device_mesh,
})


Expand Down

0 comments on commit 1c52d47

Please sign in to comment.