From 5902ece7a4fca522c847fa5dffcc3dcfdc45fa1d Mon Sep 17 00:00:00 2001 From: Riccardo Orlando Date: Tue, 6 Aug 2024 09:47:14 +0000 Subject: [PATCH] Add docstrings --- relik/retriever/indexers/base.py | 23 +++++++++++++++++++++++ relik/retriever/indexers/document.py | 17 ++++++++++++++--- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/relik/retriever/indexers/base.py b/relik/retriever/indexers/base.py index b3e5edf..18c2d3b 100644 --- a/relik/retriever/indexers/base.py +++ b/relik/retriever/indexers/base.py @@ -415,6 +415,29 @@ def from_pretrained( *args, **kwargs, ) -> "BaseDocumentIndex": + """ + Loads a pre-trained document index from the specified location. + + Args: + name_or_path (Union[str, os.PathLike]): The name or path of the pre-trained model. + device (str, optional): The device to load the model on. Defaults to "cpu". + precision (str | None, optional): The precision of the model. Defaults to None. + config_file_name (str | None, optional): The name of the configuration file. Defaults to None. + document_file_name (str | None, optional): The name of the document file. Defaults to None. + embedding_file_name (str | None, optional): The name of the embedding file. Defaults to None. + index_file_name (str | None, optional): The name of the index file. Defaults to None. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + BaseDocumentIndex: The loaded pre-trained document index. + + Raises: + FileNotFoundError: If the model configuration file is not found. + ValueError: If the document file does not exist. + ImportError: If the `faiss` package is not installed when trying to load a FAISS index. + + """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) skip_metadata = kwargs.pop("skip_metadata", False) diff --git a/relik/retriever/indexers/document.py b/relik/retriever/indexers/document.py index b8d7d01..d760b20 100644 --- a/relik/retriever/indexers/document.py +++ b/relik/retriever/indexers/document.py @@ -310,7 +310,20 @@ def from_dict(cls, d): return cls([Document.from_dict(doc) for doc in d]) @classmethod - def from_file(cls, file_path: Union[str, Path], skip_metadata: bool = False, **kwargs): + def from_file( + cls, file_path: Union[str, Path], skip_metadata: bool = False, **kwargs + ): + """ + Load documents from a file. + + Args: + file_path (Union[str, Path]): The path to the file containing the documents. + skip_metadata (bool, optional): Whether to skip loading metadata for each document. Defaults to False. + **kwargs: Additional keyword arguments. + + Returns: + cls: An instance of the class with the loaded documents. + """ with open(file_path, "r") as f: docs = [] for line in f: @@ -318,8 +331,6 @@ def from_file(cls, file_path: Union[str, Path], skip_metadata: bool = False, **k if skip_metadata: doc.pop("metadata", None) docs.append(Document.from_dict(doc)) - # load a json lines file - # d = [Document.from_dict(json.loads(line)) for line in f] return cls(docs) @classmethod