Skip to content

Commit

Permalink
Refactor document retrieval methods in DocumentStore
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed May 24, 2024
1 parent a8ad16b commit 1c7aaf3
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 97 deletions.
62 changes: 17 additions & 45 deletions examples/retriever/create_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,54 +25,26 @@ def build_index(
indexer_class: str = "goldenretriever.indexers.inmemory.InMemoryDocumentIndex",
batch_size: int = 512,
num_workers: int = 4,
passage_max_length: int = 64,
passage_max_length: int = 512,
device: str = "cuda",
index_device: str = "cpu",
precision: str = "fp32",
):
logger.info("Loading documents")
train_dataset = InBatchNegativesDataset(
name="raco_train",
path="/root/golden-retriever/data/commonsense/raco/CommonsenseTraining/train.json",
tokenizer=question_encoder_name_or_path,
question_batch_size=64,
passage_batch_size=400,
max_passage_length=64,
shuffle=True,
)
val_dataset = InBatchNegativesDataset(
name="raco_val",
path="/root/golden-retriever/data/commonsense/raco/CommonsenseTraining/dev.json",
tokenizer=question_encoder_name_or_path,
question_batch_size=64,
passage_batch_size=400,
max_passage_length=64,
)
# if document_file_type == "jsonl":
# documents = DocumentStore.from_jsonl(document_path)
# elif document_file_type == "csv":
# documents = DocumentStore.from_tsv(
# document_path, delimiter=",", quoting=csv.QUOTE_NONE, ingore_case=True
# )
# elif document_file_type == "tsv":
# documents = DocumentStore.from_tsv(
# document_path, delimiter="\t", quoting=csv.QUOTE_NONE, ingore_case=True
# )
# else:
# raise ValueError(
# f"Unknown document file type: {document_file_type}, must be one of jsonl, csv, tsv"
# )
documents = DocumentStore()
logger.info("Adding documents to document store")
for sample in tqdm(train_dataset):
[documents.add_document(s) for s in sample["positives"] if s is not None]
[documents.add_document(s) for s in sample["negatives"] if s is not None]
[documents.add_document(s) for s in sample["hard_negatives"] if s is not None]

for sample in tqdm(val_dataset):
[documents.add_document(s) for s in sample["positives"] if s is not None]
[documents.add_document(s) for s in sample["negatives"] if s is not None]
[documents.add_document(s) for s in sample["hard_negatives"] if s is not None]
if document_file_type == "jsonl":
documents = DocumentStore.from_file(document_path)
elif document_file_type == "csv":
documents = DocumentStore.from_tsv(
document_path, delimiter=",", quoting=csv.QUOTE_NONE, ingore_case=True
)
elif document_file_type == "tsv":
documents = DocumentStore.from_tsv(
document_path, delimiter="\t", quoting=csv.QUOTE_NONE, ingore_case=True
)
else:
raise ValueError(
f"Unknown document file type: {document_file_type}, must be one of jsonl, csv, tsv"
)

