-
Notifications
You must be signed in to change notification settings - Fork 5
/
convert_tfrecord.py
executable file
·195 lines (140 loc) · 6.88 KB
/
convert_tfrecord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# -*- coding:utf-8 -*-
# read and convert ai challenger label json file
import numpy as np
import tensorflow as tf
import datetime
import math
import sys
import random
import threading
import json
import os
import matplotlib.pyplot as plt
import shutil
from prepare_data_global_model import prepare_global_label
def int64_feature(values):
"""Returns a TF-Feature of int64s.
Args:
values: A scalar or list of values.
Returns:
a TF-Feature.
"""
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def bytes_feature(values):
"""Returns a TF-Feature of bytes.
Args:
values: A string.
Returns:
a TF-Feature.
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
def image_to_tfexample(image_data, image_format, height, width, class_id, filename):
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature(image_format),
'image/class/label': int64_feature(class_id),
'image/height': int64_feature(height),
'image/width': int64_feature(width),
'image/filename': bytes_feature(filename),
}))
# 图片读取类,有两个方法,分别可以读图片的维度,返回图片的宽和高、 对图片原始数据进行转码
# 但是它实际上就起到了一个作用,获取图片的尺寸
class ImageReader(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def read_image_dims(self, sess, image_data):
image = self.decode_jpeg(sess, image_data)
return image.shape[0], image.shape[1]
def decode_jpeg(self, sess, image_data):
image = sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
# 获取要输出的tfrecord的文件名称,最后格式类似于 flowers_train_00001-of-00005.tfrecord
def _get_dataset_filename(TFRECORD_TARGET_PATH, split_name, shard_id, _NUM_SHARDS):
output_filename = '%s_%05d-of-%05d.tfrecord' % (
split_name, shard_id, _NUM_SHARDS)
return os.path.join(TFRECORD_TARGET_PATH, output_filename)
def _convert_dataset( split_name, filenames, labels, TFRECORD_TARGET_PATH, _NUM_SHARDS):
# """Converts the given filenames to a TFRecord dataset.
num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
with tf.Graph().as_default():
image_reader = ImageReader()
with tf.Session('') as sess:
for shard_id in range(_NUM_SHARDS):
output_filename = _get_dataset_filename(TFRECORD_TARGET_PATH, split_name, shard_id,_NUM_SHARDS)
start_ndx = shard_id * num_per_shard
end_ndx = min((shard_id + 1) * num_per_shard, len(filenames))
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
for i in range(start_ndx, end_ndx):
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
i+1, len(filenames), shard_id))
sys.stdout.flush()
file_path = os.path.join(TEMP_TARGET_PATH, filenames[i]+'.jpg')
image_data = tf.gfile.FastGFile(file_path, 'rb').read()
height, width = image_reader.read_image_dims(sess, image_data)
example = image_to_tfexample( image_data, b'jpg', height, width,
list(labels[i]),
bytes( filenames[i] ))
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
def convert_original_iamge(split_name='train'):
start_time = datetime.datetime.now()
if split_name == 'train': # jpeg for hollywoods ;
NUM_SHARDS = 10
temp = np.loadtxt(TRAIN_IMAGE_PATH, dtype='str')
img_names = [os.path.join(TRAIN_BASIC_PATH, 'JPEGImages', imgname+'.jpeg') for
imgname in temp]
xml_paths = [os.path.join(TRAIN_BASIC_PATH, 'Annotations', imgname+'.xml') for
imgname in temp]
elif split_name == 'test':
NUM_SHARDS = 1
temp = np.loadtxt(TEST_IMAGE_PATH, dtype='str')
img_names = [os.path.join(TRAIN_BASIC_PATH, 'JPEGImages', imgname+'.jpeg') for
imgname in temp]
xml_paths = [os.path.join(TRAIN_BASIC_PATH, 'Annotations', imgname+'.xml') for
imgname in temp]
elif split_name == 'val':
NUM_SHARDS = 1
temp = np.loadtxt(VAL_IMAGE_PATH, dtype='str')
img_names = [os.path.join(TRAIN_BASIC_PATH, 'JPEGImages', imgname+'.jpeg') for
imgname in temp]
xml_paths = [os.path.join(TRAIN_BASIC_PATH, 'Annotations', imgname+'.xml') for
imgname in temp]
else:
NUM_SHARDS = 1
temp = np.loadtxt(global_mat_IMAGE_PATH, dtype='str')
img_names = [os.path.join(TRAIN_BASIC_PATH, 'JPEGImages', imgname+'.jpeg') for
imgname in temp]
xml_paths = [os.path.join(TRAIN_BASIC_PATH, 'Annotations', imgname+'.xml') for
imgname in temp]
labels = []
for i in range(len(temp)):
img_resized, label = prepare_global_label( img_names[i], xml_paths[i])
# plt.savefig(os.path.join(TEMP_TARGET_PATH,temp[i]+'.jpg'))
img_resized.save(os.path.join(TEMP_TARGET_PATH,temp[i]+'.jpg') )
np.savetxt(os.path.join(TEMP_TARGET_PATH,temp[i]+'.txt') ,label)
# label = np.loadtxt(os.path.join(TEMP_TARGET_PATH,temp[i]+'.txt') )
labels.append([int(i) for i in label])
_convert_dataset(split_name, temp, labels, TFRECORD_TARGET_PATH, NUM_SHARDS) # 51.24 122.81
print('\nFinished converting the dataset to tfrecord! Time cost ',datetime.datetime.now()-start_time)
if __name__ == '__main__':
# hollywood data test:1297, train: 216694 , val: 6676
# 这里是源数据文件位置,需要为 VOC 格式
TRAIN_BASIC_PATH = '/nishome/zl/faster-rcnn/data/HollywoodHeads/'
TEST_IMAGE_PATH= os.path.join(TRAIN_BASIC_PATH, 'ImageSets/Main/test.txt',)
TRAIN_IMAGE_PATH = os.path.join(TRAIN_BASIC_PATH, 'ImageSets/Main/train.txt')
VAL_IMAGE_PATH = os.path.join(TRAIN_BASIC_PATH, 'ImageSets/Main/val.txt')
# 保存文件位置
TFRECORD_TARGET_PATH = os.path.join(TRAIN_BASIC_PATH, 'tfrecord')
TEMP_TARGET_PATH = os.path.join(TRAIN_BASIC_PATH, 'temp')
# convert_original_iamge(split_name='test')
# convert_original_iamge(split_name='val')
convert_original_iamge(split_name='train')