-
Notifications
You must be signed in to change notification settings - Fork 4
/
dataset_ingredient.py
81 lines (66 loc) · 2.56 KB
/
dataset_ingredient.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import Callable, Optional, Tuple
from sacred import Ingredient
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import Cityscapes, VOCSegmentation
from torchvision.transforms import InterpolationMode
from utils import label_map_cityscapes
dataset_ingredient = Ingredient('dataset')
@dataset_ingredient.named_config
def pascal_voc_2012():
name = 'pascal_voc_2012'
root = 'data'
split = 'val'
size = 512
num_images = None
@dataset_ingredient.named_config
def cityscapes():
name = 'cityscapes'
root = 'data/cityscapes'
split = 'val'
size = (1024, 2048)
num_images = None
@dataset_ingredient.capture
def get_pascal_voc_2012(root: str, size: int, split: str,
num_images: Optional[int] = None) -> Tuple[DataLoader, Optional[Callable]]:
transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
])
target_transform = transforms.Compose([
transforms.PILToTensor(),
transforms.Resize(size, interpolation=InterpolationMode.NEAREST)
])
dataset = VOCSegmentation(root=root, year='2012', image_set=split,
transform=transform, target_transform=target_transform)
if num_images is not None:
assert num_images <= len(dataset)
dataset = Subset(dataset, indices=list(range(num_images)))
loader = DataLoader(dataset=dataset, shuffle=False)
return loader, None
@dataset_ingredient.capture
def get_cityscapes(root: str, size: int, split: str,
num_images: Optional[int] = None, batch_size: int = 1) -> Tuple[DataLoader, Optional[Callable]]:
transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
])
target_transform = transforms.Compose([
label_map_cityscapes,
transforms.PILToTensor(),
transforms.Resize(size, interpolation=InterpolationMode.NEAREST)
])
dataset = Cityscapes(root=root, split=split, target_type='semantic',
transform=transform, target_transform=target_transform)
if num_images is not None:
assert num_images <= len(dataset)
dataset = Subset(dataset, indices=list(range(num_images)))
loader = DataLoader(dataset=dataset, shuffle=False, batch_size=batch_size)
return loader, label_map_cityscapes
_loaders = {
'pascal_voc_2012': get_pascal_voc_2012,
'cityscapes': get_cityscapes
}
@dataset_ingredient.capture
def get_dataset(name: str) -> DataLoader:
return _loaders[name]()