-
Notifications
You must be signed in to change notification settings - Fork 0
/
mario_net.py
65 lines (56 loc) · 1.97 KB
/
mario_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import collections
import time
import torch
import torchvision
from torch import nn
from mario_dataset import MarioDataset
class MarioNet(nn.Module):
def __init__(self, outputs=5):
super().__init__()
resnet = torchvision.models.segmentation.lraspp_mobilenet_v3_large(weights='DEFAULT')
layers = list(resnet.children())[:-1]
self.flatten = nn.Flatten()
self.end_layers = nn.Sequential(
nn.Linear(237120, 256),
nn.ReLU(),
# nn.Linear(256, 256),
# nn.ReLU(),
nn.Linear(256, outputs)
)
self.network = nn.Sequential(*layers)
def freeze(self, layers=0):
"""
disables optimization in some or all layers from the pre-trained network
freezes up to 'layers'. passing 0 freezes all layers, and -1 freezes all up to the last
:param layers: number of layers to freeze (0 for all, negative numbers wrap to end)
:return:
"""
if layers <= 0:
layers += len(self.network)
for i, child in enumerate(self.network.children()):
if i >= layers:
break
for param in child.parameters():
param.requires_grad = False
def forward(self, x):
x = self.network(x)
if isinstance(x, collections.OrderedDict):
# x = torch.cat([self.flatten(t) for t in x.values()], dim=1)
x = self.flatten(x['low'])
else:
x = self.flatten(x)
return self.end_layers(x)
def save(self, path):
torch.save(self.state_dict(), path)
def load(self, path):
self.load_state_dict(torch.load(path))
if __name__ == "__main__":
dataset = MarioDataset(['bcdata/luigi-circuit_1.pkl'])
img, contr = dataset[0]
net = MarioNet()
t = time.time()
out = net(img.reshape((1, 3, 832, 456)))
dt = time.time() - t
print(f"compute time: {dt}s")
print(f"predicted: {out}")
print(f"actual: {contr}")