-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDAResUNet_test.py
66 lines (55 loc) · 2.03 KB
/
DAResUNet_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn
import os
import h5py
import numpy as np
from DAResUNet.daresunet import DAResUNet
from torchviz import make_dot
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--force-cpu', '-c', action='store_true',
help='Force CPU device ?')
def main():
args = parser.parse_args()
if args.force_cpu:
device = torch.device('cpu')
print(f'\033[93m'
f'Forcing CPU usage\n'
f'Running on CPU, model may be slow'
f'\033[0m')
elif torch.cuda.is_available():
device = torch.device('cuda:0')
print(f'\033[92m'
f'CUDA available, running on device \'{torch.cuda.get_device_name(device)}\''
f'\033[0m')
else:
device = torch.device('cpu')
print(f'\033[93m'
f'CUDA not available !\n'
f'Running on CPU, model may be slow'
f'\033[0m')
data = []
for root, dirs, filenames in os.walk('challenge_dataset/'): # adapt path
for file in filenames:
data.append(h5py.File(f'{root}{file}'))
data_shape = data[0]['raw'].shape
raws = np.expand_dims(np.array([data[i]['raw']
for i in range(len(data))]), axis=1)
raws = torch.as_tensor(raws, dtype=torch.uint8, device=device)
labels = np.expand_dims(
np.array([data[i]['label'] for i in range(len(data))]), axis=1)
labels = torch.as_tensor(labels, dtype=torch.uint8, device=device)
print(
f'\033[92m'
f'Succesfully loaded tensors of size {raws.size()} to device\n'
f'Total memory : {"{:.2e}".format(np.prod(raws.size())) * 1} bytes'
f'\033[0m'
)
model = DAResUNet().to(device)
dummy = torch.zeros(size=[1, 1, 64, 192, 192],
dtype=torch.float32, device=device)
y_dummy = model(dummy)['y']
make_dot(y_dummy, params=dict(list(model.named_parameters()))
).render('Model plot', format='png')
if __name__ == '__main__':
main()