forked from kornia/kornia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
conftest.py
128 lines (98 loc) · 4.26 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import sys
from itertools import product
from typing import Dict
import numpy
import pytest
import torch
import kornia
def get_test_devices() -> Dict[str, torch.device]:
"""Create a dictionary with the devices to test the source code. CUDA devices will be test only in case the
current hardware supports it.
Return:
dict(str, torch.device): list with devices names.
"""
devices: Dict[str, torch.device] = {}
devices["cpu"] = torch.device("cpu")
if torch.cuda.is_available():
devices["cuda"] = torch.device("cuda:0")
if kornia.xla_is_available():
import torch_xla.core.xla_model as xm
devices["tpu"] = xm.xla_device()
if hasattr(torch.backends, 'mps'):
if torch.backends.mps.is_available():
devices["mps"] = torch.device("mps")
return devices
def get_test_dtypes() -> Dict[str, torch.dtype]:
"""Create a dictionary with the dtypes the source code.
Return:
dict(str, torch.dtype): list with dtype names.
"""
dtypes: Dict[str, torch.dtype] = {}
dtypes["bfloat16"] = torch.bfloat16
dtypes["float16"] = torch.float16
dtypes["float32"] = torch.float32
dtypes["float64"] = torch.float64
return dtypes
# setup the devices to test the source code
TEST_DEVICES: Dict[str, torch.device] = get_test_devices()
TEST_DTYPES: Dict[str, torch.dtype] = get_test_dtypes()
# Combinations of device and dtype to be excluded from testing.
# DEVICE_DTYPE_BLACKLIST = {('cpu', 'float16')}
DEVICE_DTYPE_BLACKLIST = {}
@pytest.fixture()
def device(device_name) -> torch.device:
return TEST_DEVICES[device_name]
@pytest.fixture()
def dtype(dtype_name) -> torch.dtype:
return TEST_DTYPES[dtype_name]
@pytest.fixture(scope='session')
def torch_optimizer():
if hasattr(torch, 'compile') and sys.platform == "linux":
torch.set_float32_matmul_precision('high')
return torch.compile
pytest.skip(f"skipped because {torch.__version__} not have `compile` available! Failed to setup dynamo.")
def pytest_generate_tests(metafunc):
device_names = None
dtype_names = None
if 'device_name' in metafunc.fixturenames:
raw_value = metafunc.config.getoption('--device')
if raw_value == 'all':
device_names = list(TEST_DEVICES.keys())
else:
device_names = raw_value.split(',')
if 'dtype_name' in metafunc.fixturenames:
raw_value = metafunc.config.getoption('--dtype')
if raw_value == 'all':
dtype_names = list(TEST_DTYPES.keys())
else:
dtype_names = raw_value.split(',')
if device_names is not None and dtype_names is not None:
# Exclude any blacklisted device/dtype combinations.
params = [combo for combo in product(device_names, dtype_names) if combo not in DEVICE_DTYPE_BLACKLIST]
metafunc.parametrize('device_name,dtype_name', params)
elif device_names is not None:
metafunc.parametrize('device_name', device_names)
elif dtype_names is not None:
metafunc.parametrize('dtype_name', dtype_names)
def pytest_addoption(parser):
parser.addoption('--device', action="store", default="cpu")
parser.addoption('--dtype', action="store", default="float32")
@pytest.fixture(autouse=True)
def add_doctest_deps(doctest_namespace):
doctest_namespace["np"] = numpy
doctest_namespace["torch"] = torch
doctest_namespace["kornia"] = kornia
# the commit hash for the data version
sha: str = 'cb8f42bf28b9f347df6afba5558738f62a11f28a'
sha2: str = '824ff1518870864644df6842a4ec964040f64504'
sha3: str = '8b98f44abbe92b7a84631ed06613b08fee7dae14'
@pytest.fixture(scope='session')
def data(request):
url = {
'loftr_homo': f'https://github.com/kornia/data_test/blob/{sha}/loftr_outdoor_and_homography_data.pt?raw=true',
'loftr_fund': f'https://github.com/kornia/data_test/blob/{sha}/loftr_indoor_and_fundamental_data.pt?raw=true',
'adalam_idxs': f'https://github.com/kornia/data_test/blob/{sha2}/adalam_test.pt?raw=true',
'disk_outdoor': f'https://github.com/kornia/data_test/blob/{sha3}/knchurch_disk.pt?raw=true',
'dexined': 'https://cmp.felk.cvut.cz/~mishkdmy/models/DexiNed_BIPED_10.pth',
}
return torch.hub.load_state_dict_from_url(url[request.param], map_location=torch.device('cpu'))