-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
76 lines (62 loc) · 2.25 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import torch
import pandas as pd
import numpy as np
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, random_split, DataLoader
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
from collections import OrderedDict
import pickle
def get_default_device():
return torch.device('cpu')
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
def to_device(data, device):
if isinstance(data, (list, tuple)):
return [to_device(d, device) for d in data]
else:
return data.to(device, non_blocking=True)
device = get_default_device()
class ImageClassificationBase(nn.Module):
# training step
def training_step(self, batch):
img, targets = batch
out = self(img)
loss = F.nll_loss(out, targets)
return loss
# validation step
def validation_step(self, batch):
img, targets = batch
out = self(img)
loss = F.nll_loss(out, targets)
acc = accuracy(out, targets)
return {'val_acc':acc.detach(), 'val_loss':loss.detach()}
# validation epoch end
def validation_epoch_end(self, outputs):
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean()
batch_accs = [x['val_acc'] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean()
return {'val_loss':epoch_loss.item(), 'val_acc':epoch_acc.item()}
# print result end epoch
def epoch_end(self, epoch, result):
print("Epoch [{}] : train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result["train_loss"], result["val_loss"], result["val_acc"]))
class DogBreedPretrainedGoogleNet(ImageClassificationBase):
def __init__(self):
super().__init__()
self.network = models.googlenet(pretrained=True)
# Replace last layer
num_ftrs = self.network.fc.in_features
self.network.fc = nn.Sequential(
nn.Linear(num_ftrs, 120),
nn.LogSoftmax(dim=1)
)
def forward(self, xb):
return self.network(xb)