-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_data.py
208 lines (181 loc) · 7.75 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import numpy as np
import os
import random
import torch
from torch.utils.data import IterableDataset
import time
import imageio
from PIL import Image
import torchvision.transforms as T
def get_images(paths, labels, nb_samples=None, shuffle=True):
"""
Takes a set of character folders and labels and returns paths to image files
paired with labels.
Args:
paths: A list of character folders
labels: List or numpy array of same length as paths
nb_samples: Number of images to retrieve per character
Returns:
List of (label, image_path) tuples
"""
if nb_samples is not None:
sampler = lambda x: random.sample(x, nb_samples)
else:
sampler = lambda x: x
images_labels = [
(i, os.path.join(path, image))
for i, path in zip(labels, paths)
for image in sampler(os.listdir(path))
]
if shuffle:
random.shuffle(images_labels)
return images_labels
class DataGenerator(IterableDataset):
"""
Data Generator capable of generating batches of Omniglot data.
A "class" is considered a class of omniglot digits.
"""
def __init__(
self,
num_classes,
num_samples_per_class,
batch_type,
config={},
device=torch.device("cpu"),
cache=True,
augment_support_set=False,
augmenter=None
):
"""
Args:
num_classes: Number of classes for classification (N-way)
num_samples_per_class: num samples to generate per class in one batch (K+1)
batch_size: size of meta batch size (e.g. number of functions)
batch_type: train/val/test
"""
self.num_samples_per_class = num_samples_per_class
self.num_classes = num_classes
data_folder = config.get("data_folder", "./omniglot_resized")
self.img_size = config.get("img_size", (28, 28))
self.dim_input = np.prod(self.img_size)
self.dim_output = self.num_classes
character_folders = [
os.path.join(data_folder, family, character)
for family in os.listdir(data_folder)
if os.path.isdir(os.path.join(data_folder, family))
for character in os.listdir(os.path.join(data_folder, family))
if os.path.isdir(os.path.join(data_folder, family, character))
]
random.seed(1)
random.shuffle(character_folders)
num_val = 100
num_train = 1100
self.metatrain_character_folders = character_folders[:num_train]
self.metaval_character_folders = character_folders[num_train : num_train + num_val]
self.metatest_character_folders = character_folders[num_train + num_val :]
self.device = device
self.image_caching = cache
self.stored_images = {}
self.augment_support_set = augment_support_set
self.augmenter = augmenter
if batch_type == "train":
self.folders = self.metatrain_character_folders
elif batch_type == "val":
self.folders = self.metaval_character_folders
else:
self.folders = self.metatest_character_folders
def image_file_to_array(self, filename, dim_input, augment=False):
"""
Takes an image path and returns numpy array
Args:
filename: Image filename
dim_input: Flattened shape of image
Returns:
1 channel image
"""
if self.image_caching and (filename in self.stored_images):
return self.stored_images[filename]
image = imageio.imread(filename) # misc.imread(filename)
if augment:
image = self.augment_image(image)
image = image.reshape([dim_input])
image = image.astype(np.float32) / 255.0
image = 1.0 - image
if self.image_caching:
self.stored_images[filename] = image
return image
def augment_image(self, image):
if self.augmenter == "randaug":
augmenter = T.RandAugment()
elif self.augmenter == "autoaug_cifar10":
augmenter = T.AutoAugment(T.AutoAugmentPolicy.CIFAR10)
elif self.augmenter == "autoaug_imagenet":
augmenter = T.AutoAugment(T.AutoAugmentPolicy.IMAGENET)
elif self.augmenter == "autoaug_svhn":
augmenter = T.AutoAugment(T.AutoAugmentPolicy.SVHN)
else:
print(f'Augmenter {self.augmenter} not supported, defaulting to RandAug')
augmenter = T.RandAugment()
image = augmenter(image)
return image
def _sample(self):
"""
Samples a batch for training, validation, or testing
Args:
does not take any arguments
Returns:
A tuple of (1) Image batch and (2) Label batch:
1. image batch has shape [K+1, N, 784] and
2. label batch has shape [K+1, N, N]
where K is the number of "shots", N is number of classes
Note:
1. The numpy functions np.random.shuffle and np.eye (for creating)
one-hot vectors would be useful.
2. For shuffling, remember to make sure images and labels are shuffled
in the same order, otherwise the one-to-one mapping between images
and labels may get messed up. Hint: there is a clever way to use
np.random.shuffle here.
3. The value for `self.num_samples_per_class` will be set to K+1
since for K-shot classification you need to sample K supports and
1 query.
"""
#############################
#### YOUR CODE GOES HERE ####
# Sample N different characters from specified folders
char_folders = random.sample(self.folders, self.num_classes)
# Load K + 1 images per character
one_hot_labels = np.eye(self.num_classes)
# Length = N X (K + 1)
images_labels = get_images(char_folders, one_hot_labels, self.num_samples_per_class, shuffle=False)
train_images, train_labels = [], []
test_images, test_labels = [], []
for idx, (label, img_path) in enumerate(images_labels):
# Append the first image of each class to test set
if idx % self.num_samples_per_class == 0:
test_images.append(self.image_file_to_array(img_path, 784))
test_labels.append(label)
else:
train_images.append(self.image_file_to_array(img_path, 784))
train_labels.append(label)
if self.augment_support_set:
train_images.append(self.image_file_to_array(img_path, 784, augment=True))
train_labels.append(label)
# Shuffle the query / test dataset
test_dataset = list(zip(test_images, test_labels))
np.random.shuffle(test_dataset)
test_images, test_labels = zip(*test_dataset)
test_images = list(test_images)
test_labels = list(test_labels)
# Format the data into images [2K + 1, N, 784] and one-hot labels [2K + 1, N, N]
# Format the data into images [K + 1, N, 784] and one-hot labels [K + 1, N, N]
if self.augment_support_set:
images = np.vstack(train_images + test_images).reshape((2 * self.num_samples_per_class - 1, self.num_classes, -1))
labels = np.vstack(train_labels + test_labels).reshape((-1, self.num_classes, self.num_classes))
else:
images = np.vstack(train_images + test_images).reshape((self.num_samples_per_class, self.num_classes, -1))
labels = np.vstack(train_labels + test_labels).reshape((-1, self.num_classes, self.num_classes))
return images, labels
#############################
def __iter__(self):
while True:
yield self._sample()