Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update genius_pretrain_chinese.py #2

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion pre_training/genius_pretrain_chinese.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformers import BertTokenizer, AutoModel, AutoConfig, AutoModelForSeq2SeqLM
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from datasets import load_dataset, load_metric
from rouge import Rouge
import argparse
parser = argparse.ArgumentParser(allow_abbrev=False)
args = parser.parse_args()
Expand Down Expand Up @@ -71,6 +72,62 @@ def compute_metrics(eval_pred):
return {k: round(v, 4) for k, v in result.items()}


rouge = Rouge()

def get_avg(scores, rouge_name, metric_name):
return sum([score[rouge_name][metric_name] for score in scores])

def compute_chinese_metrics(eval_pred):
predictions, labels = eval_pred
# Decode generated summaries into text
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
# Replace -100 in the labels as we can't decode them
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
# Decode reference summaries into text
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# ROUGE expects a newline after each sentence
decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
# Compute ROUGE scores
scores = rouge.get_scores(decoded_preds, decoded_labels)
avg_scores = {}
length = len(scores)
avg_p_1 = get_avg(scores, "rouge-1", "p") / length
avg_r_1 = get_avg(scores, "rouge-1", "r") / length
avg_f_1 = get_avg(scores, "rouge-1", "f") / length
avg_p_2 = get_avg(scores, "rouge-2", "p") / length
avg_r_2 = get_avg(scores, "rouge-2", "r") / length
avg_f_2 = get_avg(scores, "rouge-2", "f") / length
avg_p_l = get_avg(scores, "rouge-l", "p") / length
avg_r_l = get_avg(scores, "rouge-l", "r") / length
avg_f_l = get_avg(scores, "rouge-l", "f") / length
"""
avg_scores = {
"rouge-1":{
"p": avg_p_1,
"r": avg_r_1,
"f": avg_f_1,
},
"rouge-2":{
"p": avg_p_2,
"r": avg_r_2,
"f": avg_f_2,
},
"rouge-L":{
"p": avg_p_l,
"r": avg_r_l,
"f": avg_f_l,
},
}
"""
# 只返回recall
avg_scores = {
"rouge-1": avg_r_1,
"rouge-2": avg_r_2,
"rouge-L": avg_r_l,
}
return avg_scores

##################################################################
# training
##################################################################
Expand Down Expand Up @@ -118,4 +175,4 @@ def compute_metrics(eval_pred):
)


trainer.train(resume_from_checkpoint = False)
trainer.train(resume_from_checkpoint = False)