-
Notifications
You must be signed in to change notification settings - Fork 28
/
data.py
91 lines (76 loc) · 3.89 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torchvision
import sys
PATH_TO_CIFAR = "./cifar/"
sys.path.append(PATH_TO_CIFAR)
import train as cifar_train
def get_inp_tar(dataset):
return dataset.data.view(dataset.data.shape[0], -1).float(), dataset.targets
def get_mnist_dataset(root, is_train, to_download, return_tensor=False):
mnist = torchvision.datasets.MNIST(root, train=is_train, download=to_download,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
# only 1 channel
(0.1307,), (0.3081,))
]))
if not return_tensor:
return mnist
else:
return get_inp_tar(mnist)
def get_dataloader(args, unit_batch = False, no_randomness=False):
if unit_batch:
bsz = (1, 1)
else:
bsz = (args.batch_size_train, args.batch_size_test)
if no_randomness:
enable_shuffle = False
else:
enable_shuffle = True
if args.dataset.lower() == 'mnist':
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./files/', train=True, download=args.to_download,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
# only 1 channel
(0.1307,), (0.3081,))
])),
batch_size=bsz[0], shuffle=enable_shuffle
)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./files/', train=False, download=args.to_download,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=bsz[1], shuffle=enable_shuffle
)
return train_loader, test_loader
elif args.dataset.lower() == 'cifar10':
if args.cifar_style_data:
train_loader, test_loader = cifar_train.get_dataset(args.config)
else:
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.CIFAR10('./data/', train=True, download=args.to_download,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
# Note this normalization is not same as in MNIST
# (mean_ch1, mean_ch2, mean_ch3), (std1, std2, std3)
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])),
batch_size=bsz[0], shuffle=enable_shuffle
)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.CIFAR10('./data/', train=False, download=args.to_download,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
# (mean_ch1, mean_ch2, mean_ch3), (std1, std2, std3)
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])),
batch_size=bsz[1], shuffle=enable_shuffle
)
return train_loader, test_loader