-
Notifications
You must be signed in to change notification settings - Fork 0
/
encoder_only_train.py
112 lines (93 loc) · 3.96 KB
/
encoder_only_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from transformers import Trainer, AdamW, AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, TrainingArguments, EarlyStoppingCallback
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, precision_recall_fscore_support
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer
from datasets import load_dataset, DatasetDict, load_from_disk
import os
import sys
def parse_arguments(args):
parsed_args = {}
for arg in args[1:]:
key_value = arg.split('=')
if len(key_value) == 2:
key = key_value[0]
value = key_value[1]
parsed_args[key] = value
else:
raise ValueError("args parsing fault {0}".format(arg))
return parsed_args
args = parse_arguments(sys.argv)
os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu']
checkpoint = args['ckpt']
predata_file = args['dataFile']
prestore_file = args['storeFile']+'/tokenized_data'
def load_preprocess_data(data_file, store_file):
raw = load_from_disk(data_file)
raw = raw.rename_column('label', 'labels')
columns_to_keep = ["labels", 'inputs']
raw = raw.remove_columns([col for col in list(
raw['train'].features.keys()) if col not in columns_to_keep])
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenized_inputs = raw.map(lambda x: tokenizer(
x["inputs"], truncation=True), batched=True, remove_columns=["inputs", "labels"])
max_source_length = max(len(x) for subset in [
'train', 'dev', 'test'] for x in tokenized_inputs[subset]['input_ids'])
print(f"Max source length: {max_source_length}")
def preprocess_function(sample):
return tokenizer(sample["inputs"], max_length=min(512, max_source_length), padding="max_length", truncation=True)
tokenized_dataset = raw.map(
preprocess_function, batched=True, remove_columns=["inputs"])
print(data_file, tokenized_dataset)
tokenized_dataset.save_to_disk(store_file)
def do_training(store_file, level):
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(
checkpoint, num_labels=2)
tokenized_datasets = load_from_disk(store_file)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
training_args = TrainingArguments(
output_dir=store_file.replace("tokenized_data", "output"),
per_device_train_batch_size=int(args['batchsz']),
per_device_eval_batch_size=int(args['batchsz']),
learning_rate=float(args['lr']),
num_train_epochs=1000,
evaluation_strategy="epoch",
save_strategy="epoch",
save_total_limit=1,
load_best_model_at_end=True,
metric_for_best_model="accuracy"
)
print("---Trainging Argument---")
print(training_args)
import evaluate
metric = evaluate.load('accuracy')
def compute_metrics(eval_preds):
logits, true_labels = eval_preds
pred = np.argmax(logits, axis=-1)
return metric.compute(predictions=pred, references=true_labels)
early_stopping_callback = EarlyStoppingCallback(
early_stopping_patience=5,
early_stopping_threshold=0.0001,
)
trainer = Trainer(
model,
training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["dev"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
callbacks=[early_stopping_callback],
)
trainer.train()
for depth in range(1, 7):
data_file = predata_file.format(depth=depth)
store_file = prestore_file.format(depth=depth)
load_preprocess_data(data_file, store_file)
do_training(store_file, depth)
data_file = data_file.replace(f"D{str(depth)}", f"LD{str(depth)}")
store_file = store_file.replace(f"D{str(depth)}", f"LD{str(depth)}")
load_preprocess_data(data_file, store_file)
do_training(store_file, depth)