From 9d94478d508c0497b29e92095ba1535e04d86f65 Mon Sep 17 00:00:00 2001 From: LittlePea13 Date: Tue, 6 Aug 2024 11:45:29 +0200 Subject: [PATCH] add silly retriever --- relik/retriever/pytorch_modules/model.py | 29 +++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/relik/retriever/pytorch_modules/model.py b/relik/retriever/pytorch_modules/model.py index a46fa7e..0ca0f52 100644 --- a/relik/retriever/pytorch_modules/model.py +++ b/relik/retriever/pytorch_modules/model.py @@ -21,7 +21,7 @@ from relik.retriever.data.base.datasets import BaseDataset from relik.retriever.data.labels import Labels from relik.retriever.indexers.base import BaseDocumentIndex -from relik.retriever.indexers.document import Document +from relik.retriever.indexers.document import Document, DocumentStore from relik.retriever.indexers.faissindex import FaissDocumentIndex from relik.retriever.indexers.inmemory import InMemoryDocumentIndex from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample @@ -647,3 +647,30 @@ def to_config(cls, *args, **kwargs): "document_index": to_config(cls.document_index), } return config + +class GoldenSillyRetriever(GoldenRetriever): + def __init__(self, documents: List[str], *args, **kwargs): + self.documents = DocumentStore([Document(doc) for doc in documents]) + self.document_index = BaseDocumentIndex(self.documents) + def retrieve(self, + text: Optional[Union[str, List[str]]] = None, + k: int = 100, + *args, + **kwargs, + ) -> List[List[RetrievedSample]]: + if isinstance(text, str): + text = [text] + elif text is None: + text = [] + return [ + [RetrievedSample(score=1.0, document=doc) for doc in self.documents[:k]] + for _ in text + ] + def index(self): + pass + def eval(self): + pass + def save_pretrained(self): + pass + def to(self, device): + pass \ No newline at end of file