From e26090ac34662d99825d563dce05016462f778a9 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Fri, 6 Oct 2023 04:38:16 +0000 Subject: [PATCH] update gpu_test decorator to avoid TypeError complaint --- tests/test_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index c4100f8d..a0edee01 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,13 +17,14 @@ from torch import nn, Tensor -def gpu_test(gpu_count: int = 1): +def gpu_test(min_gpu_count: int = 1): """ Annotation for GPU tests, skipping the test if the required amount of GPU is not available """ - message = f"Not enough GPUs to run the test: required {gpu_count}" - return pytest.mark.skipif(torch.cuda.device_count() < gpu_count, reason=message) + message = f"Not enough GPUs to run the test: requires {min_gpu_count}" + local_gpu_count = torch.cuda.device_count() + return pytest.mark.skipif(local_gpu_count < min_gpu_count, reason=message) def init_distributed_on_file(world_size: int, gpu_id: int, sync_file: str):