-
Notifications
You must be signed in to change notification settings - Fork 72
/
Copy pathtrainer.py
105 lines (80 loc) · 4.01 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import torch
import torch.nn.functional as F
import torchvision
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils import _create_model_training_folder
class BYOLTrainer:
def __init__(self, online_network, target_network, predictor, optimizer, device, **params):
self.online_network = online_network
self.target_network = target_network
self.optimizer = optimizer
self.device = device
self.predictor = predictor
self.max_epochs = params['max_epochs']
self.writer = SummaryWriter()
self.m = params['m']
self.batch_size = params['batch_size']
self.num_workers = params['num_workers']
self.checkpoint_interval = params['checkpoint_interval']
_create_model_training_folder(self.writer, files_to_same=["./config/config.yaml", "main.py", 'trainer.py'])
@torch.no_grad()
def _update_target_network_parameters(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@staticmethod
def regression_loss(x, y):
x = F.normalize(x, dim=1)
y = F.normalize(y, dim=1)
return 2 - 2 * (x * y).sum(dim=-1)
def initializes_target_network(self):
# init momentum network as encoder net
for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
def train(self, train_dataset):
train_loader = DataLoader(train_dataset, batch_size=self.batch_size,
num_workers=self.num_workers, drop_last=False, shuffle=True)
niter = 0
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
self.initializes_target_network()
for epoch_counter in range(self.max_epochs):
for (batch_view_1, batch_view_2), _ in train_loader:
batch_view_1 = batch_view_1.to(self.device)
batch_view_2 = batch_view_2.to(self.device)
if niter == 0:
grid = torchvision.utils.make_grid(batch_view_1[:32])
self.writer.add_image('views_1', grid, global_step=niter)
grid = torchvision.utils.make_grid(batch_view_2[:32])
self.writer.add_image('views_2', grid, global_step=niter)
loss = self.update(batch_view_1, batch_view_2)
self.writer.add_scalar('loss', loss, global_step=niter)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self._update_target_network_parameters() # update the key encoder
niter += 1
print("End of epoch {}".format(epoch_counter))
# save checkpoints
self.save_model(os.path.join(model_checkpoints_folder, 'model.pth'))
def update(self, batch_view_1, batch_view_2):
# compute query feature
predictions_from_view_1 = self.predictor(self.online_network(batch_view_1))
predictions_from_view_2 = self.predictor(self.online_network(batch_view_2))
# compute key features
with torch.no_grad():
targets_to_view_2 = self.target_network(batch_view_1)
targets_to_view_1 = self.target_network(batch_view_2)
loss = self.regression_loss(predictions_from_view_1, targets_to_view_1)
loss += self.regression_loss(predictions_from_view_2, targets_to_view_2)
return loss.mean()
def save_model(self, PATH):
torch.save({
'online_network_state_dict': self.online_network.state_dict(),
'target_network_state_dict': self.target_network.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}, PATH)