-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
141 lines (113 loc) · 4.2 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# -*- coding: utf-8 -*-
"""
@author:puzhen
@data:2023/5/15 21:46
"""
# 将图片和标注数据按比例切分为 训练集和测试集
import shutil
import random
import os
# 原始路径
image_original_path = "./images/"
label_original_path = "./labels/"
image_last_name = '.png'
label_last_name = '.txt'
cur_path = os.getcwd()
# 训练集路径
train_image_path = os.path.join(cur_path, "datasets/defect/images/train/")
train_label_path = os.path.join(cur_path, "datasets/defect/labels/train/")
# 验证集路径
val_image_path = os.path.join(cur_path, "datasets/defect/images/val/")
val_label_path = os.path.join(cur_path, "datasets/defect/labels/val/")
# 测试集路径
test_image_path = os.path.join(cur_path, "datasets/defect/images/test/")
test_label_path = os.path.join(cur_path, "datasets/defect/labels/test/")
# 训练集目录
list_train = os.path.join(cur_path, "datasets/defect/train.txt")
list_val = os.path.join(cur_path, "datasets/defect/val.txt")
list_test = os.path.join(cur_path, "datasets/defect/test.txt")
train_percent = 0.4
val_percent = 0.4
test_percent = 0.4
def del_file(path):
for i in os.listdir(path):
file_data = path + "\\" + i
os.remove(file_data)
def mkdir():
if not os.path.exists(train_image_path):
os.makedirs(train_image_path)
else:
del_file(train_image_path)
if not os.path.exists(train_label_path):
os.makedirs(train_label_path)
else:
del_file(train_label_path)
if not os.path.exists(val_image_path):
os.makedirs(val_image_path)
else:
del_file(val_image_path)
if not os.path.exists(val_label_path):
os.makedirs(val_label_path)
else:
del_file(val_label_path)
if not os.path.exists(test_image_path):
os.makedirs(test_image_path)
else:
del_file(test_image_path)
if not os.path.exists(test_label_path):
os.makedirs(test_label_path)
else:
del_file(test_label_path)
def clearfile():
if os.path.exists(list_train):
os.remove(list_train)
if os.path.exists(list_val):
os.remove(list_val)
if os.path.exists(list_test):
os.remove(list_test)
def main():
mkdir()
clearfile()
file_train = open(list_train, 'w')
file_val = open(list_val, 'w')
file_test = open(list_test, 'w')
total_txt = os.listdir(label_original_path)
num_txt = len(total_txt)
list_all_txt = range(num_txt)
num_train = int(num_txt * train_percent)
num_val = int(num_txt * val_percent)
num_test = num_txt - num_train - num_val
train = random.sample(list_all_txt, num_train)
# train从list_all_txt取出num_train个元素
# 所以list_all_txt列表只剩下了这些元素
val_test = [i for i in list_all_txt if not i in train]
# 再从val_test取出num_val个元素,val_test剩下的元素就是test
val = random.sample(val_test, num_val)
print("训练集数目:{}, 验证集数目:{}, 测试集数目:{}".format(len(train), len(val), len(val_test) - len(val)))
for i in list_all_txt:
name = total_txt[i][:-4]
srcImage = image_original_path + name + image_last_name
srcLabel = label_original_path + name + ".txt"
if i in train:
dst_train_Image = train_image_path + name + image_last_name
dst_train_Label = train_label_path + name + label_last_name
shutil.copyfile(srcImage, dst_train_Image)
shutil.copyfile(srcLabel, dst_train_Label)
file_train.write(dst_train_Image + '\n')
elif i in val:
dst_val_Image = val_image_path + name + image_last_name
dst_val_Label = val_label_path + name + label_last_name
shutil.copyfile(srcImage, dst_val_Image)
shutil.copyfile(srcLabel, dst_val_Label)
file_val.write(dst_val_Image + '\n')
else:
dst_test_Image = test_image_path + name + image_last_name
dst_test_Label = test_label_path + name + label_last_name
shutil.copyfile(srcImage, dst_test_Image)
shutil.copyfile(srcLabel, dst_test_Label)
file_test.write(dst_test_Image + '\n')
file_train.close()
file_val.close()
file_test.close()
if __name__ == "__main__":
main()