Skip to content

Latest commit

 

History

History
91 lines (80 loc) · 3.39 KB

README.md

File metadata and controls

91 lines (80 loc) · 3.39 KB

Cross-encoder

Take the MSMARCO as an example to show the workflow of corss-encoder. Part of code is revised from reranker.

Prepare Data

The data format is the same as retriever training. Noted that you should select the corresponding tokenizer when use a different pretrained language model. For example, using microsoft/deberta-v3-base as the reranker model:

python preprocess.py  --tokenizer_name microsoft/deberta-v3-base --max_seq_length 150 --output_dir ./data/DebertaTokenizer_data --use_title

Train and Inference

The following script will train the reranker and predict the scores for test_file (the result will be saved to rerank_score.txt):

torchrun --nproc_per_node 8 \
-m cross_encoder.run \
--output_dir {path to save model} \
--model_name_or_path microsoft/deberta-v3-base \
--fp16  \
--do_train \
--corpus_file ./data/DebertaTokenizer_data/corpus \
--train_query_file ./data/DebertaTokenizer_data/train_query \
--train_qrels ./data/DebertaTokenizer_data/train_qrels.txt \
--neg_file {negative file} \
--max_len 200 \
--per_device_train_batch_size 1 \
--train_group_size 65 \
--gradient_accumulation_steps 2 \
--warmup_steps 1000 \
--weight_decay 0.01 \
--learning_rate 1e-5 \
--num_train_epochs 4 \
--do_predict  \
--test_query_file ./data/DebertaTokenizer_data/dev_query \
--test_file {ranked results of dev queries provided by retriever} \
--prediction_save_path dev_rerank_score.txt \
--prediction_topk 300 \
--dataloader_num_workers 6 

There are two files need to set with different bi-encoder in the first stage:

  • neg_file: The hard negatives generated by bi-encoder. Following this to generate this file for Shitao/RetroMAE_MSMARCO_finetune or your bi-encoder.

  • test_file: The rank results of bi-encoder. Following the test.py to generate ranking_file for dev queries with model Shitao/RetroMAE_MSMARCO_finetune or your bi-encoder.

Test

python test.py \
--score_file dev_rerank_score.txt \
--qrels_file ./data/qrels.dev.tsv

Prediction for Knowledge Distillation

Train the bi-encoder to learn from the predicted scores of cross-encoder can improve the its performance. Besides the rerank score for ranking results of train queries, the scores for positive pairs also need to be predicted (avoid some positive passages not be retrieved by bi-encoder).

torchrun --nproc_per_node 8 \
-m cross_encoder.run \
--output_dir {path to save model} \
--model_name_or_path {reranker model} \
--fp16  \
--corpus_file ./data/DebertaTokenizer_data/corpus \
--max_len 200 \
--do_predict  \
--test_query_file ./data/DebertaTokenizer_data/train_query \
--test_file ./data/train_qrels.txt \
--prediction_save_path train_qrels_score.txt \
--dataloader_num_workers 6 

torchrun --nproc_per_node 8 \
-m cross_encoder.run \
--output_dir {path to save model} \
--model_name_or_path {reranker model} \
--fp16  \
--corpus_file ./data/DebertaTokenizer_data/corpus \
--max_len 200 \
--do_predict  \
--test_query_file ./data/DebertaTokenizer_data/train_query \
--test_file {ranked results of train queries} \
--prediction_save_path train_rerank_score.txt \
--dataloader_num_workers 6 

Then, set --teacher_score_files train_qrels_score.txt,train_rerank_score.txt to enable the knowledge distillation in bi-encoder training.