-
Notifications
You must be signed in to change notification settings - Fork 24
/
tiny_imagenet.py
119 lines (92 loc) · 4.51 KB
/
tiny_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import os
import pandas as pd
import warnings
from torchvision.datasets import ImageFolder
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import extract_archive, check_integrity, download_url, verify_str_arg
class TinyImageNet(VisionDataset):
"""`tiny-imageNet <http://cs231n.stanford.edu/tiny-imagenet-200.zip>`_ Dataset.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'tiny-imagenet-200/'
url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
filename = 'tiny-imagenet-200.zip'
md5 = '90528d7ca1a48142e341f4ef8d21d0de'
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
super(TinyImageNet, self).__init__(root, transform=transform, target_transform=target_transform)
self.dataset_path = os.path.join(root, self.base_folder)
self.loader = default_loader
self.split = verify_str_arg(split, "split", ("train", "val",))
if self._check_integrity():
print('Files already downloaded and verified.')
elif download:
self._download()
else:
raise RuntimeError(
'Dataset not found. You can use download=True to download it.')
if not os.path.isdir(self.dataset_path):
print('Extracting...')
extract_archive(os.path.join(root, self.filename))
_, class_to_idx = find_classes(os.path.join(self.dataset_path, 'wnids.txt'))
self.data = make_dataset(self.root, self.base_folder, self.split, class_to_idx)
def _download(self):
print('Downloading...')
download_url(self.url, root=self.root, filename=self.filename)
print('Extracting...')
extract_archive(os.path.join(self.root, self.filename))
def _check_integrity(self):
return check_integrity(os.path.join(self.root, self.filename), self.md5)
def __getitem__(self, index):
img_path, target = self.data[index]
image = self.loader(img_path)
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
target = self.target_transform(target)
return image, target
def __len__(self):
return len(self.data)
def find_classes(class_file):
with open(class_file) as r:
classes = list(map(lambda s: s.strip(), r.readlines()))
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_dataset(root, base_folder, dirname, class_to_idx):
images = []
dir_path = os.path.join(root, base_folder, dirname)
if dirname == 'train':
for fname in sorted(os.listdir(dir_path)):
cls_fpath = os.path.join(dir_path, fname)
if os.path.isdir(cls_fpath):
cls_imgs_path = os.path.join(cls_fpath, 'images')
for imgname in sorted(os.listdir(cls_imgs_path)):
path = os.path.join(cls_imgs_path, imgname)
item = (path, class_to_idx[fname])
images.append(item)
else:
imgs_path = os.path.join(dir_path, 'images')
imgs_annotations = os.path.join(dir_path, 'val_annotations.txt')
with open(imgs_annotations) as r:
data_info = map(lambda s: s.split('\t'), r.readlines())
cls_map = {line_data[0]: line_data[1] for line_data in data_info}
for imgname in sorted(os.listdir(imgs_path)):
path = os.path.join(imgs_path, imgname)
item = (path, class_to_idx[cls_map[imgname]])
images.append(item)
return images
if __name__ == '__main__':
train_dataset = TinyImageNet('./tiny-imagenet', split='train', download=False)
test_dataset = TinyImageNet('./tiny-imagenet', split='val', download=False)