-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
142 lines (113 loc) · 4.83 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from PIL import Image
import torch.utils.data as data
from cars_transforms import transform_visualize, transform_predict
from torchvision.datasets import StanfordCars
import pandas as pd
from torchvision import transforms
class StanfordCarsCAM(StanfordCars):
'''
Dataset bounding Stanford dataset to given category:
brand, car_type
Full path to dataset is defined by:
{self.root_dir}/{self.state}/stanford_cars/cars_{self.state}
Attributes
----------
root_dir: str
path to stanford_cars_dataset content
state: str
train or test
car_brand: str
limits records by given car_brand
car_type: str
limits records by given car_type
car_production_year: int
limits records by given car_production_year
'''
def __init__(
self,
root: str,
split: str,
car_brand: str = None,
car_type: str = None,
car_production_year: int = None,
download_datasets: bool = False,
generate_img_for_cam: bool = False,
transform_prediction: transforms.Compose = transform_predict,
transform_visualization: transforms.Compose = transform_visualize,
) -> None:
super().__init__(root=root, split=split, download=download_datasets)
self.car_brand = car_brand
self.car_type = car_type
self.car_production_year = car_production_year
self.generate_img_for_cam = generate_img_for_cam
self.transform_prediction = transform_prediction
self.transform_visualization = transform_visualization
self.classes_specification = self._classes_specification()
# update inherited fields according to filter conditions
self.classes = self.classes_specification["car_class"].to_list()
self.class_to_idx = {idx: car_class for idx, car_class in
zip(self.classes_specification["new_idx"].to_list(),
self.classes_specification["car_class"].to_list())}
self._samples = self._filter_samples()
def __getitem__(self, idx):
image_path, target = self._samples[idx]
image = Image.open(image_path).convert("RGB")
image_predict = self.transform_prediction(image)
# generate transformed image for cam purpose
if self.generate_img_for_cam:
image_visualize = self.transform_visualization(image)
return image_predict, image_visualize, target
return image_predict, target
def _filter_samples(self):
updated_samples_list = []
old_idxs = set(self.classes_specification["old_idx"].to_list())
old_to_new_idxs = {idx: car_class for idx, car_class in
zip(self.classes_specification["old_idx"].to_list(),
self.classes_specification["new_idx"].to_list())}
for sample in self._samples:
path = sample[0]
old_idx = sample[1]
if old_idx in old_idxs:
new_idx = old_to_new_idxs[old_idx]
updated_samples_list.append((path, new_idx))
return updated_samples_list
def _classes_specification(self):
'''
creates specification of samples including:
- full class name
- class id
- brand
- type
- year of production
'''
car_classes = []
old_idxs = []
car_brands = []
car_types = []
car_production_years = []
for car_class in self.class_to_idx.keys():
car_classes.append(car_class)
old_idxs.append(self.class_to_idx[car_class])
class_record = car_class.split(" ")
car_brands.append(class_record[0] if class_record[0] != "Land" else "Land Rover")
car_production_years.append(int(class_record[-1]))
car_types.append(class_record[-2])
specification = pd.DataFrame({
"car_class": car_classes,
"old_idx": old_idxs,
"car_brand": car_brands,
"car_type": car_types,
"car_production_year": car_production_years
})
# filter specification by car_brand
if self.car_brand:
specification = specification[specification["car_brand"].isin(self.car_brand.split())]
# filter specification by car_type
if self.car_type:
specification = specification[specification["car_type"].isin(self.car_type.split())]
# filter specification by car_production_year
if self.car_production_year:
specification = specification[specification["car_production_year"] == self.car_production_year]
# adjust ids for model
specification["new_idx"] = range(len(specification))
return specification