-
Notifications
You must be signed in to change notification settings - Fork 2
/
iCIFAR100.py
84 lines (67 loc) · 2.73 KB
/
iCIFAR100.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from PIL import Image
from torchvision.datasets import CIFAR100
class iCIFAR100(CIFAR100):
def __init__(self, root, train=True, t1=None, t2=None, transform=None, target_transform=None, download=False):
super(iCIFAR100, self).__init__(root, train=train, download=download, transform=transform, target_transform=target_transform)
self.t1 = t1
self.t2 = t2
self.train = train # training set or test set
self.class_to_idx = dict()
for index, (img, label) in enumerate(zip(self.data, self.targets)):
if label in self.class_to_idx.keys():
self.class_to_idx[label].append(index)
else:
self.class_to_idx[label] = list()
self.class_to_idx[label].append(index)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
def get_images_by_indexes(self, indexes):
"""
Retrieve all the images from a given list of indexes.
:param indexes: list of all indexes.
:return: Two list with the same images but with different transformations.
"""
images1, images2 = list(), list()
for i in indexes:
images1.append(self.t1(Image.fromarray(self.data[i])))
images2.append(self.t2(Image.fromarray(self.data[i])))
return images1, images2
def get_indexes_by_classes(self, classes):
"""
Retrieve all the indexes of the images from a given range of classes.
:param classes: range of classes.
:return: a list of indexes.
"""
a = list()
for i in classes:
a.extend(self.class_to_idx[i])
return a
def get_images_by_class(self, label):
"""
Retrieve all images of that class "label"
:param label: class
:return: three list with all images with trans1, with trans2 and all the indexes.
"""
a, b, c = list(), list(), list()
for index, i in enumerate(self.data):
if self.targets[index] == label:
a.append(self.t1(Image.fromarray(i)))
b.append(index)
c.append(self.t2(Image.fromarray(i)))
return a, b, c