-
Notifications
You must be signed in to change notification settings - Fork 6
/
cifar.py
134 lines (109 loc) · 4.04 KB
/
cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# -*- coding: utf-8 -*-
"""
@author: Abderrahmen Amich
@email: aamich@umich.edu
"""
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from easydict import EasyDict
def get_transform():
'''
@return: torch transforms to use in CIFAR10
'''
transform = transforms.Compose([
transforms.Resize( (32,32) ),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
return transform
def get_datasets(root='data', train=True, test=True, transform=None, batch=128, model='target'):
'''
@brief: function that obtains the CIFAR10 dataset and return
the referent DataLoaders
@param root: place to store original data
@param train: when True, returns the train data
@param test: when True, returns the test data
@param batch: batch size to dataloaders
@return: dictionary composed by: 'train' and 'test' datasets and
the name of the 'classes'.
Dictionary keys: 'train', 'test', 'classes'
'''
assert train or test, 'You must select train, test, or both'
ret = {}
transform = get_transform() if transform is None else transform
if train:
trainset = torchvision.datasets.CIFAR10(
root=root, train=True, download=True, transform=transform
)
if model=='copycat':
shuffle=False
else:
shuffle=True
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch, shuffle=shuffle, num_workers=2
)
ret['train'] = trainloader
ret['n_train'] = len(trainset)
if test:
testset = torchvision.datasets.CIFAR10(
root=root, train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch, shuffle=False, num_workers=2
)
ret['test'] = testloader
ret['n_test'] = len(testset)
if train==True and test==False:
return EasyDict(train=trainloader)
if train==False and test==True:
return EasyDict(test=testloader)
if train==True and test==True:
return EasyDict(train=trainloader,test=testloader)
class CNN(nn.Module):
"""Sample model."""
def __init__(self):
super(CNN, self).__init__()
self.conv_layer = nn.Sequential(
# Conv Layer block 1
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
# Conv Layer block 2
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Dropout2d(p=0.05),
# Conv Layer block 3
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fc_layer = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(4096, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Dropout(p=0.1),
nn.Linear(512, 10)
)
def forward(self, x):
# conv layers
x = self.conv_layer(x)
# flatten
x = x.view(x.size(0), -1)
# fc layer
x = self.fc_layer(x)
return x