-
Notifications
You must be signed in to change notification settings - Fork 0
/
mlm_mBERT_xnli.py
50 lines (43 loc) · 1.91 KB
/
mlm_mBERT_xnli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from transformers import LineByLineTextDataset, Trainer, TrainingArguments, BertTokenizer, BertForMaskedLM, \
DataCollatorForLanguageModeling
# setting device for transformers
torch.cuda.set_device(1)
print(torch.cuda.current_device())
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertForMaskedLM.from_pretrained('bert-base-multilingual-cased')
# initialize the training argument
training_args = TrainingArguments(
output_dir='models', # output directory to where save model checkpoint
evaluation_strategy="steps", # evaluate each `logging_steps` steps
overwrite_output_dir=True,
num_train_epochs=2, # number of training epochs, feel free to tweak
per_device_train_batch_size=16, # the training batch size, put it as high as your GPU memory fits
gradient_accumulation_steps=8, # accumulating the gradients before updating the weights
per_device_eval_batch_size=64, # evaluation batch size
logging_steps=1000, # evaluate, log and save model checkpoints every 1000 step
save_steps=1000,
# load_best_model_at_end=True, # whether to load the best model (in terms of loss) at the end of training
# save_total_limit=3, # whether you don't have much space so you let only 3 model weights saved in the disk
)
# initialize data_collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
# initialize datasets
dataset = LineByLineTextDataset(
tokenizer=tokenizer,
file_path='data/xnli/xnli-all.txt',
block_size=128,
)
# initialize the trainer and pass everything to it
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset,
)
# training procedure
trainer.train()
# Save
trainer.save_model('./models/bert_xnli')
tokenizer.save_pretrained('./models/bert_xnli')
print('Finished training all... at ./models/bert_xnli')