diff --git a/pre_training/genius_pretrain_chinese.py b/pre_training/genius_pretrain_chinese.py index 97fd1eb..ec9f2c7 100644 --- a/pre_training/genius_pretrain_chinese.py +++ b/pre_training/genius_pretrain_chinese.py @@ -7,6 +7,7 @@ from transformers import BertTokenizer, AutoModel, AutoConfig, AutoModelForSeq2SeqLM from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq from datasets import load_dataset, load_metric +from rouge import Rouge import argparse parser = argparse.ArgumentParser(allow_abbrev=False) args = parser.parse_args() @@ -71,6 +72,62 @@ def compute_metrics(eval_pred): return {k: round(v, 4) for k, v in result.items()} +rouge = Rouge() + +def get_avg(scores, rouge_name, metric_name): + return sum([score[rouge_name][metric_name] for score in scores]) + +def compute_chinese_metrics(eval_pred): + predictions, labels = eval_pred + # Decode generated summaries into text + decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) + # Replace -100 in the labels as we can't decode them + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + # Decode reference summaries into text + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + # ROUGE expects a newline after each sentence + decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds] + decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels] + # Compute ROUGE scores + scores = rouge.get_scores(decoded_preds, decoded_labels) + avg_scores = {} + length = len(scores) + avg_p_1 = get_avg(scores, "rouge-1", "p") / length + avg_r_1 = get_avg(scores, "rouge-1", "r") / length + avg_f_1 = get_avg(scores, "rouge-1", "f") / length + avg_p_2 = get_avg(scores, "rouge-2", "p") / length + avg_r_2 = get_avg(scores, "rouge-2", "r") / length + avg_f_2 = get_avg(scores, "rouge-2", "f") / length + avg_p_l = get_avg(scores, "rouge-l", "p") / length + avg_r_l = get_avg(scores, "rouge-l", "r") / length + avg_f_l = get_avg(scores, "rouge-l", "f") / length + """ + avg_scores = { + "rouge-1":{ + "p": avg_p_1, + "r": avg_r_1, + "f": avg_f_1, + }, + "rouge-2":{ + "p": avg_p_2, + "r": avg_r_2, + "f": avg_f_2, + }, + "rouge-L":{ + "p": avg_p_l, + "r": avg_r_l, + "f": avg_f_l, + }, + } + """ + # 只返回recall + avg_scores = { + "rouge-1": avg_r_1, + "rouge-2": avg_r_2, + "rouge-L": avg_r_l, + } + return avg_scores + ################################################################## # training ################################################################## @@ -118,4 +175,4 @@ def compute_metrics(eval_pred): ) -trainer.train(resume_from_checkpoint = False) \ No newline at end of file +trainer.train(resume_from_checkpoint = False)