-
Notifications
You must be signed in to change notification settings - Fork 16
/
DataOperate.py
87 lines (69 loc) · 2.15 KB
/
DataOperate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
""""
we define the data set and data operation in this file.
"""
import glob
import os
from torch.utils.data import Dataset
import numpy as np
import nibabel as nib
import random
import torch
def get_data_list(data_path, ratio=0.8):
"""
this function is create the data list and the data is set as follow:
--data
--data_1
image.nii
label.nii
--data_2
image.nii
label.nii
...
if u use your own data, u can rewrite this function
"""
data_list = glob.glob(os.path.join(data_path, '*'))
label_name = 'label.nii'
data_name = 'image.nii'
data_list.sort()
list_all = [{'data': os.path.join(path, data_name), 'label': os.path.join(path, label_name)} for path in data_list]
cut = int(ratio * len(list_all))
train_list = list_all[:cut]
test_list = list_all[cut:]
random.shuffle(train_list)
return train_list, test_list
class MySet(Dataset):
"""
the dataset class receive a list that contain the data item, and each item is a dict
with two item include data path and label path. as follow:
data_list = [
{
"data": data_path_1,
"label": label_path_1,
...
}
]
"""
def __init__(self, data_list):
self.data_list = data_list
def __getitem__(self, item):
data_dict = self.data_list[item]
data_path = data_dict["data"]
mask_path = data_dict["label"]
data = nib.load(data_path)
data = data.get_fdata()
mask = nib.load(mask_path)
mask = mask.get_fdata()
data = self.normalize(data)
data = data[np.newaxis, :, :, :]
mask = mask.astype(np.float32)
mask = mask[np.newaxis, :, :, :]
mask_tensor = torch.from_numpy(mask)
data_tensor = torch.from_numpy(data)
return data_tensor, mask_tensor
@staticmethod
def normalize(data):
data = data.astype(np.float32)
data = (data - np.min(data))/(np.max(data) - np.min(data))
return data
def __len__(self):
return len(self.data_list)