-
Notifications
You must be signed in to change notification settings - Fork 2
/
data.py
118 lines (108 loc) · 3.96 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
import numpy as np
import os
import glob
import cv2
from libtiff import TIFF
# no need of libtiff, can comment
class dataProcess(object):
def __init__(self, out_rows, out_cols, data_path = "./raw/train", label_path = "./raw/label",
test_path = "./raw/test", npy_path = "./npydata", result_path = "./results", img_type = "tif"):
# 数据处理类,初始化
self.out_rows = out_rows
self.out_cols = out_cols
self.data_path = data_path
self.label_path = label_path
self.img_type = img_type
self.test_path = test_path
self.npy_path = npy_path
self.result_path = result_path
if not os.path.exists(self.npy_path):
os.mkdir(self.npy_path)
if not os.path.exists(self.result_path):
os.mkdir(self.result_path)
# 创建训练数据
def create_train_data(self):
i = 0
print('-'*30)
print('Creating training images...')
print('-'*30)
imgs = glob.glob(self.data_path+"/*."+self.img_type)
print("Number of train images is {}".format(len(imgs)))
imgdatas = np.ndarray((len(imgs),self.out_rows,self.out_cols,1), dtype=np.uint8)
imglabels = np.ndarray((len(imgs),self.out_rows,self.out_cols,1), dtype=np.uint8)
for imgname in imgs:
midname = imgname[imgname.rindex("/")+1:]
img = load_img(self.data_path + "/" + midname,grayscale = True)
label = load_img(self.label_path + "/" + midname,grayscale = True)
img = img_to_array(img)
label = img_to_array(label)
#img = cv2.imread(self.data_path + "/" + midname,cv2.IMREAD_GRAYSCALE)
#label = cv2.imread(self.label_path + "/" + midname,cv2.IMREAD_GRAYSCALE)
#img = np.array([img])
#label = np.array([label])
imgdatas[i] = img
imglabels[i] = label
if i % 100 == 0:
print('Done: {0}/{1} images'.format(i, len(imgs)))
i += 1
print('loading done')
np.save(self.npy_path + '/imgs_train', imgdatas)
np.save(self.npy_path + '/imgs_mask_train', imglabels)
print('Saving to imgs_train.npy and imgs_mask_train.npy files done.')
# 创建测试数据
def create_test_data(self):
i = 0
print('-'*30)
print('Creating test images...')
print('-'*30)
imgs = glob.glob(self.test_path+"/*."+self.img_type)
imgs = sorted(imgs)
print("Number of test images is {}".format(len(imgs)))
imgdatas = np.ndarray((len(imgs),self.out_rows,self.out_cols,1), dtype=np.uint8)
for imgname in imgs:
midname = imgname[imgname.rindex("/")+1:]
img = load_img(self.test_path + "/" + midname,grayscale = True)
img = img_to_array(img)
#img = cv2.imread(self.test_path + "/" + midname,cv2.IMREAD_GRAYSCALE)
#img = np.array([img])
imgdatas[i] = img
i += 1
print('loading done')
np.save(self.npy_path + '/imgs_test.npy', imgdatas)
print('Saving to imgs_test.npy files done.')
# 加载训练图片与mask
def load_train_data(self):
print('-'*30)
print('load train images...')
print('-'*30)
imgs_train = np.load(self.npy_path+"/imgs_train.npy")
imgs_mask_train = np.load(self.npy_path+"/imgs_mask_train.npy")
imgs_train = imgs_train.astype('float32')
imgs_mask_train = imgs_mask_train.astype('float32')
imgs_train /= 255
mean = imgs_train.mean(axis = 0)
imgs_train -= mean
imgs_mask_train /= 255
# 做一个阈值处理,输出的概率值大于0.5的就认为是对象,否则认为是背景
imgs_mask_train[imgs_mask_train > 0.5] = 1
imgs_mask_train[imgs_mask_train <= 0.5] = 0
return imgs_train,imgs_mask_train
# 加载测试图片
def load_test_data(self):
print('-'*30)
print('load test images...')
print('-'*30)
imgs_test = np.load(self.npy_path+"/imgs_test.npy")
imgs_test = imgs_test.astype('float32')
imgs_test /= 255
mean = imgs_test.mean(axis = 0)
imgs_test -= mean
return imgs_test
if __name__ == "__main__":
mydata = dataProcess(512,512)
mydata.create_train_data()
mydata.create_test_data()
imgs_train,imgs_mask_train = mydata.load_train_data()
imgs_test = mydata.load_test_data()
print(imgs_train.shape,imgs_mask_train.shape, imgs_test.shape)