-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
from typing import List, Tuple | ||
|
||
import numpy as np | ||
from qdrant_client import QdrantClient | ||
from qdrant_client.conversions.common_types import VectorParams, Distance | ||
from app.core.data.dto.search import SimSearchSentenceHit | ||
from app.preprocessing.ray_model_service import RayModelService | ||
from app.preprocessing.ray_model_worker.dto.clip import ClipTextEmbeddingInput | ||
from app.util.singleton_meta import SingletonMeta | ||
from config import conf | ||
from loguru import logger | ||
|
||
|
||
class QdrantService(metaclass=SingletonMeta): | ||
def __new__(cls, *args, **kwargs): | ||
cls._sentence_class_name = "Sentence" | ||
cls._document_class_name = "Document" | ||
cls._colletions = [cls._sentence_class_name, cls._document_class_name] | ||
|
||
try: | ||
cls._client = QdrantClient( | ||
host=conf.qdrant.host, | ||
port=conf.qdrant.port, | ||
grpc_port=conf.qdrant.crpc_port, | ||
prefer_grpc=True, | ||
) | ||
collections = {c[0] for c in cls._client.get_collections()} | ||
if kwargs["flush"] if "flush" in kwargs else False: | ||
logger.warning("Flushing DWTS Qdrant Data!") | ||
for c in collections: | ||
cls._client.delete_collection(c) | ||
for name in cls._colletions.items(): | ||
if name not in collections: | ||
cls._client.create_collection( | ||
name, | ||
vectors_config=VectorParams(size=512, distance=Distance.COSINE), | ||
) | ||
|
||
except Exception as e: | ||
msg = f"Cannot connect or initialize to Qdrant DB - Error '{e}'" | ||
logger.error(msg) | ||
raise SystemExit(msg) | ||
|
||
cls.rms = RayModelService() | ||
|
||
return super(QdrantService, cls).__new__(cls) | ||
|
||
def add_text_sdoc_to_index( | ||
self, | ||
proj_id: int, | ||
sdoc_id: int, | ||
sentences: List[str], | ||
) -> None: | ||
sentence_embs = self.rms.clip_text_embedding( | ||
ClipTextEmbeddingInput(text=sentences) | ||
).numpy() | ||
|
||
# create cheap&easy (but suboptimal) document embeddings for now | ||
doc_emb = sentence_embs.sum(axis=0) | ||
doc_emb /= np.linalg.norm(doc_emb) | ||
|
||
logger.debug( | ||
f"Adding {len(sentence_embs)} sentences " | ||
f"from SDoc {sdoc_id} in Project {proj_id} to Qdrant ..." | ||
) | ||
sents = [ | ||
{ | ||
"id": f"{proj_id}-{sdoc_id}-{sent_id}", | ||
"project_id": proj_id, | ||
"sdoc_id": sdoc_id, | ||
"sentence_id": sent_id, | ||
"text": sentences[sent_id], | ||
"vec": sent_emb.tolist(), | ||
} | ||
for sent_id, sent_emb in enumerate(sentence_embs) | ||
] | ||
res = self._client.collections[self._sentence_class_name].documents.import_( | ||
sents, {"action": "create"} | ||
) | ||
print(res) | ||
print("added sentences to TS", len(sents)) | ||
documents = [ | ||
{ | ||
"id": f"{proj_id}-{sdoc_id}", | ||
"project_id": proj_id, | ||
"sdoc_id": sdoc_id, | ||
"text": " ".join(sentences), | ||
"vec": doc_emb.tolist(), | ||
} | ||
] | ||
self._client.collections[self._document_class_name].documents.import_( | ||
documents, {"action": "create"} | ||
) | ||
|
||
def remove_text_sdoc_from_index(self, sdoc_id: int) -> None: | ||
logger.debug(f"Removing text SDoc {sdoc_id} from Index!") | ||
for name in self._colletions.keys(): | ||
self._client.collections[name].documents.delete( | ||
{"filter_by": f"sdoc_id:={sdoc_id}"} | ||
) | ||
|
||
def remove_all_project_embeddings( | ||
self, | ||
proj_id: int, | ||
) -> None: | ||
for name in self._colletions.keys(): | ||
self._client.collections[name].documents.delete( | ||
{"filter_by": f"project_id:={proj_id}"} | ||
) | ||
|
||
def suggest_similar_sentences( | ||
self, proj_id: int, sdoc_sent_ids: List[Tuple[int, int]] | ||
) -> List[SimSearchSentenceHit]: | ||
return self.__suggest(proj_id, sdoc_sent_ids) | ||
|
||
def __suggest( | ||
self, | ||
proj_id: int, | ||
sdoc_sent_ids: List[Tuple[int, int]], | ||
top_k: int = 10, | ||
threshold: float = 0.0, | ||
) -> List[SimSearchSentenceHit]: | ||
candidates: List[SimSearchSentenceHit] = [] | ||
vc = "vector_query" | ||
queries = [ | ||
{vc: f"vec:([], id: {proj_id}-{sdoc_id}-{sent_id}, k:1)"} | ||
for sdoc_id, sent_id in sdoc_sent_ids | ||
] | ||
|
||
res = self._client.multi_search.perform( | ||
{"searches": queries}, | ||
{ | ||
"collection": self._sentence_class_name, | ||
"q": "*", | ||
"filter_by": f"project_id:= {proj_id}", | ||
"include_fields": "id,sdoc_id,sentence_id", | ||
}, | ||
) | ||
|
||
for r in res["results"]: | ||
for hit in r["hits"]: | ||
candidates.append( | ||
SimSearchSentenceHit( | ||
sdoc_id=hit["document"]["sdoc_id"], | ||
score=hit["vector_distance"], | ||
sentence_id=hit["document"]["sentence_id"], | ||
) | ||
) | ||
|
||
candidates.sort(key=lambda x: (x.sdoc_id, x.sentence_id)) | ||
hits = self.__unique_consecutive(candidates) | ||
hits = [h for h in hits if (h.sdoc_id, h.sentence_id) not in sdoc_sent_ids] | ||
return hits | ||
|
||
def __unique_consecutive(self, hits: List[SimSearchSentenceHit]): | ||
result = [] | ||
current = SimSearchSentenceHit(sdoc_id=-1, sentence_id=-1, score=0.0) | ||
for hit in hits: | ||
if hit.sdoc_id != current.sdoc_id or hit.sentence_id != current.sentence_id: | ||
current = hit | ||
result.append(hit) | ||
return result |