Skip to content

Commit

Permalink
update gpu_test decorator to avoid TypeError complaint
Browse files Browse the repository at this point in the history
  • Loading branch information
lessw2020 committed Oct 6, 2023
1 parent b5ad836 commit e26090a
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
from torch import nn, Tensor


def gpu_test(gpu_count: int = 1):
def gpu_test(min_gpu_count: int = 1):
"""
Annotation for GPU tests, skipping the test if the
required amount of GPU is not available
"""
message = f"Not enough GPUs to run the test: required {gpu_count}"
return pytest.mark.skipif(torch.cuda.device_count() < gpu_count, reason=message)
message = f"Not enough GPUs to run the test: requires {min_gpu_count}"
local_gpu_count = torch.cuda.device_count()
return pytest.mark.skipif(local_gpu_count < min_gpu_count, reason=message)


def init_distributed_on_file(world_size: int, gpu_id: int, sync_file: str):
Expand Down

0 comments on commit e26090a

Please sign in to comment.