forked from iskenderkahramanoglu/FARNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
38 lines (29 loc) · 1.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from config import Config
Config = Config()
def train_model(model, criterion, optimizer, scheduler, train_loader, num_epochs=300):
for epoch in range(0, num_epochs):
print('Epoch{}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
model.train()
loss_temp = 0
for i, (img, heatmaps, heatmaps_refine, img_name, x_all, y_all) in enumerate(train_loader):
img = img.to(Config.device)
heatmaps = heatmaps.to(Config.device)
heatmaps_refine = heatmaps_refine.to(Config.device)
outputs, outputs_refine = model(img)
loss = criterion(outputs, heatmaps)
ratio = torch.pow(Config.base_number, heatmaps)
loss = torch.mul(loss, ratio)
loss = torch.mean(loss)
loss_temp += loss
loss_refine = criterion(outputs_refine, heatmaps_refine)
ratio_refine = torch.pow(Config.base_number, heatmaps_refine)
loss_refine = torch.mul(loss_refine, ratio_refine)
loss_refine = torch.mean(loss_refine)
loss = loss + loss_refine
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
return model