-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
31 lines (20 loc) · 943 Bytes
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from torchvision.datasets import CIFAR10, CIFAR100
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
DATASET_NAME_TO_DATASET_CLASS = {
"CIFAR10": CIFAR10,
"CIFAR100": CIFAR100,
}
def create_dataloader(dataset_name):
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = DATASET_NAME_TO_DATASET_CLASS[dataset_name](root="./", train=True,
download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=256,
shuffle=True, num_workers=4)
testset = DATASET_NAME_TO_DATASET_CLASS[dataset_name](root="./", train=False,
download=True, transform=transform)
testloader = DataLoader(testset, batch_size=256,
shuffle=True, num_workers=4)
return trainloader, testloader