Skip to content

Commit

Permalink
add silly retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
LittlePea13 committed Aug 6, 2024
1 parent d4f1d3f commit 9d94478
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion relik/retriever/pytorch_modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from relik.retriever.data.base.datasets import BaseDataset
from relik.retriever.data.labels import Labels
from relik.retriever.indexers.base import BaseDocumentIndex
from relik.retriever.indexers.document import Document
from relik.retriever.indexers.document import Document, DocumentStore
from relik.retriever.indexers.faissindex import FaissDocumentIndex
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample
Expand Down Expand Up @@ -647,3 +647,30 @@ def to_config(cls, *args, **kwargs):
"document_index": to_config(cls.document_index),
}
return config

class GoldenSillyRetriever(GoldenRetriever):
def __init__(self, documents: List[str], *args, **kwargs):
self.documents = DocumentStore([Document(doc) for doc in documents])
self.document_index = BaseDocumentIndex(self.documents)
def retrieve(self,
text: Optional[Union[str, List[str]]] = None,
k: int = 100,
*args,
**kwargs,
) -> List[List[RetrievedSample]]:
if isinstance(text, str):
text = [text]
elif text is None:
text = []
return [
[RetrievedSample(score=1.0, document=doc) for doc in self.documents[:k]]
for _ in text
]
def index(self):
pass
def eval(self):
pass
def save_pretrained(self):
pass
def to(self, device):
pass

0 comments on commit 9d94478

Please sign in to comment.