Skip to content

Commit

Permalink
Merge pull request #11 from SCAI-BIO/weaviate
Browse files Browse the repository at this point in the history
Add function to additionally retrieve similarities together with clos…
  • Loading branch information
tiadams authored Jul 15, 2024
2 parents d0b3891 + ddf52f8 commit c499c03
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
42 changes: 38 additions & 4 deletions datastew/repository/weaviate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import shutil
from typing import List, Union
from typing import List, Union, Tuple
import uuid as uuid
import weaviate
from weaviate.embedded import EmbeddedOptions
Expand All @@ -11,7 +11,6 @@


class WeaviateRepository(BaseRepository):

logger = logging.getLogger(__name__)

def __init__(self, mode="memory", path=None):
Expand Down Expand Up @@ -101,7 +100,8 @@ def get_all_mappings(self, limit=1000) -> List[Mapping]:
try:
result = self.client.query.get(
"Mapping",
["text", "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"]
["text",
"hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"]
).with_additional("vector").with_limit(limit).do()
for item in result['data']['Get']['Mapping']:
embedding_vector = item["_additional"]["vector"]
Expand Down Expand Up @@ -132,7 +132,8 @@ def get_closest_mappings(self, embedding, limit=5) -> List[Mapping]:
try:
result = self.client.query.get(
"Mapping",
["text", "_additional { distance }", "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"]
["text", "_additional { distance }",
"hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"]
).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do()
for item in result['data']['Get']['Mapping']:
embedding_vector = item["_additional"]["vector"]
Expand All @@ -158,6 +159,39 @@ def get_closest_mappings(self, embedding, limit=5) -> List[Mapping]:
raise RuntimeError(f"Failed to fetch closest mappings: {e}")
return mappings

def get_closest_mappings_with_similarities(self, embedding, limit=5) -> List[Tuple[Mapping, float]]:
mappings_with_similarities = []
try:
result = self.client.query.get(
"Mapping",
["text", "_additional { distance }",
"hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"]
).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do()
for item in result['data']['Get']['Mapping']:
similarity = 1 - item["_additional"]["distance"]
embedding_vector = item["_additional"]["vector"]
concept_data = item["hasConcept"][0] # Assuming it has only one concept
terminology_data = concept_data["hasTerminology"][0] # Assuming it has only one terminology
terminology = Terminology(
name=terminology_data["name"],
id=terminology_data["_additional"]["id"]
)
concept = Concept(
concept_identifier=concept_data["conceptID"],
pref_label=concept_data["prefLabel"],
terminology=terminology,
id=concept_data["_additional"]["id"]
)
mapping = Mapping(
text=item["text"],
concept=concept,
embedding=embedding_vector
)
mappings_with_similarities.append((mapping, similarity))
except Exception as e:
raise RuntimeError(f"Failed to fetch closest mappings with similarities: {e}")
return mappings_with_similarities

def shut_down(self):
if self.mode == "memory":
shutil.rmtree("db")
Expand Down
14 changes: 8 additions & 6 deletions tests/test_weaviate_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,17 @@ def test_repository(self):
terminologies = repository.get_all_terminologies()
self.assertEqual(len(terminologies), 1)

closest_mappings = repository.get_closest_mappings(embedding_model.get_embedding(text10))
test_embedding = embedding_model.get_embedding(text10)

closest_mappings = repository.get_closest_mappings(test_embedding)
self.assertEqual(len(closest_mappings), 5)
self.assertEqual(closest_mappings[0].text, "Influenza")

closest_mappings_with_similarities = repository.get_closest_mappings_with_similarities(test_embedding)
self.assertEqual(len(closest_mappings_with_similarities), 5)
self.assertEqual(closest_mappings_with_similarities[0][0].text, "Influenza")
self.assertEqual(closest_mappings_with_similarities[0][1], 0.86187172)

# check if it crashed (due to schema re-creation) after restart
repository = WeaviateRepository(mode="disk", path="db")

Expand All @@ -82,8 +89,3 @@ def test_repository(self):
concept8, mapping8, concept9, mapping9
])






0 comments on commit c499c03

Please sign in to comment.