From e746294ba63292bd6e60afd495830ef3e089002d Mon Sep 17 00:00:00 2001 From: Riccardo Orlando Date: Thu, 8 Aug 2024 14:21:51 +0000 Subject: [PATCH] Update README.md --- README.md | 4 ++-- relik/common/utils.py | 7 ++++--- relik/retriever/indexers/base.py | 8 +++++--- relik/retriever/indexers/faissindex.py | 4 ++-- relik/retriever/indexers/inmemory.py | 2 +- relik/version.py | 2 +- 6 files changed, 15 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 4a62add..ea333a3 100644 --- a/README.md +++ b/README.md @@ -212,7 +212,7 @@ Output: Retrievers and Readers can be used separately. In the case of retriever-only ReLiK, the output will contain the candidates for the input text. -Reader-only example: +Retriever-only example: ```python from relik import Relik @@ -246,7 +246,7 @@ Output: ), ) -Retriever-only example: +Reader-only example: ```python from relik import Relik diff --git a/relik/common/utils.py b/relik/common/utils.py index 9be7012..65de435 100644 --- a/relik/common/utils.py +++ b/relik/common/utils.py @@ -339,9 +339,10 @@ def download_from_hf( downloaded_paths.append(downloaded_path) except OSError: if ignore_failure: - logger.warn( - f"Couldn't download {filename} from {path_or_repo_id}, ignoring" - ) + # logger.warn( + # f"Couldn't download {filename} from {path_or_repo_id}, ignoring" + # ) + pass else: raise diff --git a/relik/retriever/indexers/base.py b/relik/retriever/indexers/base.py index 18c2d3b..4e5bb82 100644 --- a/relik/retriever/indexers/base.py +++ b/relik/retriever/indexers/base.py @@ -467,13 +467,15 @@ def from_pretrained( ) config = OmegaConf.load(config_path) - # add the actual cls class to the config in place of the _target_ if cls is not BaseDocumentIndex + # add the actual cls class to the config in place of the _target_ if cls is not BaseDocumentIndex target = config.get("_target_", None) + use_faiss = kwargs.get("use_faiss", False) or "FaissDocumentIndex" in target + if cls.__name__ != "BaseDocumentIndex" and target is None: kwargs["_target_"] = f"{cls.__module__}.{cls.__name__}" - - use_faiss = kwargs.get("use_faiss", False) or "FaissDocumentIndex" in target + if use_faiss and "FaissDocumentIndex" not in target: + kwargs["_target_"] = "relik.retriever.indexers.faissindex.FaissDocumentIndex" kwargs["device"] = device kwargs["precision"] = precision diff --git a/relik/retriever/indexers/faissindex.py b/relik/retriever/indexers/faissindex.py index 3f3d9ae..18f0241 100644 --- a/relik/retriever/indexers/faissindex.py +++ b/relik/retriever/indexers/faissindex.py @@ -179,8 +179,8 @@ def _build_faiss_index( embeddings.cpu() if isinstance(embeddings, torch.Tensor) else embeddings ) - # convert to float32 if embeddings is a torch.Tensor and is float16 - if isinstance(embeddings, torch.Tensor) and embeddings.dtype == torch.float16: + # convert to float32 if embeddings is a torch.Tensor and not already float32 + if isinstance(embeddings, torch.Tensor) and embeddings.dtype != torch.float32: embeddings = embeddings.float() # logger.info("Training the index.") diff --git a/relik/retriever/indexers/inmemory.py b/relik/retriever/indexers/inmemory.py index 7f08d91..a47845b 100644 --- a/relik/retriever/indexers/inmemory.py +++ b/relik/retriever/indexers/inmemory.py @@ -41,7 +41,7 @@ def __init__( separator: str | None = None, name_or_path: str | os.PathLike | None = None, device: str = "cpu", - precision: str | int | torch.dtype = 32, + precision: str | int | torch.dtype | None = 32, *args, **kwargs, ) -> None: diff --git a/relik/version.py b/relik/version.py index e55efe2..94e759c 100644 --- a/relik/version.py +++ b/relik/version.py @@ -4,7 +4,7 @@ _MINOR = "0" # On main and in a nightly release the patch should be one ahead of the last # released build. -_PATCH = "5" +_PATCH = "6" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. _SUFFIX = os.environ.get("RELIK_VERSION_SUFFIX", "")