Skip to content

Commit

Permalink
add trailing comma
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 committed Feb 26, 2024
1 parent 54171a5 commit 954c21a
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/trainer/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,13 @@ def test_fsdp_process_group(world_size: int):
@pytest.mark.gpu
@world_size(2)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.2.0'), reason='Device mesh requires Torch 2.2')
@pytest.mark.parametrize('sharding_strategy',
['NO_SHARD', 'SHARD_GRAD_OP', 'FULL_SHARD', 'HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'])
@pytest.mark.parametrize('sharding_strategy', [
'NO_SHARD',
'SHARD_GRAD_OP',
'FULL_SHARD',
'HYBRID_SHARD',
'_HYBRID_SHARD_ZERO2',
])
@pytest.mark.parametrize('device_mesh', [[2], [1, 2]])
def test_wrong_size_device_mesh_error(world_size: int, sharding_strategy: str, device_mesh: list[int]):
context = contextlib.nullcontext()
Expand Down

0 comments on commit 954c21a

Please sign in to comment.