-
Notifications
You must be signed in to change notification settings - Fork 23
/
imagenet_dataset.py
69 lines (57 loc) · 2.25 KB
/
imagenet_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2019 Apple Inc. All Rights Reserved.
#
import torchvision.datasets as datasets
from PIL import Image
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# fall back to PIL Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class ImageFolderWithIdx(datasets.ImageFolder):
"""
Extends ImageFolder dataset to yield index of element in dataset in addition to image and target label.
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self,
root,
transform=None,
target_transform=None,
loader=default_loader):
super(ImageFolderWithIdx, self).__init__(root=root,
transform=transform,
target_transform=target_transform,
loader=loader)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target, index) where index is the index of this sample in dataset.
"""
sample, target = super().__getitem__(index)
return sample, target, index