This repository contains code to train flax-based MonoBERT ranking models from scratch for the large-scale Baidu-ULTR search dataset. The repository is part of a larger reproducibility study of unbiased learning-to-rank methods on the Baidu-ULTR datasets.
- We recommend installing dependencies using either
- Mamba:
mamba env create --file environment.yaml
; or - Poetry:
poetry install
. - Make sure you install Jax with cuda support for your system.
- Mamba:
- Next, download the Baidu ULTR dataset for training. We upload the first 125 partitions here. Afterwards, update the project config with your dataset path under
config/user_const.yaml
. - You can train our BERTs on a SLURM cluster using, e.g.:
sbatch scripts/train.sh <model-name>
, where<model-name>
is the ranking objective, e.g.:[naive-pointwise, naive-listwise, pbm, dla, ips-pointwise, ips-listwise]
- You can evaluate all pre-trained models by running:
sbatch scripts/eval.sh <model-name>
You can download all pre-trained models from hugging face hub by clicking the model names below. We also list the evaluation results on the Baidu-ULTR test set. Ranking performance is measured in DCG, nDCG, and MRR on expert annotations (6,985 queries). Click prediction performance is measured in log-likelihood on one test partition of user clicks (≈297k queries).
Model | Log-likelihood | DCG@1 | DCG@3 | DCG@5 | DCG@10 | nDCG@10 | MRR@10 |
---|---|---|---|---|---|---|---|
Pointwise Naive | 0.227 | 1.641 | 3.462 | 4.752 | 7.251 | 0.357 | 0.609 |
Pointwise Two-Tower | 0.218 | 1.629 | 3.471 | 4.822 | 7.456 | 0.367 | 0.607 |
Pointwise IPS | 0.222 | 1.295 | 2.811 | 3.977 | 6.296 | 0.307 | 0.534 |
Listwise Naive | - | 1.947 | 4.108 | 5.614 | 8.478 | 0.405 | 0.639 |
Listwise IPS | - | 1.671 | 3.530 | 4.873 | 7.450 | 0.361 | 0.603 |
Listwise DLA | - | 1.796 | 3.730 | 5.125 | 7.802 | 0.377 | 0.615 |
from datasets import load_dataset
from torch.utils.data import DataLoader
from src.data import collate_click_fn
from src.model import CrossEncoder
# As an example, we use a smaller click dataset based on Baidu ULTR:
dataset = load_dataset(
"philipphager/baidu-ultr_uva-mlm-ctr",
name="clicks",
split="test",
trust_remote_code=True,
)
click_loader = DataLoader(
test_clicks,
batch_size=64,
collate_fn=collate_click_fn,
)
# Download the naive-pointwise model from HuggingFace hub.
# Note that you have to change the model class for instantiating different models:
model = CrossEncoder.from_pretrained(
"philipphager/baidu-ultr_uva-bert_naive-pointwise",
)
# Use model for click / relevance prediction
batch = next(iter(click_loader))
model(batch)
# Use model only for relevance prediction, e.g., for evaluation:
model.predict_relevance(batch)
The basis for all ranking models in this repository is a MonoBERT cross-encoder architecture. In a cross-encoder, the user query and each candidate document are concatenated as the BERT input and the CLS token is used to predict query-document relevance. We train BERT models from scratch using a masked language modeling objective by randomly masking the model input and training the model to predict missing tokens. We tune the CLS token to predict query-document relevance using ranking objectives on user clicks. We display a rough sketch of the model architecture below:
We use a pointwise binary cross-entropy loss and a listwise softmax cross-entropy loss as our main ranking losses. We implement several unbiased learning to rank methods for position bias mitigation in click data, including a Two-Tower/PBM objective, inverse propensity scoring (IPS), and the dual learning algorithm (DLA). For more details see our paper or inspect our loss functions at src/loss.py
.
@inproceedings{Hager2024BaiduULTR,
author = {Philipp Hager and Romain Deffayet and Jean-Michel Renders and Onno Zoeter and Maarten de Rijke},
title = {Unbiased Learning to Rank Meets Reality: Lessons from Baidu’s Large-Scale Search Dataset},
booktitle = {Proceedings of the 47th International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR`24)},
organization = {ACM},
year = {2024},
}
This project uses the MIT License.