From ccc7567b1a9cd8942ea6dd5ba515da24bdfd7059 Mon Sep 17 00:00:00 2001 From: Fynn Petersen-Frey Date: Tue, 14 May 2024 11:48:39 +0000 Subject: [PATCH] qdant client --- backend/src/app/core/search/qdrant_service.py | 162 ++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 backend/src/app/core/search/qdrant_service.py diff --git a/backend/src/app/core/search/qdrant_service.py b/backend/src/app/core/search/qdrant_service.py new file mode 100644 index 000000000..c87449447 --- /dev/null +++ b/backend/src/app/core/search/qdrant_service.py @@ -0,0 +1,162 @@ +from typing import List, Tuple + +import numpy as np +from qdrant_client import QdrantClient +from qdrant_client.conversions.common_types import VectorParams, Distance +from app.core.data.dto.search import SimSearchSentenceHit +from app.preprocessing.ray_model_service import RayModelService +from app.preprocessing.ray_model_worker.dto.clip import ClipTextEmbeddingInput +from app.util.singleton_meta import SingletonMeta +from config import conf +from loguru import logger + + +class QdrantService(metaclass=SingletonMeta): + def __new__(cls, *args, **kwargs): + cls._sentence_class_name = "Sentence" + cls._document_class_name = "Document" + cls._colletions = [cls._sentence_class_name, cls._document_class_name] + + try: + cls._client = QdrantClient( + host=conf.qdrant.host, + port=conf.qdrant.port, + grpc_port=conf.qdrant.crpc_port, + prefer_grpc=True, + ) + collections = {c[0] for c in cls._client.get_collections()} + if kwargs["flush"] if "flush" in kwargs else False: + logger.warning("Flushing DWTS Qdrant Data!") + for c in collections: + cls._client.delete_collection(c) + for name in cls._colletions.items(): + if name not in collections: + cls._client.create_collection( + name, + vectors_config=VectorParams(size=512, distance=Distance.COSINE), + ) + + except Exception as e: + msg = f"Cannot connect or initialize to Qdrant DB - Error '{e}'" + logger.error(msg) + raise SystemExit(msg) + + cls.rms = RayModelService() + + return super(QdrantService, cls).__new__(cls) + + def add_text_sdoc_to_index( + self, + proj_id: int, + sdoc_id: int, + sentences: List[str], + ) -> None: + sentence_embs = self.rms.clip_text_embedding( + ClipTextEmbeddingInput(text=sentences) + ).numpy() + + # create cheap&easy (but suboptimal) document embeddings for now + doc_emb = sentence_embs.sum(axis=0) + doc_emb /= np.linalg.norm(doc_emb) + + logger.debug( + f"Adding {len(sentence_embs)} sentences " + f"from SDoc {sdoc_id} in Project {proj_id} to Qdrant ..." + ) + sents = [ + { + "id": f"{proj_id}-{sdoc_id}-{sent_id}", + "project_id": proj_id, + "sdoc_id": sdoc_id, + "sentence_id": sent_id, + "text": sentences[sent_id], + "vec": sent_emb.tolist(), + } + for sent_id, sent_emb in enumerate(sentence_embs) + ] + res = self._client.collections[self._sentence_class_name].documents.import_( + sents, {"action": "create"} + ) + print(res) + print("added sentences to TS", len(sents)) + documents = [ + { + "id": f"{proj_id}-{sdoc_id}", + "project_id": proj_id, + "sdoc_id": sdoc_id, + "text": " ".join(sentences), + "vec": doc_emb.tolist(), + } + ] + self._client.collections[self._document_class_name].documents.import_( + documents, {"action": "create"} + ) + + def remove_text_sdoc_from_index(self, sdoc_id: int) -> None: + logger.debug(f"Removing text SDoc {sdoc_id} from Index!") + for name in self._colletions.keys(): + self._client.collections[name].documents.delete( + {"filter_by": f"sdoc_id:={sdoc_id}"} + ) + + def remove_all_project_embeddings( + self, + proj_id: int, + ) -> None: + for name in self._colletions.keys(): + self._client.collections[name].documents.delete( + {"filter_by": f"project_id:={proj_id}"} + ) + + def suggest_similar_sentences( + self, proj_id: int, sdoc_sent_ids: List[Tuple[int, int]] + ) -> List[SimSearchSentenceHit]: + return self.__suggest(proj_id, sdoc_sent_ids) + + def __suggest( + self, + proj_id: int, + sdoc_sent_ids: List[Tuple[int, int]], + top_k: int = 10, + threshold: float = 0.0, + ) -> List[SimSearchSentenceHit]: + candidates: List[SimSearchSentenceHit] = [] + vc = "vector_query" + queries = [ + {vc: f"vec:([], id: {proj_id}-{sdoc_id}-{sent_id}, k:1)"} + for sdoc_id, sent_id in sdoc_sent_ids + ] + + res = self._client.multi_search.perform( + {"searches": queries}, + { + "collection": self._sentence_class_name, + "q": "*", + "filter_by": f"project_id:= {proj_id}", + "include_fields": "id,sdoc_id,sentence_id", + }, + ) + + for r in res["results"]: + for hit in r["hits"]: + candidates.append( + SimSearchSentenceHit( + sdoc_id=hit["document"]["sdoc_id"], + score=hit["vector_distance"], + sentence_id=hit["document"]["sentence_id"], + ) + ) + + candidates.sort(key=lambda x: (x.sdoc_id, x.sentence_id)) + hits = self.__unique_consecutive(candidates) + hits = [h for h in hits if (h.sdoc_id, h.sentence_id) not in sdoc_sent_ids] + return hits + + def __unique_consecutive(self, hits: List[SimSearchSentenceHit]): + result = [] + current = SimSearchSentenceHit(sdoc_id=-1, sentence_id=-1, score=0.0) + for hit in hits: + if hit.sdoc_id != current.sdoc_id or hit.sentence_id != current.sentence_id: + current = hit + result.append(hit) + return result