-
Notifications
You must be signed in to change notification settings - Fork 0
/
sample_code.py
123 lines (104 loc) · 3.84 KB
/
sample_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
standardize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.485,.456,.406],std=[.229,.224,.225])
])
alter = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(84,padding=10),
standardize
])
def batchmaker(way,trainshot,testshot,theset,alterful=False):
classes = np.random.choice(len(theset),way)
if alterful:
li = [torch.cat([alter(theset[cl][i]).view(1,3,84,84) for i in
np.random.choice(600,trainshot+testshot)],dim=0).float()
for cl in classes]
else:
li = [torch.cat([standardize(theset[cl][i]).view(1,3,84,84) for i in
np.random.choice(600,trainshot+testshot)],dim=0).float()
for cl in classes]
support = torch.cat([t[:trainshot,:,:,:] for t in li],dim=0)
stargs = torch.LongTensor([i//trainshot for i in range(trainshot*way)])
query = torch.cat([t[trainshot:,:,:,:] for t in li],dim=0)
qtargs = torch.LongTensor([i//testshot for i in range(testshot*way)])
return(Variable(support).cuda(),
Variable(query, volatile=(not alterful)).cuda(),
Variable(qtargs, volatile=(not alterful)).cuda(),
Variable(stargs).cuda()
)
vbity = 200
epoch = 2000
start = time.time()
runningloss = 0
for it in range(10*epoch):
if it%10==0:
print(it)
# Build batch
support, query, targs, _ = batchmaker(way,trainshot,testshot,trainset,alterful=True)
# Predict
embed.zero_grad()
model.zero_grad()
embeds = embed(support)
qembeds = embed(query)
preds = model(embeds,qembeds,way)
# Calculate Loss
loss = criterion(preds, targs)
runningloss += loss.data[0]
# Backprop
if it%epoch == 0:
optimizer = optim.Adam(embed.parameters(),lr=.001/(2**(it//epoch)))
loss.backward()
# nn.utils.clip_grad_norm(model.parameters(), 1)
optimizer.step()
# Report
if it%vbity == vbity-1:
display.clear_output(wait=True)
losstracker.append(runningloss/vbity)
embed = embed.eval()
evalloss, evalacc, _ = evaluate(embed,model,criterion,evalway,trainshot,testshot,reps,testset)
embed = embed.train()
evallosstracker.append(evalloss)
evalacctracker.append(evalacc)
pl.figure(1,figsize=(15,5))
pl.subplot(1,2,1)
pl.plot(losstracker)
pl.plot(evallosstracker)
pl.ylim((.5,3))
pl.title("Loss: Training Blue, Validation Gold")
pl.subplot(1,2,2)
pl.plot(evalacctracker[::-1])
pl.ylim((0.3,.8))
pl.title("Validation Acc")
pl.show()
print("Train loss is: "+str(runningloss/vbity)+
"\nValidation accuracy is: "+str(evalacc)+
"\nValidation loss is: "+str(evalloss)+"\n")
runningloss = 0
print(time.time()-start)
class Block(nn.Module):
def __init__(self, insize, outsize):
super(Block, self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(insize, outsize, kernel_size=3, padding=1),
nn.BatchNorm2d(outsize),
nn.ReLU(),
nn.MaxPool2d(2)
)
def forward(self, inp):
return self.layers(inp)
class ENCODER(nn.Module):
def __init__(self):
super(ENCODER, self).__init__()
self.block1 = Block(3,64)
self.block2 = Block(64,64)
self.block3 = Block(64,64)
self.block4 = Block(64,64)
# self.final = nn.Conv2d(64,64, kernel_size=1, padding=0)
def forward(self, inp):
out = self.block1(inp)
out = self.block2(out)
out = self.block3(out)
out = self.block4(out)
# out = self.final(out)
return out.view(out.size(0),-1)