-
Notifications
You must be signed in to change notification settings - Fork 51
/
scheduler.py
28 lines (21 loc) · 1.1 KB
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from torch.optim.lr_scheduler import LambdaLR
def create_scheduler(args, optimizer):
if 'num_training_steps' not in args:
args['num_training_steps'] = args['epochs'] * args['step_per_epoch']
print("### num_training_steps, ", args['num_training_steps'], flush=True)
if isinstance(args['num_warmup_steps'], float):
assert 0 <= args['num_warmup_steps'] < 1
args['num_warmup_steps'] = int(args['num_training_steps'] * args['num_warmup_steps'])
print("### num_warmup_steps, ", args['num_warmup_steps'], flush=True)
if args.sched == 'linear':
def lr_lambda(current_step: int):
if current_step < args.num_warmup_steps:
return float(current_step) / float(max(1, args.num_warmup_steps))
return max(
0.0, float(args.num_training_steps - current_step) / float(
max(1, args.num_training_steps - args.num_warmup_steps))
)
lr_scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1)
else:
raise NotImplementedError(f"args.sched == {args.sched}")
return lr_scheduler