-
Notifications
You must be signed in to change notification settings - Fork 28
/
strategies.py
70 lines (48 loc) · 2.39 KB
/
strategies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
class TrainingStrategy:
def init(self, student_diffusion, student_lr, total_steps):
raise Exception()
def zero_grad(self):
raise Exception()
def step(self):
raise Exception()
def stop(self, N, max_iter):
return N > max_iter
class StrategyOneCycle(TrainingStrategy):
def init(self, student_diffusion, student_lr, total_steps):
self.student_optimizer = torch.optim.SGD(student_diffusion.net_.parameters(), lr=student_lr, weight_decay=min(1e-5, 0.5*student_lr/10))
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.student_optimizer, max_lr=student_lr, total_steps=total_steps + 2)
def zero_grad(self):
self.student_optimizer.zero_grad()
def step(self):
self.student_optimizer.step()
self.scheduler.step()
class StrategyConstantLR(TrainingStrategy):
def init(self, student_diffusion, student_lr, total_steps):
self.student_optimizer = torch.optim.AdamW(student_diffusion.net_.parameters(), lr=student_lr)
def zero_grad(self):
self.student_optimizer.zero_grad()
def step(self):
self.student_optimizer.step()
class StrategyLinearLR(TrainingStrategy):
def init(self, student_diffusion, student_lr, total_steps):
self.student_optimizer = torch.optim.AdamW(student_diffusion.net_.parameters(), lr=student_lr)
self.scheduler = torch.optim.lr_scheduler.LinearLR(self.student_optimizer, start_factor=1, end_factor=0, total_iters=total_steps)
def zero_grad(self):
self.student_optimizer.zero_grad()
def step(self):
self.student_optimizer.step()
self.scheduler.step()
class StrategyCosineAnnel(TrainingStrategy):
def init(self, student_diffusion, student_lr, total_steps):
self.student_lr = student_lr
self.eta_min = 0.01
self.student_optimizer = torch.optim.SGD(student_diffusion.net_.parameters(), lr=student_lr, weight_decay=min(1e-5, student_lr/10))
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.student_optimizer, T_0=100, T_mult=2, eta_min=self.eta_min, last_epoch=-1)
def zero_grad(self):
self.student_optimizer.zero_grad()
def step(self):
self.student_optimizer.step()
self.scheduler.step()
def stop(self, N, max_iter):
return (self.scheduler.get_last_lr()[0] < self.student_lr * self.eta_min * 3) and (N > max_iter)