Skip to content

Commit

Permalink
Experiment Code of PEFA Paper for WSDM24
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei-Cheng Chang committed Dec 6, 2023
1 parent 51062ec commit db11e9f
Show file tree
Hide file tree
Showing 27 changed files with 850 additions and 0 deletions.
96 changes: 96 additions & 0 deletions examples/pefa-wsdm24/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Experiment Code for PEFA, WSDM 2024

This folder contains code to reproduce experiments in
["PEFA: Parameter-Free Adapters for Large-scale Embedding-based Retrieval Models"](https://arxiv.org/abs/2312.02429)

## 1. Summary
In this repository, we demonstrated how to reproduce Table 2 (NQ-320K) and Table 3 (Trivia-QA) of our PEFA paper.
After following Steps 2-6 in the subsequent sections, you should be able to obtain

| NQ-320K | Recall@10 | Recall@100 |
|---|---|---|
| DistilBERT + PEFA-XS | 80.52% | 92.23% |
| DistilBERT + PEFA-XL | 85.26% | 92.53% |
| MPNet + PEFA-XS | 86.67% | 94.53% |
| MPNet + PEFA-XL | 88.72% | 95.13% |
| Sentence-T5-base + PEFA-XS | 82.52% | 92.18% |
| Sentence-T5-base + PEFA-XL | 83.69% | 92.55% |
| GTR-T5-base + PEFA-XS | 84.90% | 93.28% |
| GTR-T5-base + PEFA-XL | 88.71% | 94.36% |

| Trivia-QA | Recall@20 | Recall@100 |
|---|---|---|
| DistilBERT + PEFA-XS | 86.28% | 93.33% |
| DistilBERT + PEFA-XL | 84.18% | 91.24% |
| MPNet + PEFA-XS | 86.05% | 92.97% |
| MPNet + PEFA-XL | 86.13% | 92.42% |
| Sentence-T5-base + PEFA-XS | 78.39% | 88.57% |
| Sentence-T5-base + PEFA-XL | 75.13% | 87.24% |
| GTR-T5-base + PEFA-XS | 83.81% | 91.02% |
| GTR-T5-base + PEFA-XL | 85.30% | 92.38% |


## 2. Getting Started
* Clone the repository and enter `examples/pefa-wsdm24` directory.
* First create a [virtual environment](https://docs.python.org/3/library/venv.html) and then install dependencies by running the following command:
```bash
python3 -m pip install libpecos==1.2.1
python3 -m pip install sentence-transformers==2.2.1
```

## 3. Download Pre-processed Data for NQ320K and Trivia-QA
Our pre-processed datasets of NQ320K and Trivia-QA can be download at
```bash
mkdir -p ./data/xmc; cd ./data/xmc;
DATASET="nq320k" # nq320k or trivia
wget https://archive.org/download/pefa-wsdm24/data/xmc/${DATASET}.tar.gz
tar -zxvf ./${DATASET}.tar.gz
cd ../../ # get back to the pecos/examples/pefa-wsdm24 directory
```

Additional Notes on data-preprocessing
* We first obtained original NQ320K/Trivia-QA datasets from the [NCI Paper, Wang et al., NeurIPS22](https://github.com/solidsea98/Neural-Corpus-Indexer-NCI)
- We then pre-processed it into our format.
- Details about our data pre-processing scripts can be found in `./data/README.md`.


## 4. Generate Embeddings for PEFA Inference
Before running PEFA inference, we select an encoder to generating query/passage embeddings
```bash
DATASET="nq320k" # nq320k or trivia
ENCODER="gtr-t5-base"
bash run_encoder ${DATASET} ${ENCODER}
```
The embeddings will be saved to `./data/embeddings/${DATSET}/`
Regarding the `ENCODER` used in our paper,
* For `nq320k`, we consider `{nq-distilbert-base-v1, multi-qa-mpnet-base-dot-v1, sentence-t5-base, gtr-t5-base}`
* For `trivia`, we consider `{multi-qa-distilbert-dot-v1, multi-qa-mpnet-base-dot-v1, sentence-t5-base, gtr-t5-base}`

## 5. Run PEFA-XS
```bash
DATASET="nq320k"
ENCODER="gtr-t5-base"
bash run_pefa_xs.sh ${DATASET} ${ENCODER}
```
The script `run_pefa_xs.sh` calls the `pefa_xs.py` with hard-coded hyper-parameters.
For example, it uses `threads=64`. If your machine has less CPU cores, please adjust it accordingly.

## 6. Run PEFA-XL
```bash
DATASET="nq320k"
ENCODER="gtr-t5-base"
bash run_pefa_xl.sh ${DATASET} ${ENCODER}
```
The script `run_pefa_xl.sh` calls the `pefa_xl.py` with hard-coded hyper-parameters.
For example, it uses `threads=64`. If your machine has less CPU cores, please adjust it accordingly.

## 7. Citation
If you find this work useful for your research, please cite:
```
@inproceedings{chang2024pefa,
title={PEFA: Parameter-Free Adapters for Large-scale Embedding-based Retrieval Models},
author={Wei-Cheng Chang and Jyun-Yu Jiang and Jiong Zhang and Mutasem Al-Darabsah and Choon Hui Teo and Cho-Jui Hsieh and Hsiang-Fu Yu and S. V. N. Vishwanathan},
booktitle={Proceedings of the 17th ACM International Conference on Web Search and Data Mining (WSDM '24)},
year={2024}
}
```
34 changes: 34 additions & 0 deletions examples/pefa-wsdm24/data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

# Download Raw Data from NCI Paper
- NQ320K: https://drive.google.com/drive/folders/1epfUw4yQjAtqnZTQDLAUOwTJg-YMCGdD
- Trivia: https://drive.google.com/drive/folders/1SY28Idba1X8DNi4PYaDDH9CbUpdKiTXQ
```
unzip the NQ320K folder to ./raw/NQ320K_data
unzip the Trivia folder to ./raw/trivia_newdata
```

# Process the Raw Data to XMC Format
- NQ320K:
```
python proc_nq320k.py
```

- Trivia:
```
python proc_trivia.py
```

You should see the following data artifacts
```
./xmc/{nq320k|trivia}
|- X.trn.abs.txt
|- X.trn.d2q.txt
|- X.trn.doc.txt
|- X.trn.txt
|- X.tst.txt
|- Y.trn.abs.npz
|- Y.trn.d2q.npz
|- Y.trn.doc.npz
|- Y.trn.npz
|- Y.tst.npz
```
102 changes: 102 additions & 0 deletions examples/pefa-wsdm24/data/proc_nq320k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@

import os
from collections import defaultdict
import numpy as np
import pandas as pd
import scipy.sparse as smat
from pecos.utils import smat_util


COL_NAME_LIST = [
"query", "qid", "doc_id",
"bert_k30_c30_1", "bert_k30_c30_2", "bert_k30_c30_3", "bert_k30_c30_4", "bert_k30_c30_5",
]

def load_df(input_tsv_path):
return pd.read_csv(
input_tsv_path,
encoding='utf-8', header=None, sep='\t',
names=COL_NAME_LIST,
dtype={"query": str, "qid": str, 'doc_id': str}
).loc[:, ["query", "qid", "doc_id"]]

def build_did2lid_map(df_inp, did_to_lid):
for i in range(len(df_inp)):
did_str = df_inp["doc_id"][i]
if did_str not in did_to_lid:
did_to_lid[did_str] = len(did_to_lid)

def build_corpus_and_label_mat(df, did_to_lid, skip_same_lid=False):
qry_to_qid = defaultdict(str)
rows, cols = [], []
inc_lid_set = set()
for i in range(len(df)):
query = df["query"][i]
did_str = df["doc_id"][i]

lid = did_to_lid[did_str]
if skip_same_lid and lid in inc_lid_set:
continue
inc_lid_set.add(lid)

if query not in qry_to_qid:
qry_to_qid[query] = len(qry_to_qid)
qid = qry_to_qid[query]
rows.append(qid)
cols.append(lid)

vals = [1.0 for _ in range(len(rows))]

num_inp, num_out = len(qry_to_qid), len(did_to_lid)
Y = smat.csr_matrix(
(vals, (rows, cols)),
shape=(num_inp, num_out),
dtype=np.float32,
)
print("#Q {:7d} #L {:7d} NNZ {:9d}".format(num_inp, num_out, Y.nnz))
id2query = [str(query).lower() for query, qid in sorted(qry_to_qid.items(), key=lambda x: x[1])]
return id2query, Y


def write_qtxt(id2qtxt, output_path):
with open(output_path, "w") as fout:
for query_txt in id2qtxt:
fout.write(f"{query_txt}\n")

def main():
df_trn = load_df("./raw/NQ_dataset/nq_train_doc_newid.tsv")
df_tst = load_df("./raw/NQ_dataset/nq_dev_doc_newid.tsv")
df_abs = load_df("./raw/NQ_dataset/nq_title_abs.tsv")
df_doc = load_df("./raw/NQ_dataset/NQ_doc_aug.tsv")
df_d2q = load_df("./raw/NQ_dataset/NQ_512_qg.tsv")

did_to_lid = defaultdict(str)
build_did2lid_map(df_abs, did_to_lid)
print("After df_abs, #uniq_label {:9d}".format(len(did_to_lid)))
build_did2lid_map(df_doc, did_to_lid)
print("After df_doc, #uniq_label {:9d}".format(len(did_to_lid)))
build_did2lid_map(df_d2q, did_to_lid)
print("After df_d2q, #uniq_label {:9d}".format(len(did_to_lid)))

id2qtxt_trn, Y_trn_all = build_corpus_and_label_mat(df_trn, did_to_lid, skip_same_lid=False)
id2qtxt_tst, Y_tst_all = build_corpus_and_label_mat(df_tst, did_to_lid, skip_same_lid=False)
id2qtxt_abs, Y_trn_abs = build_corpus_and_label_mat(df_abs, did_to_lid, skip_same_lid=True)
id2qtxt_doc, Y_trn_doc = build_corpus_and_label_mat(df_doc, did_to_lid, skip_same_lid=False)
id2qtxt_d2q, Y_trn_d2q = build_corpus_and_label_mat(df_d2q, did_to_lid, skip_same_lid=False)

output_dir = "./xmc/nq320k"
os.makedirs(output_dir, exist_ok=True)
write_qtxt(id2qtxt_trn, f"{output_dir}/X.trn.txt")
write_qtxt(id2qtxt_tst, f"{output_dir}/X.tst.txt")
write_qtxt(id2qtxt_abs, f"{output_dir}/X.trn.abs.txt")
write_qtxt(id2qtxt_doc, f"{output_dir}/X.trn.doc.txt")
write_qtxt(id2qtxt_d2q, f"{output_dir}/X.trn.d2q.txt")
smat_util.save_matrix(f"{output_dir}/Y.trn.npz", Y_trn_all)
smat_util.save_matrix(f"{output_dir}/Y.tst.npz", Y_tst_all)
smat_util.save_matrix(f"{output_dir}/Y.trn.abs.npz", Y_trn_abs)
smat_util.save_matrix(f"{output_dir}/Y.trn.doc.npz", Y_trn_doc)
smat_util.save_matrix(f"{output_dir}/Y.trn.d2q.npz", Y_trn_d2q)


if __name__ == "__main__":
main()
134 changes: 134 additions & 0 deletions examples/pefa-wsdm24/data/proc_trivia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@

import os
from collections import defaultdict
import numpy as np
import pandas as pd
import scipy.sparse as smat
from pecos.utils import smat_util


COL_NAME_LIST = [
"query", "qid", "doc_id",
"bert_k30_c30_1", "bert_k30_c30_2", "bert_k30_c30_3", "bert_k30_c30_4", "bert_k30_c30_5",
]

QG_COL_NAME_LIST = [
"query", "doc_id",
"bert_k30_c30_1", "bert_k30_c30_2", "bert_k30_c30_3", "bert_k30_c30_4", "bert_k30_c30_5",
]


def load_df(input_tsv_path):
return pd.read_csv(
input_tsv_path,
encoding='utf-8', header=None, sep='\t',
names=COL_NAME_LIST,
dtype={"query": str, "qid": str, 'doc_id': str},
on_bad_lines="skip",
skip_blank_lines=True,
).loc[:, ["query", "qid", "doc_id"]]


def load_qg_df(input_tsv_path):
return pd.read_csv(
input_tsv_path,
encoding='utf-8', header=None, sep='\t',
names=QG_COL_NAME_LIST,
dtype={"query": str, 'doc_id': str},
on_bad_lines="skip",
skip_blank_lines=True,
).loc[:, ["query", "doc_id"]]


def build_did2lid_map(df_inp, did_to_lid):
for i in range(len(df_inp)):
#did_str = df_inp["doc_id"][i]
try:
did_list = df_inp["doc_id"][i].split(",")
except:
print(i, type(df_inp["doc_id"][i]), df_inp["doc_id"][i])
exit(0)
for did_str in did_list:
if did_str not in did_to_lid:
did_to_lid[did_str] = len(did_to_lid)

def build_corpus_and_label_mat(df, did_to_lid, skip_same_lid=False):
qry_to_qid = defaultdict(str)
rows, cols = [], []
inc_lid_set = set()
for i in range(len(df)):
query = df["query"][i]
try:
did_list = df["doc_id"][i].split(",")
except:
print(i, type(df["doc_id"][i]), df["doc_id"][i])
exit(0)
for did_str in did_list:
lid = did_to_lid[did_str]
if skip_same_lid and lid in inc_lid_set:
continue
inc_lid_set.add(lid)

if query not in qry_to_qid:
qry_to_qid[query] = len(qry_to_qid)
qid = qry_to_qid[query]
rows.append(qid)
cols.append(lid)

vals = [1.0 for _ in range(len(rows))]

num_inp, num_out = len(qry_to_qid), len(did_to_lid)
Y = smat.csr_matrix(
(vals, (rows, cols)),
shape=(num_inp, num_out),
dtype=np.float32,
)
print("#Q {:7d} #L {:7d} NNZ {:9d}".format(num_inp, num_out, Y.nnz))
id2query = [str(query).lower() for query, qid in sorted(qry_to_qid.items(), key=lambda x: x[1])]
return id2query, Y


def write_qtxt(id2qtxt, output_path):
with open(output_path, "w") as fout:
for query_txt in id2qtxt:
query_txt = query_txt.replace('\r\n', ' ')
query_txt = query_txt.replace('\n', ' ')
fout.write(f"{query_txt}\n")

def main():
df_trn = load_df("./raw/trivia_newdata/train.tsv")
df_tst = load_df("./raw/trivia_newdata/dev.tsv")
df_abs = load_df("./raw/trivia_newdata/trivia_title_cont.tsv")
df_doc = load_df("./raw/trivia_newdata/trivia_doc_aug.tsv")
df_d2q = load_qg_df("./raw/trivia_newdata/trivia_512_qg.tsv")

did_to_lid = defaultdict(str)
build_did2lid_map(df_abs, did_to_lid)
print("After df_abs, #uniq_label {:9d}".format(len(did_to_lid)))
build_did2lid_map(df_doc, did_to_lid)
print("After df_doc, #uniq_label {:9d}".format(len(did_to_lid)))
build_did2lid_map(df_d2q, did_to_lid)
print("After df_d2q, #uniq_label {:9d}".format(len(did_to_lid)))

id2qtxt_trn, Y_trn_all = build_corpus_and_label_mat(df_trn, did_to_lid, skip_same_lid=False)
id2qtxt_tst, Y_tst_all = build_corpus_and_label_mat(df_tst, did_to_lid, skip_same_lid=False)
id2qtxt_abs, Y_trn_abs = build_corpus_and_label_mat(df_abs, did_to_lid, skip_same_lid=True)
id2qtxt_doc, Y_trn_doc = build_corpus_and_label_mat(df_doc, did_to_lid, skip_same_lid=False)
id2qtxt_d2q, Y_trn_d2q = build_corpus_and_label_mat(df_d2q, did_to_lid, skip_same_lid=False)

output_dir = "./xmc/trivia"
os.makedirs(output_dir, exist_ok=True)
write_qtxt(id2qtxt_trn, f"{output_dir}/X.trn.txt")
write_qtxt(id2qtxt_tst, f"{output_dir}/X.tst.txt")
write_qtxt(id2qtxt_abs, f"{output_dir}/X.trn.abs.txt")
write_qtxt(id2qtxt_doc, f"{output_dir}/X.trn.doc.txt")
write_qtxt(id2qtxt_d2q, f"{output_dir}/X.trn.d2q.txt")
smat_util.save_matrix(f"{output_dir}/Y.trn.npz", Y_trn_all)
smat_util.save_matrix(f"{output_dir}/Y.tst.npz", Y_tst_all)
smat_util.save_matrix(f"{output_dir}/Y.trn.abs.npz", Y_trn_abs)
smat_util.save_matrix(f"{output_dir}/Y.trn.doc.npz", Y_trn_doc)
smat_util.save_matrix(f"{output_dir}/Y.trn.d2q.npz", Y_trn_d2q)


if __name__ == "__main__":
main()
Loading

0 comments on commit db11e9f

Please sign in to comment.