logger.info("Loading document index")
# document_index = InMemoryDocumentIndex(
Expand Down Expand Up @@ -123,9 +95,9 @@ def build_index(
arg_parser.add_argument("--document_file_type", type=str, default="jsonl")
arg_parser.add_argument("--output_folder", type=str, required=True)
arg_parser.add_argument("--batch_size", type=int, default=128)
arg_parser.add_argument("--passage_max_length", type=int, default=64)
arg_parser.add_argument("--passage_max_length", type=int, default=256)
arg_parser.add_argument("--device", type=str, default="cuda")
arg_parser.add_argument("--index_device", type=str, default="cpu")
arg_parser.add_argument("--precision", type=str, default="fp32")

build_index(**vars(arg_parser.parse_args()))
build_index(**vars(arg_parser.parse_args()))
57 changes: 28 additions & 29 deletions examples/retriever/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from goldenretriever.common.log import get_logger
from goldenretriever.indexers.base import BaseDocumentIndex
from goldenretriever.indexers.document import DocumentStore
from goldenretriever.trainer import Trainer
from goldenretriever.trainer.train import Trainer
from goldenretriever import GoldenRetriever
from goldenretriever.indexers.inmemory import InMemoryDocumentIndex
from goldenretriever.indexers.faiss import FaissDocumentIndex
Expand All @@ -12,37 +12,36 @@

if __name__ == "__main__":
# instantiate retriever
document_index = BaseDocumentIndex.from_pretrained(
"/root/golden-retriever/data/retrievers/raco-e5-small-v2/training-index/document_index",
# _target_="goldenretriever.indexers.faiss.FaissDocumentIndex"
)
# document_index = BaseDocumentIndex.from_pretrained(
# "wandb/run-20240521_212032-fqnhon6y/files/retriever/document_index",
# )
retriever = GoldenRetriever(
question_encoder="/root/golden-retriever/data/retrievers/raco-e5-small-v2/training-index/question_encoder",
document_index=document_index,
question_encoder="wandb/run-20240521_212032-fqnhon6y/files/retriever/question_encoder",
document_index="wandb/run-20240521_212032-fqnhon6y/files/retriever/document_index",
device="cuda",
precision="16",
)

train_dataset = InBatchNegativesDataset(
name="raco_train",
path="/root/golden-retriever/data/commonsense/raco/CommonsenseTraining/train.json",
tokenizer=retriever.question_tokenizer,
question_batch_size=64,
passage_batch_size=400,
max_passage_length=64,
shuffle=True,
)
val_dataset = InBatchNegativesDataset(
name="raco_val",
path="/root/golden-retriever/data/commonsense/raco/CommonsenseTraining/dev.json",
tokenizer=retriever.question_tokenizer,
question_batch_size=64,
passage_batch_size=400,
max_passage_length=64,
)
# train_dataset = InBatchNegativesDataset(
# name="raco_train",
# path="/root/golden-retriever/data/commonsense/raco/CommonsenseTraining/train.json",
# tokenizer=retriever.question_tokenizer,
# question_batch_size=64,
# passage_batch_size=400,
# max_passage_length=64,
# shuffle=True,
# )
# val_dataset = InBatchNegativesDataset(
# name="raco_val",
# path="/root/golden-retriever/data/commonsense/raco/CommonsenseTraining/dev.json",
# tokenizer=retriever.question_tokenizer,
# question_batch_size=64,
# passage_batch_size=400,
# max_passage_length=64,
# )
test_dataset = InBatchNegativesDataset(
name="raco_val",
path="/root/golden-retriever/data/commonsense/raco/CommonsenseTraining/obqa.json",
path="/media/data/commonsense/RACo/data/dev.json",
tokenizer=retriever.question_tokenizer,
question_batch_size=64,
passage_batch_size=400,
Expand All @@ -61,16 +60,16 @@

trainer = Trainer(
retriever=retriever,
train_dataset=train_dataset,
val_dataset=val_dataset,
train_dataset=None,
val_dataset=None,
test_dataset=test_dataset,
num_workers=0,
max_steps=25_000,
wandb_online_mode=False,
wandb_project_name="golden-retriever-raco",
wandb_experiment_name="raco-e5-small-inbatch",
wandb_experiment_name="raco-e5-base-eval",
max_hard_negatives_to_mine=0,
top_k=[5, 10]
top_k=[1, 3, 5, 10, 20, 50, 100]
)

# trainer.train()
Expand Down
81 changes: 81 additions & 0 deletions examples/training/train_raco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from tqdm import tqdm
from goldenretriever.common.log import get_logger
from goldenretriever.indexers.document import DocumentStore
from goldenretriever import Trainer
from goldenretriever import GoldenRetriever
from goldenretriever.indexers.inmemory import InMemoryDocumentIndex
from goldenretriever.data.datasets import AidaInBatchNegativesDataset, InBatchNegativesDataset

logger = get_logger(__name__)

if __name__ == "__main__":
# instantiate retriever
retriever = GoldenRetriever(
# question_encoder="facebook/contriever",
question_encoder="intfloat/e5-base-v2",
document_index=InMemoryDocumentIndex(
documents=DocumentStore.from_file(
"/media/data/commonsense/RACo/data/train.dev.index.jsonl"
),
# metadata_fields=["definition"],
# separator=" <def> ",
device="cuda",
precision="16",
),
)

train_dataset = InBatchNegativesDataset(
name="raco_train",
path="/media/data/commonsense/RACo/data/train.json",
tokenizer=retriever.question_tokenizer,
question_batch_size=64,
passage_batch_size=200,
max_passage_length=64,
shuffle=True,
# load_from_cache_file=False,
)
val_dataset = InBatchNegativesDataset(
name="raco_val",
path="/media/data/commonsense/RACo/data/dev.json",
tokenizer=retriever.question_tokenizer,
question_batch_size=64,
passage_batch_size=200,
max_passage_length=64,
# load_from_cache_file=False,
)
# test_dataset = AidaInBatchNegativesDataset(
# name="aida_test",
# path="/root/golden-retriever/data/entitylinking/aida_32_tokens_topic/test.jsonl",
# tokenizer=retriever.question_tokenizer,
# question_batch_size=64,
# passage_batch_size=400,
# max_passage_length=64,
# use_topics=True,
# )

# logger.info("Loading document index")
# document_index = InMemoryDocumentIndex(
# documents=DocumentStore.from_file("/root/golden-retriever/data/entitylinking/documents.jsonl"),
# metadata_fields=["definition"],
# separator=" <def> ",
# device="cuda",
# precision="16",
# )
# retriever.document_index = document_index

trainer = Trainer(
retriever=retriever,
train_dataset=train_dataset,
val_dataset=val_dataset,
num_workers=0,
max_steps=55_000,
wandb_online_mode=True,
wandb_log_model=False,
wandb_project_name="golden-retriever-raco",
# wandb_experiment_name="e5-base-v2-raco-self-index-hn=5-prob=0.2-batch=400",
wandb_experiment_name="e5-base-v2-raco-self-index-batch=400",
max_hard_negatives_to_mine=0,
# mine_hard_negatives_with_probability=0.2,
)

trainer.train()
17 changes: 13 additions & 4 deletions goldenretriever/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def __init__(
passage_batch_size: int = 32,
question_batch_size: int = 32,
max_positives: int = -1,
max_negatives: int = 0,
max_hard_negatives: int = 0,
max_negatives: int = -1,
max_hard_negatives: int = -1,
max_question_length: int = 256,
max_passage_length: int = 64,
shuffle: bool = False,
Expand Down Expand Up @@ -520,8 +520,17 @@ def load_fn(
passage=passage,
positive_pssgs=passage[: len(positives)],
positives=positives,
negatives=negatives,
hard_negatives=hard_negatives,
negatives=(
negatives
if len(negatives) > 0
else None
),
hard_negatives=(
hard_negatives
if len(hard_negatives) > 0
# else datasets.Sequence(datasets.Value("string"))
else None
),
)
return output

Expand Down
2 changes: 1 addition & 1 deletion goldenretriever/indexers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def get_document_from_index(self, index: int) -> Document | None:
Returns:
`str`: The document.
"""
return self.documents.get_document_from_id(index)
return self.documents.get_document_from_index(index)

def get_passage_from_index(self, index: int) -> str:
"""
Expand Down
21 changes: 20 additions & 1 deletion goldenretriever/indexers/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,21 @@ def get_document_from_text(self, text: str) -> Document | None:
if text not in self._documents_reverse_index:
logger.warning(f"Document with text `{text}` does not exist, skipping")
return self._documents_reverse_index.get(text, None)

def get_document_from_index(self, index: int) -> Document | None:
"""
Retrieve the document by its index.
Args:
index (`int`):
The index of the document to retrieve.
Returns:
Optional[Document]: The document with the given index, or None if it does not exist.
"""
if index >= len(self._documents):
logger.warning(f"Document with index `{index}` does not exist, skipping")
return self._documents[index]

def add_documents(self, documents: List[Document] | List[str] | List[Dict]) -> List[Document]:
"""
Expand Down Expand Up @@ -175,13 +190,17 @@ def add_document(
Document: The document just added.
"""
if isinstance(text, str):
# check if the document already exists
if text in self:
logger.warning(f"Document `{text}` already exists, skipping")
return self._documents_reverse_index[text]
if id is None:
# get the len of the documents and add 1
id = len(self._documents) # + 1
text = Document(text, id, metadata)

if text in self:
logger.warning(f"Document {text} already exists, skipping")
logger.warning(f"Document `{text}` already exists, skipping")
return self._documents_index[text.id]

self._documents.append(text)
Expand Down
2 changes: 1 addition & 1 deletion goldenretriever/indexers/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]
batch_scores: List[List[float]] = retriever_out[0].detach().cpu().tolist()
# Retrieve the passages corresponding to the indices
batch_docs = [
[self.documents.get_document_from_id(i) for i in indices if i != -1]
[self.get_document_from_index(i) for i in indices if i != -1]
for indices in batch_top_k
]
# build the output object
Expand Down
2 changes: 1 addition & 1 deletion goldenretriever/indexers/inmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]
batch_scores: List[List[float]] = retriever_out.values.detach().cpu().tolist()
# Retrieve the passages corresponding to the indices
batch_docs = [
[self.documents.get_document_from_id(i) for i in indices]
[self.get_document_from_index(i) for i in indices]
for indices in batch_top_k
]
# build the output object
Expand Down
2 changes: 1 addition & 1 deletion goldenretriever/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
PRECISION_INPUT_STR_ALIAS_CONVERSION = {"64": "64-true", "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed"}
PRECISION_INPUT_STR_ALIAS_CONVERSION = {"64": "64-true", "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed"}
Loading

0 comments on commit 1c7aaf3

Please sign in to comment.