Skip to content

Commit

Permalink
properly check for bfloat16
Browse files Browse the repository at this point in the history
- we check only the test device, not the machine in general
- we don't want emulated bfloat16 (e.g. CPU)
  • Loading branch information
catwell committed Oct 9, 2024
1 parent f3d2b6c commit 2796117
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
10 changes: 10 additions & 0 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,13 @@ def str_to_dtype(dtype: str) -> torch.dtype:
return torch.bool
case _:
raise ValueError(f"Unknown dtype: {dtype}")


def device_has_bfloat16(device: torch.device) -> bool:
cuda_version = cast(str | None, torch.version.cuda) # type: ignore
if cuda_version is None or int(cuda_version.split(".")[0]) < 11:
return False
try:
return torch.cuda.get_device_properties(device).major >= 8 # type: ignore
except ValueError:
return False
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from pytest import FixtureRequest, fixture, skip

from refiners.fluxion.utils import str_to_dtype
from refiners.fluxion.utils import device_has_bfloat16, str_to_dtype

PARENT_PATH = Path(__file__).parent

Expand All @@ -21,11 +21,11 @@ def test_device() -> torch.device:
return torch.device(test_device)


def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]:
def dtype_fixture_factory(params: list[str]) -> Callable[[torch.device, FixtureRequest], torch.dtype]:
@fixture(scope="session", params=params)
def dtype_fixture(request: FixtureRequest) -> torch.dtype:
def dtype_fixture(test_device: torch.device, request: FixtureRequest) -> torch.dtype:
torch_dtype = str_to_dtype(request.param)
if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
if torch_dtype == torch.bfloat16 and not device_has_bfloat16(test_device):
skip("bfloat16 is not supported on this test device")
return torch_dtype

Expand Down

0 comments on commit 2796117

Please sign in to comment.