-
Notifications
You must be signed in to change notification settings - Fork 3
/
cifar10.py
235 lines (167 loc) · 6.9 KB
/
cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
########################################################################
#
# Functions for downloading the CIFAR-10 data-set from the internet
# and loading it into memory.
#
# Implemented in Python 3.5
#
# Usage:
# 1) Set the variable data_path with the desired storage path.
# 2) Call maybe_download_and_extract() to download the data-set
# if it is not already located in the given data_path.
# 3) Call load_class_names() to get an array of the class-names.
# 4) Call load_training_data() and load_test_data() to get
# the images, class-numbers and one-hot encoded class-labels
# for the training-set and test-set.
# 5) Use the returned data in your own program.
#
# Format:
# The images for the training- and test-sets are returned as 4-dim numpy
# arrays each with the shape: [image_number, height, width, channel]
# where the individual pixels are floats between 0.0 and 1.0.
#
########################################################################
#
# This file is part of the TensorFlow Tutorials available at:
#
# https://github.com/Hvass-Labs/TensorFlow-Tutorials
#
# Published under the MIT License. See the file LICENSE for details.
#
# Copyright 2016 by Magnus Erik Hvass Pedersen
#
########################################################################
import numpy as np
import pickle
import os
import download
from dataset import one_hot_encoded
########################################################################
# Directory where you want to download and save the data-set.
# Set this before you start calling any of the functions below.
data_path = "data/CIFAR-10/"
# URL for the data-set on the internet.
data_url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
########################################################################
# Various constants for the size of the images.
# Use these constants in your own program.
# Width and height of each image.
img_size = 32
# Number of channels in each image, 3 channels: Red, Green, Blue.
num_channels = 3
# Length of an image when flattened to a 1-dim array.
img_size_flat = img_size * img_size * num_channels
# Number of classes.
num_classes = 10
########################################################################
# Various constants used to allocate arrays of the correct size.
# Number of files for the training-set.
_num_files_train = 5
# Number of images for each batch-file in the training-set.
_images_per_file = 10000
# Total number of images in the training-set.
# This is used to pre-allocate arrays for efficiency.
_num_images_train = _num_files_train * _images_per_file
########################################################################
# Private functions for downloading, unpacking and loading data-files.
def _get_file_path(filename=""):
"""
Return the full path of a data-file for the data-set.
If filename=="" then return the directory of the files.
"""
return os.path.join(data_path, "cifar-10-batches-py/", filename)
def _unpickle(filename):
"""
Unpickle the given file and return the data.
Note that the appropriate dir-name is prepended the filename.
"""
# Create full path for the file.
file_path = _get_file_path(filename)
print("Loading data: " + file_path)
with open(file_path, mode='rb') as file:
# In Python 3.X it is important to set the encoding,
# otherwise an exception is raised here.
data = pickle.load(file, encoding='bytes')
return data
def _convert_images(raw):
"""
Convert images from the CIFAR-10 format and
return a 4-dim array with shape: [image_number, height, width, channel]
where the pixels are floats between 0.0 and 1.0.
"""
# Convert the raw images from the data-files to floating-points.
raw_float = np.array(raw, dtype=float) / 255.0
# Reshape the array to 4-dimensions.
images = raw_float.reshape([-1, num_channels, img_size, img_size])
# Reorder the indices of the array.
images = images.transpose([0, 2, 3, 1])
return images
def _load_data(filename):
"""
Load a pickled data-file from the CIFAR-10 data-set
and return the converted images (see above) and the class-number
for each image.
"""
# Load the pickled data-file.
data = _unpickle(filename)
# Get the raw images.
raw_images = data[b'data']
# Get the class-numbers for each image. Convert to numpy-array.
cls = np.array(data[b'labels'])
# Convert the images.
images = _convert_images(raw_images)
return images, cls
########################################################################
# Public functions that you may call to download the data-set from
# the internet and load the data into memory.
def maybe_download_and_extract():
"""
Download and extract the CIFAR-10 data-set if it doesn't already exist
in data_path (set this variable first to the desired path).
"""
download.maybe_download_and_extract(url=data_url, download_dir=data_path)
def load_class_names():
"""
Load the names for the classes in the CIFAR-10 data-set.
Returns a list with the names. Example: names[3] is the name
associated with class-number 3.
"""
# Load the class-names from the pickled file.
raw = _unpickle(filename="batches.meta")[b'label_names']
# Convert from binary strings.
names = [x.decode('utf-8') for x in raw]
return names
def load_training_data():
"""
Load all the training-data for the CIFAR-10 data-set.
The data-set is split into 5 data-files which are merged here.
Returns the images, class-numbers and one-hot encoded class-labels.
"""
# Pre-allocate the arrays for the images and class-numbers for efficiency.
images = np.zeros(shape=[_num_images_train, img_size, img_size, num_channels], dtype=float)
cls = np.zeros(shape=[_num_images_train], dtype=int)
# Begin-index for the current batch.
begin = 0
# For each data-file.
for i in range(_num_files_train):
# Load the images and class-numbers from the data-file.
images_batch, cls_batch = _load_data(filename="data_batch_" + str(i + 1))
# Number of images in this batch.
num_images = len(images_batch)
# End-index for the current batch.
end = begin + num_images
# Store the images into the array.
images[begin:end, :] = images_batch
# Store the class-numbers into the array.
cls[begin:end] = cls_batch
# The begin-index for the next batch is the current end-index.
begin = end
return images, cls, one_hot_encoded(class_numbers=cls, num_classes=num_classes)
def load_test_data():
"""
Load all the test-data for the CIFAR-10 data-set.
Returns the images, class-numbers and one-hot encoded class-labels.
"""
images, cls = _load_data(filename="test_batch")
return images, cls, one_hot_encoded(class_numbers=cls, num_classes=num_classes)
########################################################################