diff --git a/tests/test_utils.py b/tests/test_utils.py index a0edee01..5d2c3627 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,14 +17,14 @@ from torch import nn, Tensor -def gpu_test(min_gpu_count: int = 1): +def gpu_test(gpu_count: int = 1): """ Annotation for GPU tests, skipping the test if the required amount of GPU is not available """ - message = f"Not enough GPUs to run the test: requires {min_gpu_count}" + message = f"Not enough GPUs to run the test: requires {gpu_count}" local_gpu_count = torch.cuda.device_count() - return pytest.mark.skipif(local_gpu_count < min_gpu_count, reason=message) + return pytest.mark.skipif(local_gpu_count < gpu_count, reason=message) def init_distributed_on_file(world_size: int, gpu_id: int, sync_file: str):