forked from amzn/pecos
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Experiment Code of PEFA Paper for WSDM24
- Loading branch information
Wei-Cheng Chang
committed
Dec 5, 2023
1 parent
51062ec
commit 63d9eb7
Showing
27 changed files
with
850 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Experiment Code for PEFA, WSDM 2024 | ||
|
||
This folder contains code to reproduce experiments in | ||
["PEFA: Parameter-Free Adapters for Large-scale Embedding-based Retrieval Models"](#) | ||
|
||
## 1. Summary | ||
In this repository, we demonstrated how to reproduce Table 2 (NQ-320K) and Table 3 (Trivia-QA) of our PEFA paper. | ||
After following Steps 2-6 in the subsequent sections, you should be able to obtain | ||
|
||
| NQ-320K | Recall@10 | Recall@100 | | ||
|---|---|---| | ||
| DistilBERT + PEFA-XS | 80.52% | 92.23% | | ||
| DistilBERT + PEFA-XL | 85.26% | 92.53% | | ||
| MPNet + PEFA-XS | 86.67% | 94.53% | | ||
| MPNet + PEFA-XL | 88.72% | 95.13% | | ||
| Sentence-T5-base + PEFA-XS | 82.52% | 92.18% | | ||
| Sentence-T5-base + PEFA-XL | 83.69% | 92.55% | | ||
| GTR-T5-base + PEFA-XS | 84.90% | 93.28% | | ||
| GTR-T5-base + PEFA-XL | 88.71% | 94.36% | | ||
|
||
| Trivia-QA | Recall@20 | Recall@100 | | ||
|---|---|---| | ||
| DistilBERT + PEFA-XS | 86.28% | 93.33% | | ||
| DistilBERT + PEFA-XL | 84.18% | 91.24% | | ||
| MPNet + PEFA-XS | 86.05% | 92.97% | | ||
| MPNet + PEFA-XL | 86.13% | 92.42% | | ||
| Sentence-T5-base + PEFA-XS | 78.39% | 88.57% | | ||
| Sentence-T5-base + PEFA-XL | 75.13% | 87.24% | | ||
| GTR-T5-base + PEFA-XS | 83.81% | 91.02% | | ||
| GTR-T5-base + PEFA-XL | 85.30% | 92.38% | | ||
|
||
|
||
## 2. Getting Started | ||
* Clone the repository and enter `examples/pefa-wsdm24` directory. | ||
* First create a [virtual environment](https://docs.python.org/3/library/venv.html) and then install dependencies by running the following command: | ||
```bash | ||
python3 -m pip install libpecos==1.2.1 | ||
python3 -m pip install sentence-transformers==2.2.1 | ||
``` | ||
|
||
## 3. Download Pre-processed Data for NQ320K and Trivia-QA | ||
Our pre-processed datasets of NQ320K and Trivia-QA can be download at | ||
```bash | ||
mkdir -p ./data/xmc; cd ./data/xmc; | ||
DATASET="nq320k" # nq320k or trivia | ||
wget https://archive.org/download/pefa-wsdm24/data/xmc/${DATASET}.tar.gz | ||
tar -zxvf ./${DATASET}.tar.gz | ||
cd ../../ # get back to the pecos/examples/pefa-wsdm24 directory | ||
``` | ||
|
||
Additional Notes on data-preprocessing | ||
* We first obtained original NQ320K/Trivia-QA datasets from the [NCI Paper, Wang et al., NeurIPS22](https://github.com/solidsea98/Neural-Corpus-Indexer-NCI) | ||
- We then pre-processed it into our format. | ||
- Details about our data pre-processing scripts can be found in `./data/README.md`. | ||
|
||
|
||
## 4. Generate Embeddings for PEFA Inference | ||
Before running PEFA inference, we select an encoder to generating query/passage embeddings | ||
```bash | ||
DATASET="nq320k" # nq320k or trivia | ||
ENCODER="gtr-t5-base" | ||
bash run_encoder ${DATASET} ${ENCODER} | ||
``` | ||
The embeddings will be saved to `./data/embeddings/${DATSET}/` | ||
Regarding the `ENCODER` used in our paper, | ||
* For `nq320k`, we consider `{nq-distilbert-base-v1, multi-qa-mpnet-base-dot-v1, sentence-t5-base, gtr-t5-base}` | ||
* For `trivia`, we consider `{multi-qa-distilbert-dot-v1, multi-qa-mpnet-base-dot-v1, sentence-t5-base, gtr-t5-base}` | ||
|
||
## 5. Run PEFA-XS | ||
```bash | ||
DATASET="nq320k" | ||
ENCODER="gtr-t5-base" | ||
bash run_pefa_xs.sh ${DATASET} ${ENCODER} | ||
``` | ||
The script `run_pefa_xs.sh` calls the `pefa_xs.py` with hard-coded hyper-parameters. | ||
For example, it uses `threads=64`. If your machine has less CPU cores, please adjust it accordingly. | ||
|
||
## 6. Run PEFA-XL | ||
```bash | ||
DATASET="nq320k" | ||
ENCODER="gtr-t5-base" | ||
bash run_pefa_xl.sh ${DATASET} ${ENCODER} | ||
``` | ||
The script `run_pefa_xl.sh` calls the `pefa_xl.py` with hard-coded hyper-parameters. | ||
For example, it uses `threads=64`. If your machine has less CPU cores, please adjust it accordingly. | ||
|
||
## 7. Citation | ||
If you find this work useful for your research, please cite: | ||
``` | ||
@inproceedings{chang2024pefa, | ||
title={PEFA: Parameter-Free Adapters for Large-scale Embedding-based Retrieval Models}, | ||
author={Wei-Cheng Chang and Jyun-Yu Jiang and Jiong Zhang and Mutasem Al-Darabsah and Choon Hui Teo and Cho-Jui Hsieh and Hsiang-Fu Yu and S. V. N. Vishwanathan}, | ||
booktitle={Proceedings of the 17th ACM International Conference on Web Search and Data Mining (WSDM '24)}, | ||
year={2024} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
|
||
# Download Raw Data from NCI Paper | ||
- NQ320K: https://drive.google.com/drive/folders/1epfUw4yQjAtqnZTQDLAUOwTJg-YMCGdD | ||
- Trivia: https://drive.google.com/drive/folders/1SY28Idba1X8DNi4PYaDDH9CbUpdKiTXQ | ||
``` | ||
unzip the NQ320K folder to ./raw/NQ320K_data | ||
unzip the Trivia folder to ./raw/trivia_newdata | ||
``` | ||
|
||
# Process the Raw Data to XMC Format | ||
- NQ320K: | ||
``` | ||
python proc_nq320k.py | ||
``` | ||
|
||
- Trivia: | ||
``` | ||
python proc_trivia.py | ||
``` | ||
|
||
You should see the following data artifacts | ||
``` | ||
./xmc/{nq320k|trivia} | ||
|- X.trn.abs.txt | ||
|- X.trn.d2q.txt | ||
|- X.trn.doc.txt | ||
|- X.trn.txt | ||
|- X.tst.txt | ||
|- Y.trn.abs.npz | ||
|- Y.trn.d2q.npz | ||
|- Y.trn.doc.npz | ||
|- Y.trn.npz | ||
|- Y.tst.npz | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
|
||
import os | ||
from collections import defaultdict | ||
import numpy as np | ||
import pandas as pd | ||
import scipy.sparse as smat | ||
from pecos.utils import smat_util | ||
|
||
|
||
COL_NAME_LIST = [ | ||
"query", "qid", "doc_id", | ||
"bert_k30_c30_1", "bert_k30_c30_2", "bert_k30_c30_3", "bert_k30_c30_4", "bert_k30_c30_5", | ||
] | ||
|
||
def load_df(input_tsv_path): | ||
return pd.read_csv( | ||
input_tsv_path, | ||
encoding='utf-8', header=None, sep='\t', | ||
names=COL_NAME_LIST, | ||
dtype={"query": str, "qid": str, 'doc_id': str} | ||
).loc[:, ["query", "qid", "doc_id"]] | ||
|
||
def build_did2lid_map(df_inp, did_to_lid): | ||
for i in range(len(df_inp)): | ||
did_str = df_inp["doc_id"][i] | ||
if did_str not in did_to_lid: | ||
did_to_lid[did_str] = len(did_to_lid) | ||
|
||
def build_corpus_and_label_mat(df, did_to_lid, skip_same_lid=False): | ||
qry_to_qid = defaultdict(str) | ||
rows, cols = [], [] | ||
inc_lid_set = set() | ||
for i in range(len(df)): | ||
query = df["query"][i] | ||
did_str = df["doc_id"][i] | ||
|
||
lid = did_to_lid[did_str] | ||
if skip_same_lid and lid in inc_lid_set: | ||
continue | ||
inc_lid_set.add(lid) | ||
|
||
if query not in qry_to_qid: | ||
qry_to_qid[query] = len(qry_to_qid) | ||
qid = qry_to_qid[query] | ||
rows.append(qid) | ||
cols.append(lid) | ||
|
||
vals = [1.0 for _ in range(len(rows))] | ||
|
||
num_inp, num_out = len(qry_to_qid), len(did_to_lid) | ||
Y = smat.csr_matrix( | ||
(vals, (rows, cols)), | ||
shape=(num_inp, num_out), | ||
dtype=np.float32, | ||
) | ||
print("#Q {:7d} #L {:7d} NNZ {:9d}".format(num_inp, num_out, Y.nnz)) | ||
id2query = [str(query).lower() for query, qid in sorted(qry_to_qid.items(), key=lambda x: x[1])] | ||
return id2query, Y | ||
|
||
|
||
def write_qtxt(id2qtxt, output_path): | ||
with open(output_path, "w") as fout: | ||
for query_txt in id2qtxt: | ||
fout.write(f"{query_txt}\n") | ||
|
||
def main(): | ||
df_trn = load_df("./raw/NQ_dataset/nq_train_doc_newid.tsv") | ||
df_tst = load_df("./raw/NQ_dataset/nq_dev_doc_newid.tsv") | ||
df_abs = load_df("./raw/NQ_dataset/nq_title_abs.tsv") | ||
df_doc = load_df("./raw/NQ_dataset/NQ_doc_aug.tsv") | ||
df_d2q = load_df("./raw/NQ_dataset/NQ_512_qg.tsv") | ||
|
||
did_to_lid = defaultdict(str) | ||
build_did2lid_map(df_abs, did_to_lid) | ||
print("After df_abs, #uniq_label {:9d}".format(len(did_to_lid))) | ||
build_did2lid_map(df_doc, did_to_lid) | ||
print("After df_doc, #uniq_label {:9d}".format(len(did_to_lid))) | ||
build_did2lid_map(df_d2q, did_to_lid) | ||
print("After df_d2q, #uniq_label {:9d}".format(len(did_to_lid))) | ||
|
||
id2qtxt_trn, Y_trn_all = build_corpus_and_label_mat(df_trn, did_to_lid, skip_same_lid=False) | ||
id2qtxt_tst, Y_tst_all = build_corpus_and_label_mat(df_tst, did_to_lid, skip_same_lid=False) | ||
id2qtxt_abs, Y_trn_abs = build_corpus_and_label_mat(df_abs, did_to_lid, skip_same_lid=True) | ||
id2qtxt_doc, Y_trn_doc = build_corpus_and_label_mat(df_doc, did_to_lid, skip_same_lid=False) | ||
id2qtxt_d2q, Y_trn_d2q = build_corpus_and_label_mat(df_d2q, did_to_lid, skip_same_lid=False) | ||
|
||
output_dir = "./xmc/nq320k" | ||
os.makedirs(output_dir, exist_ok=True) | ||
write_qtxt(id2qtxt_trn, f"{output_dir}/X.trn.txt") | ||
write_qtxt(id2qtxt_tst, f"{output_dir}/X.tst.txt") | ||
write_qtxt(id2qtxt_abs, f"{output_dir}/X.trn.abs.txt") | ||
write_qtxt(id2qtxt_doc, f"{output_dir}/X.trn.doc.txt") | ||
write_qtxt(id2qtxt_d2q, f"{output_dir}/X.trn.d2q.txt") | ||
smat_util.save_matrix(f"{output_dir}/Y.trn.npz", Y_trn_all) | ||
smat_util.save_matrix(f"{output_dir}/Y.tst.npz", Y_tst_all) | ||
smat_util.save_matrix(f"{output_dir}/Y.trn.abs.npz", Y_trn_abs) | ||
smat_util.save_matrix(f"{output_dir}/Y.trn.doc.npz", Y_trn_doc) | ||
smat_util.save_matrix(f"{output_dir}/Y.trn.d2q.npz", Y_trn_d2q) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
|
||
import os | ||
from collections import defaultdict | ||
import numpy as np | ||
import pandas as pd | ||
import scipy.sparse as smat | ||
from pecos.utils import smat_util | ||
|
||
|
||
COL_NAME_LIST = [ | ||
"query", "qid", "doc_id", | ||
"bert_k30_c30_1", "bert_k30_c30_2", "bert_k30_c30_3", "bert_k30_c30_4", "bert_k30_c30_5", | ||
] | ||
|
||
QG_COL_NAME_LIST = [ | ||
"query", "doc_id", | ||
"bert_k30_c30_1", "bert_k30_c30_2", "bert_k30_c30_3", "bert_k30_c30_4", "bert_k30_c30_5", | ||
] | ||
|
||
|
||
def load_df(input_tsv_path): | ||
return pd.read_csv( | ||
input_tsv_path, | ||
encoding='utf-8', header=None, sep='\t', | ||
names=COL_NAME_LIST, | ||
dtype={"query": str, "qid": str, 'doc_id': str}, | ||
on_bad_lines="skip", | ||
skip_blank_lines=True, | ||
).loc[:, ["query", "qid", "doc_id"]] | ||
|
||
|
||
def load_qg_df(input_tsv_path): | ||
return pd.read_csv( | ||
input_tsv_path, | ||
encoding='utf-8', header=None, sep='\t', | ||
names=QG_COL_NAME_LIST, | ||
dtype={"query": str, 'doc_id': str}, | ||
on_bad_lines="skip", | ||
skip_blank_lines=True, | ||
).loc[:, ["query", "doc_id"]] | ||
|
||
|
||
def build_did2lid_map(df_inp, did_to_lid): | ||
for i in range(len(df_inp)): | ||
#did_str = df_inp["doc_id"][i] | ||
try: | ||
did_list = df_inp["doc_id"][i].split(",") | ||
except: | ||
print(i, type(df_inp["doc_id"][i]), df_inp["doc_id"][i]) | ||
exit(0) | ||
for did_str in did_list: | ||
if did_str not in did_to_lid: | ||
did_to_lid[did_str] = len(did_to_lid) | ||
|
||
def build_corpus_and_label_mat(df, did_to_lid, skip_same_lid=False): | ||
qry_to_qid = defaultdict(str) | ||
rows, cols = [], [] | ||
inc_lid_set = set() | ||
for i in range(len(df)): | ||
query = df["query"][i] | ||
try: | ||
did_list = df["doc_id"][i].split(",") | ||
except: | ||
print(i, type(df["doc_id"][i]), df["doc_id"][i]) | ||
exit(0) | ||
for did_str in did_list: | ||
lid = did_to_lid[did_str] | ||
if skip_same_lid and lid in inc_lid_set: | ||
continue | ||
inc_lid_set.add(lid) | ||
|
||
if query not in qry_to_qid: | ||
qry_to_qid[query] = len(qry_to_qid) | ||
qid = qry_to_qid[query] | ||
rows.append(qid) | ||
cols.append(lid) | ||
|
||
vals = [1.0 for _ in range(len(rows))] | ||
|
||
num_inp, num_out = len(qry_to_qid), len(did_to_lid) | ||
Y = smat.csr_matrix( | ||
(vals, (rows, cols)), | ||
shape=(num_inp, num_out), | ||
dtype=np.float32, | ||
) | ||
print("#Q {:7d} #L {:7d} NNZ {:9d}".format(num_inp, num_out, Y.nnz)) | ||
id2query = [str(query).lower() for query, qid in sorted(qry_to_qid.items(), key=lambda x: x[1])] | ||
return id2query, Y | ||
|
||
|
||
def write_qtxt(id2qtxt, output_path): | ||
with open(output_path, "w") as fout: | ||
for query_txt in id2qtxt: | ||
query_txt = query_txt.replace('\r\n', ' ') | ||
query_txt = query_txt.replace('\n', ' ') | ||
fout.write(f"{query_txt}\n") | ||
|
||
def main(): | ||
df_trn = load_df("./raw/trivia_newdata/train.tsv") | ||
df_tst = load_df("./raw/trivia_newdata/dev.tsv") | ||
df_abs = load_df("./raw/trivia_newdata/trivia_title_cont.tsv") | ||
df_doc = load_df("./raw/trivia_newdata/trivia_doc_aug.tsv") | ||
df_d2q = load_qg_df("./raw/trivia_newdata/trivia_512_qg.tsv") | ||
|
||
did_to_lid = defaultdict(str) | ||
build_did2lid_map(df_abs, did_to_lid) | ||
print("After df_abs, #uniq_label {:9d}".format(len(did_to_lid))) | ||
build_did2lid_map(df_doc, did_to_lid) | ||
print("After df_doc, #uniq_label {:9d}".format(len(did_to_lid))) | ||
build_did2lid_map(df_d2q, did_to_lid) | ||
print("After df_d2q, #uniq_label {:9d}".format(len(did_to_lid))) | ||
|
||
id2qtxt_trn, Y_trn_all = build_corpus_and_label_mat(df_trn, did_to_lid, skip_same_lid=False) | ||
id2qtxt_tst, Y_tst_all = build_corpus_and_label_mat(df_tst, did_to_lid, skip_same_lid=False) | ||
id2qtxt_abs, Y_trn_abs = build_corpus_and_label_mat(df_abs, did_to_lid, skip_same_lid=True) | ||
id2qtxt_doc, Y_trn_doc = build_corpus_and_label_mat(df_doc, did_to_lid, skip_same_lid=False) | ||
id2qtxt_d2q, Y_trn_d2q = build_corpus_and_label_mat(df_d2q, did_to_lid, skip_same_lid=False) | ||
|
||
output_dir = "./xmc/trivia" | ||
os.makedirs(output_dir, exist_ok=True) | ||
write_qtxt(id2qtxt_trn, f"{output_dir}/X.trn.txt") | ||
write_qtxt(id2qtxt_tst, f"{output_dir}/X.tst.txt") | ||
write_qtxt(id2qtxt_abs, f"{output_dir}/X.trn.abs.txt") | ||
write_qtxt(id2qtxt_doc, f"{output_dir}/X.trn.doc.txt") | ||
write_qtxt(id2qtxt_d2q, f"{output_dir}/X.trn.d2q.txt") | ||
smat_util.save_matrix(f"{output_dir}/Y.trn.npz", Y_trn_all) | ||
smat_util.save_matrix(f"{output_dir}/Y.tst.npz", Y_tst_all) | ||
smat_util.save_matrix(f"{output_dir}/Y.trn.abs.npz", Y_trn_abs) | ||
smat_util.save_matrix(f"{output_dir}/Y.trn.doc.npz", Y_trn_doc) | ||
smat_util.save_matrix(f"{output_dir}/Y.trn.d2q.npz", Y_trn_d2q) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.