-
Notifications
You must be signed in to change notification settings - Fork 6
/
pretrain_bert.py
executable file
·613 lines (513 loc) · 22.7 KB
/
pretrain_bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT"""
from comet_ml import Experiment
#from apex import amp
import os
import random
import numpy as np
import psutil
import torch
from olfmlm.arguments import get_args
from olfmlm.configure_data import configure_data
from olfmlm.learning_rates import AnnealingLR
from olfmlm.model import BertModel
from olfmlm.model import get_params_for_weight_decay_optimization
from olfmlm.model import DistributedDataParallel as DDP
from olfmlm.optim import Adam
from olfmlm.utils import Timers
from olfmlm.utils import save_checkpoint
from olfmlm.utils import load_checkpoint
def get_model(tokenizer, args):
"""Build the model."""
print('building BERT model ...')
model = BertModel(tokenizer, args)
print(' > number of parameters: {}'.format(
sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Wrap model for distributed training.
if args.world_size > 1:
model = DDP(model)
return model
def get_optimizer(model, args):
"""Set up the optimizer."""
# Build parameter groups (weight decay and non-decay).
while isinstance(model, DDP):
model = model.module
param_groups = model.get_params()
# Use Adam.
optimizer = Adam(param_groups,
lr=args.lr, weight_decay=args.weight_decay)
return optimizer
def get_learning_rate_scheduler(optimizer, args):
"""Build the learning rate scheduler."""
# Add linear learning rate scheduler.
if args.lr_decay_iters is not None:
num_iters = args.lr_decay_iters
else:
num_iters = args.train_tokens * args.epochs
init_step = -1
warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR(optimizer,
start_lr=args.lr,
warmup_iter=warmup_iter,
num_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step)
return lr_scheduler
def setup_model_and_optimizer(args, tokenizer):
"""Setup model and optimizer."""
model = get_model(tokenizer, args)
optimizer = get_optimizer(model, args)
lr_scheduler = get_learning_rate_scheduler(optimizer, args)
criterion_cls = torch.nn.CrossEntropyLoss(reduce=False, ignore_index=-1)
criterion_reg = torch.nn.MSELoss(reduce=False)
criterion = (criterion_cls, criterion_reg)
if args.load is not None:
args.epoch = load_checkpoint(model, optimizer, lr_scheduler, args)
args.resume_dataloader = True
return model, optimizer, lr_scheduler, criterion
def get_batch(data):
""" Get a batch of data from the data loader, which automatically batches the individual examples
Concatenates necessary data (lm_labels, loss_mask, tgs_mask), which is required for FS/QT variant tasks
Puts data into tensors, and places them onto CUDA
"""
# TODO Add trigram mask
aux_labels = {}
for mode, label in data['aux_labels'].items():
if label.shape[1] == 2:
label = torch.cat([label[:, 0], label[:, 1]])
else:
label = label.squeeze()
aux_labels[mode] = torch.autograd.Variable(label.long()).cuda()
num_sentences = data['n']
num_tokens = torch.tensor(sum(data['num_tokens']).item()).long().cuda()
tokens = []
types = []
tasks = []
loss_mask = []
tgs_mask = []
lm_labels = []
att_mask = []
for i in range(min(num_sentences)):
suffix = "_" + str(i)
tokens.append(torch.autograd.Variable(data['text' + suffix].long()).cuda())
types.append(torch.autograd.Variable(data['types' + suffix].long()).cuda())
tasks.append(torch.autograd.Variable(data['task' + suffix].long()).cuda())
att_mask.append(1 - torch.autograd.Variable(data['pad_mask' + suffix].byte()).cuda())
lm_labels.append((data['mask_labels' + suffix]).long())
loss_mask.append((data['mask' + suffix]).float())
tgs_mask.append((data['tgs_mask' + suffix]).float())
lm_labels = torch.autograd.Variable(torch.cat(lm_labels, dim=0).long()).cuda()
loss_mask = torch.autograd.Variable(torch.cat(loss_mask, dim=0).float()).cuda()
tgs_mask = torch.autograd.Variable(torch.cat(tgs_mask, dim=0).float()).cuda()
return (tokens, types, tasks, aux_labels, loss_mask, tgs_mask, lm_labels, att_mask, num_tokens)
def forward_step(data, model, criterion, modes, args):
"""Forward step."""
criterion_cls, criterion_reg = criterion
# Get the batch.
batch = get_batch(data)
tokens, types, tasks, aux_labels, loss_mask, tgs_mask, lm_labels, att_mask, num_tokens = batch
# Create self-supervised labels which required batch size
if "rg" in modes:
aux_labels['rg'] = torch.autograd.Variable(torch.arange(tokens[0].shape[0]).long()).cuda()
if "fs" in modes:
aux_labels['fs'] = torch.autograd.Variable(torch.ones(tokens[0].shape[0] * 2 * args.seq_length).long()).cuda()
# Forward model.
scores = model(modes, tokens, types, tasks, att_mask, checkpoint_activations=args.checkpoint_activations)
assert sorted(list(scores.keys())) == sorted(modes)
# Calculate losses based on required criterion
losses = {}
for mode, score in scores.items():
if mode in ["mlm", "sbo"]:
mlm_loss = criterion_cls(score.view(-1, args.data_size).contiguous().float(),
lm_labels.view(-1).contiguous())
loss_mask = loss_mask.view(-1).contiguous()
losses[mode] = torch.sum(mlm_loss * loss_mask.view(-1).float()) / loss_mask.sum()
elif mode == "tgs":
tgs_loss = criterion_cls(score.view(-1, 6).contiguous().float(),
aux_labels[mode].view(-1).contiguous())
tgs_loss = tgs_loss.view(-1).contiguous()
losses[mode] = torch.sum(tgs_loss * tgs_mask.view(-1).float() / tgs_mask.sum())
elif mode in ["fs", "wlen", "tf", "tf_idf"]: # use regression
losses[mode] = criterion_reg(score.view(-1).contiguous().float(),
aux_labels[mode].view(-1).contiguous().float()).mean()
else:
score = score.view(-1, 2) if mode in ["tc", "cap"] else score
losses[mode] = criterion_cls(score.contiguous().float(),
aux_labels[mode].view(-1).contiguous()).mean()
return losses, num_tokens
def backward_step(optimizer, model, losses, num_tokens, args):
"""Backward step."""
# Backward pass.
optimizer.zero_grad()
# For testing purposes, should always be False
if args.no_aux:
total_loss = losses['mlm']
else:
total_loss = sum(losses.values())
total_loss.backward()
# Reduce across processes.
losses_reduced = losses
if args.world_size > 1:
losses_reduced = [[k,v] for k,v in losses_reduced.items()]
reduced_losses = torch.cat([x[1].view(1) for x in losses_reduced])
torch.distributed.all_reduce(reduced_losses.data)
torch.distributed.all_reduce(num_tokens)
reduced_losses.data = reduced_losses.data / args.world_size
model.allreduce_params(reduce_after=False,
fp32_allreduce=False)#args.fp32_allreduce)
losses_reduced = {losses_reduced[i][0]: reduced_losses[i] for i in range(len(losses_reduced))}
# Clipping gradients helps prevent the exploding gradient.
if args.clip_grad > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
return losses_reduced, num_tokens
def train_step(input_data, model, criterion, optimizer, lr_scheduler, modes, args):
"""Single training step."""
# Forward model for one step.
losses, num_tokens = forward_step(input_data, model, criterion, modes, args)
# Calculate gradients, reduce across processes, and clip.
losses_reduced, num_tokens = backward_step(optimizer, model, losses, num_tokens, args)
# Update parameters.
optimizer.step()
return losses_reduced, num_tokens
def get_stage_info(total_tokens, num_tasks):
"""
Get number of tokens for each task during each stage. Based on ERNIE 2.0's continual multi-task learning
Number of stages is equal to the number of tasks (each stage is larger than the previous one)
:param total_tokens: total number of tokens to train on
:param num_tasks: number of tasks
:return: Number of tokens for each task at each stage
"""
tokens_per_task = total_tokens / num_tasks
tokens_subunit = tokens_per_task / (num_tasks + 1)
tokens_per_task_per_stage = []
for i in range(num_tasks):
stage_tokens = []
for j in range(num_tasks):
if i < j:
stage_tokens.append(0)
elif i > j:
stage_tokens.append(tokens_subunit)
else:
stage_tokens.append(tokens_subunit * (i + 2))
tokens_per_task_per_stage.append(stage_tokens)
return tokens_per_task_per_stage
def set_up_stages(args):
"""
Set up stage information and functions to use for ERNIE 2.0's continual multi-task learning
Closure that returns a function that will return next stages token requirements as requested
:param args: arguments
:return: a function that will return next stages token requirements as requested
"""
assert not args.incremental
total_tokens = args.epochs * args.train_tokens
modes = args.modes.split(',')
if args.always_mlm:
modes = modes[1:]
stage_splits = get_stage_info(total_tokens, len(modes))
stage_idx = 0
def next_stage():
nonlocal stage_idx
if stage_idx >= len(stage_splits):
print("Finished all training, shouldn't reach this unless it's the very final iteration")
return {k: float(total_tokens) for k in modes}
assert len(modes) == len(stage_splits[stage_idx])
current_stage = {k: v for k, v in zip(modes, stage_splits[stage_idx])}
print("Starting stage {} of {}, with task distribution: ".format(stage_idx, len(stage_splits)))
print(current_stage)
stage_idx += 1
return current_stage
return next_stage
def get_mode_from_stage(current_stage, args):
"""
Get the mode to use given the current stage
:param current_stage: number of tokens left for each task for this stage
:param args: arguments
:return: selected mode
"""
modes = args.modes.split(',')
if args.always_mlm:
modes = modes[1:]
p = np.array([current_stage[m] for m in modes])
p /= np.sum(p)
return [np.random.choice(modes, p=p)]
def train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, timers, experiment, metrics, args,
current_stage=None, next_stage=None):
"""Train one full epoch."""
print("Starting training of epoch {}".format(epoch), flush=True)
# Turn on training mode which enables dropout.
model.train()
# Tracking loss.
total_losses = {}
# Iterations.
max_tokens = args.train_tokens
log_tokens = 0
tot_tokens = 0
iteration = 0
tot_iteration = 0
# Data iterator.
modes = args.modes.split(',')
# Incrementally add tasks after each epoch
if args.incremental:
modes = modes[:epoch]
train_data.dataset.set_args(modes)
sent_tasks = [m for m in modes if m in train_data.dataset.sentence_tasks]
tok_tasks = [m for m in modes if m not in train_data.dataset.sentence_tasks + ["mlm"]]
data_iters = iter(train_data)
timers('interval time').start()
while tot_tokens < max_tokens:
# ERNIE 2.0's continual multi task learning
if args.continual_learning:
# Continual learn through all tasks
modes_ = get_mode_from_stage(current_stage, args)
if args.always_mlm:
# Continual learn through auxiliary tasks, always learning MLM
modes_ = ['mlm'] + modes_
# Alternating between tasks
elif args.alternating:
if args.always_mlm:
# Alternate between all tasks
modes_ = ['mlm']
if len(modes[1:]) > 0:
# Alternate between auxiliary tasks, always learning MLM
modes_ += [modes[(iteration % (len(modes) - 1)) + 1]]
else:
modes_ = [modes[iteration % len(modes)]]
# Summing all tasks
else:
sent_task = [] if len(sent_tasks) == 0 else [sent_tasks[iteration % len(sent_tasks)]]
modes_ = ['mlm'] + sent_task + tok_tasks
while True:
try:
losses, num_tokens = train_step(next(data_iters),
model,
criterion,
optimizer,
lr_scheduler,
modes_,
args)
break
except StopIteration:
data_iters = iter(train_data)
log_tokens += num_tokens.item()
tot_tokens += num_tokens.item()
if args.continual_learning:
for m in modes_:
if args.always_mlm and m == "mlm":
continue
current_stage[m] = max(0, current_stage[m] - num_tokens.item())
if sum(current_stage.values()) == 0:
ns = next_stage()
for m in ns:
current_stage[m] = ns[m]
# Update learning rate.
lr_scheduler.step(step_num=(epoch-1) * max_tokens + tot_tokens)
iteration += 1
# Update losses.
for mode, loss in losses.items():
total_losses[mode] = total_losses.get(mode, 0.0) + loss.data.detach().float()
# Logging.
if log_tokens > args.log_interval:
log_tokens = 0
learning_rate = optimizer.param_groups[0]['lr']
avg_loss = {}
for mode, v in total_losses.items():
avg_loss[mode] = v.item() / iteration
elapsed_time = timers('interval time').elapsed()
log_string = ' epoch{:2d} |'.format(epoch)
log_string += ' tokens {:8d}/{:8d} |'.format(tot_tokens, max_tokens)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / iteration)
log_string += ' learning rate {:.3E} |'.format(learning_rate)
for mode, v in avg_loss.items():
log_string += ' {} loss {:.3E} |'.format(mode, v)
print(log_string, flush=True)
#print(iteration)
total_losses = {}
experiment.set_step((epoch - 1) * max_tokens + tot_tokens)
metrics['learning_rate'] = learning_rate
for mode, v in avg_loss.items():
metrics[mode] = v
experiment.log_metrics(metrics)
tot_iteration += iteration
iteration = 0
print("Learnt using {} tokens over {} iterations this epoch".format(tot_tokens, tot_iteration + iteration))
def evaluate(epoch, data_source, model, criterion, elapsed_time, args, test=False):
"""Evaluation."""
print("Entering evaluation", flush=True)
# Turn on evaluation mode which disables dropout.
model.eval()
total_losses = {}
max_tokens = args.eval_tokens
tokens = 0
modes = args.modes.split(',')
data_source.dataset.set_args(modes)
data_iters = iter(data_source)
with torch.no_grad():
iteration = 0
while tokens < max_tokens:
# Forward evaluation.
while True:
try:
losses, num_tokens = forward_step(next(data_iters), model, criterion, modes, args)
break
except (TypeError, RuntimeError) as e:
print("Ooops, caught: '{}', continuing".format(e))
except StopIteration:
data_iters = iter(data_source)
# Reduce across processes.
if isinstance(model, DDP):
losses_reduced = [[k, v] for k, v in losses.items()]
reduced_losses = torch.cat([x[1].view(1) for x in losses_reduced])
torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data / args.world_size
torch.distributed.all_reduce(num_tokens)
losses = {losses_reduced[i][0]: reduced_losses[i] for i in range(len(losses_reduced))}
assert sorted(list(losses.keys())) == sorted(modes)
for mode, loss in losses.items():
total_losses[mode] = total_losses.get(mode, 0.0) + loss.data.detach().float().item()
iteration += 1
tokens += num_tokens.item()
print("Evaluated using {} tokens over {} iterations.".format(tokens, iteration), flush=True)
# Move model back to the train mode.
model.train()
avg_loss = {}
for mode, v in total_losses.items():
avg_loss[mode] = v / args.eval_iters
tot_loss = sum(avg_loss.values())
sep_char = '=' if test else '-'
print(sep_char * 100)
log_string = '| End of training | '.format(epoch) if test else '| End of epoch {:3d} | '
log_string += 'time: {:5.2f}s | valid loss {:.4E} | '.format(epoch, elapsed_time, tot_loss)
for mode, v in avg_loss.items():
log_string += ' {} loss {:.3E} |'.format(mode, v)
print(log_string, flush=True)
print(sep_char * 100, flush=True)
return tot_loss
def initialize_distributed(args):
"""Initialize torch.distributed."""
# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
if args.local_rank is not None:
device = args.local_rank
torch.cuda.set_device(device)
# Call the init process
if args.world_size > 1:
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
def set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def main():
"""Main training program."""
print('Pretrain BERT model')
# Disable CuDNN.
torch.backends.cudnn.enabled = False
# Timer.
timers = Timers()
# Arguments.
args = get_args()
experiment = Experiment(api_key='1jl4lQOnJsVdZR6oekS6WO5FI',
project_name=args.model_type,
auto_param_logging=False, auto_metric_logging=False,
disabled=(not args.track_results))
experiment.log_parameters(vars(args))
metrics = {}
# Pytorch distributed.
initialize_distributed(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
# Data stuff.
data_config = configure_data()
data_config.set_defaults(data_set_type='BERT', transpose=False)
(train_data, val_data, test_data), tokenizer = data_config.apply(args)
args.data_size = tokenizer.num_tokens
# Model, optimizer, and learning rate.
model, optimizer, lr_scheduler, criterion = setup_model_and_optimizer(
args, tokenizer)
#model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
timers("total time").start()
epoch = 0
# At any point you can hit Ctrl + C to break out of training early.
try:
start_epoch = 1
best_val_loss = float('inf')
# Resume data loader if necessary.
if args.resume_dataloader:
start_epoch = args.epoch
next_stage = None
current_stage = None
if args.continual_learning:
next_stage = set_up_stages(args)
current_stage = next_stage()
if args.resume_dataloader:
num_tokens = args.epoch * args.train_tokens
# Get to the right stage
while num_tokens > sum(current_stage.values()):
num_tokens -= sum(current_stage.values())
ns = next_stage()
for m in ns:
current_stage[m] = ns[m]
# Get to right part of stage
stage_tokens = sum(current_stage.values())
stage_ratios = {k: v / float(stage_tokens) for k, v in current_stage.items()}
for k in current_stage:
current_stage[k] -= num_tokens * stage_ratios[k]
# Train for required epochs
for epoch in range(start_epoch, args.epochs+1):
if args.shuffle:
train_data.batch_sampler.sampler.set_epoch(epoch+args.seed)
timers('epoch time').start()
# Train
train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, timers, experiment, metrics, args,
current_stage=current_stage, next_stage=next_stage)
elapsed_time = timers('epoch time').elapsed()
if args.save:
ck_path = 'ck/model_{}.pt'.format(epoch)
print('saving ck model to:',
os.path.join(args.save, ck_path))
save_checkpoint(ck_path, epoch+1, model, optimizer, lr_scheduler, args)
# Validate
val_loss = evaluate(epoch, val_data, model, criterion, elapsed_time, args)
if val_loss < best_val_loss:
best_val_loss = val_loss
if args.save:
best_path = 'best/model.pt'
print('saving best model to:',
os.path.join(args.save, best_path))
save_checkpoint(best_path, epoch+1, model, optimizer, lr_scheduler, args)
except KeyboardInterrupt:
print('-' * 100)
print('Exiting from training early')
exit()
if test_data is not None:
# Run on test data.
print('entering test')
elapsed_time = timers("total time").elapsed()
evaluate(epoch, test_data, model, criterion, elapsed_time, args, test=True)
if __name__ == "__main__":
main()