-
Notifications
You must be signed in to change notification settings - Fork 0
/
engine_finetune.py
128 lines (95 loc) · 4.25 KB
/
engine_finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import math
import sys
from typing import Iterable
import util.misc as misc
import util.lr_sched as lr_sched
import torch
from util.logging import master_print as print
def train_one_epoch(
model: torch.nn.Module,
criterion: torch.nn.Module,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
device: torch.device,
epoch: int,
loss_scaler,
max_norm: float = 0,
args=None
):
model.train()
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
header = "Epoch: [{}]".format(epoch)
num_logs_per_epoch = 1
accum_iter = args.accum_iter
optimizer.zero_grad()
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, len(data_loader) // num_logs_per_epoch, header)):
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
if len(samples.shape) == 6:
b, r, c, t, h, w = samples.shape
samples = samples.view(b * r, c, t, h, w)
targets = targets.view(b * r)
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
with torch.cuda.amp.autocast():
outputs = model(samples)
loss = criterion(outputs, targets)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss /= accum_iter
loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=False, update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
min_lr = 10.0
max_lr = 0.0
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ")
header = "Test:"
num_logs_per_epoch = 1
# switch to evaluation mode
model.eval()
for _, (images, target) in enumerate(metric_logger.log_every(data_loader, len(data_loader) // num_logs_per_epoch, header)):
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
if len(images.shape) == 6:
b, r, c, t, h, w = images.shape
images = images.view(b * r, c, t, h, w)
target = target.view(b * r)
# compute output
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
acc1, acc5 = misc.accuracy(output, target, topk=(1, 5))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}".format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}