Skip to content

Commit

Permalink
Merge pull request #31 from SCAI-BIO/re-enable-weaviate-tests
Browse files Browse the repository at this point in the history
test: weaviate tests
  • Loading branch information
tiadams authored Oct 17, 2024
2 parents 6b65aa0 + b11e4a3 commit 2d6bcb3
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 101 deletions.
29 changes: 25 additions & 4 deletions datastew/repository/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,9 @@ def store(self, model_object_instance: Union[Terminology, Concept, Mapping]):
elif isinstance(model_object_instance, Concept):
model_object_instance.uuid = random_uuid
if not self._concept_exists(model_object_instance.concept_identifier):
# recursion: create terminology if not existing
if not self._terminology_exists(model_object_instance.terminology.name):
self.store(model_object_instance.terminology)
properties = {
"conceptID": model_object_instance.concept_identifier,
"prefLabel": model_object_instance.pref_label,
Expand All @@ -424,6 +427,8 @@ def store(self, model_object_instance: Union[Terminology, Concept, Mapping]):
f"already exists. Skipping.")
elif isinstance(model_object_instance, Mapping):
if not self._mapping_exists(model_object_instance.embedding):
if not self._concept_exists(model_object_instance.concept.concept_identifier):
self.store(model_object_instance.concept)
properties = {
"text": model_object_instance.text,
"hasSentenceEmbedder": model_object_instance.sentence_embedder
Expand Down Expand Up @@ -459,7 +464,11 @@ def _sentence_embedder_exists(self, name: str) -> bool:
"operator": "Equal",
"valueText": name
}).do()
return len(result["data"]["Get"]["Mapping"]) > 0
result_data = result["data"]["Get"]["Mapping"]
if result_data is not None:
return len(result_data) > 0
else:
return False
except Exception as e:
raise RuntimeError(f"Failed to check if sentence embedder exists: {e}")

Expand All @@ -470,7 +479,11 @@ def _terminology_exists(self, name: str) -> bool:
"operator": "Equal",
"valueText": name
}).do()
return len(result["data"]["Get"]["Terminology"]) > 0
result_data = result["data"]["Get"]["Terminology"]
if result_data is not None:
return len(result_data) > 0
else:
return False
except Exception as e:
raise RuntimeError(f"Failed to check if terminology exists: {e}")

Expand All @@ -481,7 +494,11 @@ def _concept_exists(self, concept_id: str) -> bool:
"operator": "Equal",
"valueText": concept_id
}).do()
return len(result["data"]["Get"]["Concept"]) > 0
result_data = result["data"]["Get"]["Concept"]
if result_data is not None:
return len(result_data) > 0
else:
return False
except Exception as e:
raise RuntimeError(f"Failed to check if concept exists: {e}")

Expand All @@ -491,6 +508,10 @@ def _mapping_exists(self, embedding) -> bool:
"vector": embedding,
"distance": float(0) # Ensure distance is explicitly casted to float
}).do()
return len(result["data"]["Get"]["Mapping"]) > 0
result_data = result["data"]["Get"]["Mapping"]
if result_data is not None:
return len(result_data) > 0
else:
return False
except Exception as e:
raise RuntimeError(f"Failed to check if mapping exists: {e}")
198 changes: 101 additions & 97 deletions tests/test_weaviate_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,121 +5,125 @@
from datastew.repository import Terminology, Concept, Mapping
from datastew.repository.weaviate import WeaviateRepository


class Test(TestCase):
@unittest.skip("currently broken on github workflows")
def test_repository(self):

repository = WeaviateRepository(mode="disk", path="db")

embedding_model1 = MPNetAdapter()
embedding_model2 = MPNetAdapter("FremyCompany/BioLORD-2023")
model_name1 = embedding_model1.get_model_name()
model_name2 = embedding_model2.get_model_name()

terminology1 = Terminology("snomed CT", "SNOMED")
terminology2 = Terminology("NCI Thesaurus OBO Edition", "NCIT")

text1 = "Diabetes mellitus (disorder)"
concept1 = Concept(terminology1, text1, "Concept ID: 11893007")
mapping1 = Mapping(concept1, text1, embedding_model1.get_embedding(text1), model_name1)

text2 = "Hypertension (disorder)"
concept2 = Concept(terminology1, text2, "Concept ID: 73211009")
mapping2 = Mapping(concept2, text2, embedding_model2.get_embedding(text2), model_name2)

text3 = "Asthma"
concept3 = Concept(terminology1, text3, "Concept ID: 195967001")
mapping3 = Mapping(concept3, text3, embedding_model1.get_embedding(text3), model_name1)

text4 = "Heart attack"
concept4 = Concept(terminology1, text4, "Concept ID: 22298006")
mapping4 = Mapping(concept4, text4, embedding_model2.get_embedding(text4), model_name2)

text5 = "Common cold"
concept5 = Concept(terminology2, text5, "Concept ID: 13260007")
mapping5 = Mapping(concept5, text5, embedding_model1.get_embedding(text5), model_name1)

text6 = "Stroke"
concept6 = Concept(terminology2, text6, "Concept ID: 422504002")
mapping6 = Mapping(concept6, text6, embedding_model2.get_embedding(text6), model_name2)

text7 = "Migraine"
concept7 = Concept(terminology2, text7, "Concept ID: 386098009")
mapping7 = Mapping(concept7, text7, embedding_model1.get_embedding(text7), model_name1)

text8 = "Influenza"
concept8 = Concept(terminology2, text8, "Concept ID: 57386000")
mapping8 = Mapping(concept8, text8, embedding_model2.get_embedding(text8), model_name2)

text9 = "Osteoarthritis"
concept9 = Concept(terminology2, text9, "Concept ID: 399206004")
mapping9 = Mapping(concept9, text9, embedding_model1.get_embedding(text9), model_name1)

text10 = "The flu"

repository.store_all([
terminology1, terminology2, concept1, mapping1, concept2, mapping2, concept3, mapping3, concept4,
mapping4, concept5, mapping5, concept6, mapping6, concept7, mapping7, concept8, mapping8, concept9,
mapping9
])

mappings = repository.get_mappings(limit=5)
class TestWeaviateRepository(TestCase):

@classmethod
def setUpClass(cls):
"""Set up reusable components for the tests."""
cls.repository = WeaviateRepository(mode="disk", path="db")
cls.embedding_model1 = MPNetAdapter()
cls.embedding_model2 = MPNetAdapter("FremyCompany/BioLORD-2023")
cls.model_name1 = cls.embedding_model1.get_model_name()
cls.model_name2 = cls.embedding_model2.get_model_name()

# Terminologies
cls.terminology1 = Terminology("snomed CT", "SNOMED")
cls.terminology2 = Terminology("NCI Thesaurus OBO Edition", "NCIT")

# Concepts and mappings
cls.concepts_mappings = [
cls._create_mapping(cls.terminology1, "Diabetes mellitus (disorder)", "Concept ID: 11893007", cls.embedding_model1),
cls._create_mapping(cls.terminology1, "Hypertension (disorder)", "Concept ID: 73211009", cls.embedding_model2),
cls._create_mapping(cls.terminology1, "Asthma", "Concept ID: 195967001", cls.embedding_model1),
cls._create_mapping(cls.terminology1, "Heart attack", "Concept ID: 22298006", cls.embedding_model2),
cls._create_mapping(cls.terminology2, "Common cold", "Concept ID: 13260007", cls.embedding_model1),
cls._create_mapping(cls.terminology2, "Stroke", "Concept ID: 422504002", cls.embedding_model2),
cls._create_mapping(cls.terminology2, "Migraine", "Concept ID: 386098009", cls.embedding_model1),
cls._create_mapping(cls.terminology2, "Influenza", "Concept ID: 57386000", cls.embedding_model2),
cls._create_mapping(cls.terminology2, "Osteoarthritis", "Concept ID: 399206004", cls.embedding_model1),
]

cls.test_text = "The flu"

# Store terminologies, concepts, and mappings in the repository
cls.repository.store_all([cls.terminology1, cls.terminology2] + [item[0] for item in cls.concepts_mappings] + [item[1] for item in cls.concepts_mappings])

@staticmethod
def _create_mapping(terminology, text, concept_id, embedding_model):
"""Helper function to create a concept and mapping."""
concept = Concept(terminology, text, concept_id)
mapping = Mapping(concept, text, embedding_model.get_embedding(text), embedding_model.get_model_name())
return concept, mapping

def test_store_and_retrieve_mappings(self):
"""Test storing and retrieving mappings from the repository."""
mappings = self.repository.get_mappings(limit=5)
self.assertEqual(len(mappings), 5)

concepts = repository.get_all_concepts()
concept = repository.get_concept("Concept ID: 11893007")
def test_concept_retrieval(self):
"""Test retrieval of individual and all concepts."""
concepts = self.repository.get_all_concepts()
concept = self.repository.get_concept("Concept ID: 11893007")
self.assertEqual(concept.concept_identifier, "Concept ID: 11893007")
self.assertEqual(concept.pref_label, text1)
self.assertEqual(concept.terminology.name, terminology1.name)
self.assertEqual(concept.pref_label, "Diabetes mellitus (disorder)")
self.assertEqual(concept.terminology.name, "snomed CT")
self.assertEqual(len(concepts), 9)

terminology = repository.get_terminology("snomed CT")
terminologies = repository.get_all_terminologies()
terminology_names = [embedding.name for embedding in terminologies]
def test_terminology_retrieval(self):
"""Test retrieval of individual and all terminologies."""
terminology = self.repository.get_terminology("snomed CT")
terminologies = self.repository.get_all_terminologies()
terminology_names = [t.name for t in terminologies]
self.assertEqual(terminology.name, "snomed CT")
self.assertEqual(len(terminologies), 2)
self.assertIn("NCI Thesaurus OBO Edition", terminology_names)
self.assertIn("snomed CT", terminology_names)

sentence_embedders = repository.get_all_sentence_embedders()
def test_sentence_embedders(self):
"""Test retrieval of sentence embedders from the repository."""
sentence_embedders = self.repository.get_all_sentence_embedders()
self.assertEqual(len(sentence_embedders), 2)
self.assertIn(model_name1, sentence_embedders)
self.assertIn(model_name2, sentence_embedders)

test_embedding = embedding_model1.get_embedding(text10)
self.assertIn(self.model_name1, sentence_embedders)
self.assertIn(self.model_name2, sentence_embedders)

closest_mappings = repository.get_closest_mappings(test_embedding)
def test_closest_mappings(self):
"""Test retrieval of the closest mappings based on a test embedding."""
test_embedding = self.embedding_model1.get_embedding(self.test_text)
closest_mappings = self.repository.get_closest_mappings(test_embedding)
self.assertEqual(len(closest_mappings), 5)
self.assertEqual(closest_mappings[0].text, "Common cold")
self.assertEqual(closest_mappings[0].sentence_embedder, model_name1)

terminology_and_model_specific_closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(test_embedding, "snomed CT", model_name1)
self.assertEqual(len(terminology_and_model_specific_closest_mappings), 2)
self.assertEqual(terminology_and_model_specific_closest_mappings[0].text, "Asthma")
self.assertEqual(terminology_and_model_specific_closest_mappings[0].concept.terminology.name, "snomed CT")
self.assertEqual(terminology_and_model_specific_closest_mappings[0].sentence_embedder, model_name1)

closest_mappings_with_similarities = repository.get_closest_mappings_with_similarities(test_embedding)
self.assertEqual(closest_mappings[0].sentence_embedder, self.model_name1)

def test_terminology_and_model_specific_mappings(self):
"""Test retrieval of mappings filtered by terminology and model."""
test_embedding = self.embedding_model1.get_embedding(self.test_text)
specific_mappings = self.repository.get_terminology_and_model_specific_closest_mappings(test_embedding, "snomed CT", self.model_name1)
self.assertEqual(len(specific_mappings), 2)
self.assertEqual(specific_mappings[0].text, "Asthma")
self.assertEqual(specific_mappings[0].concept.terminology.name, "snomed CT")
self.assertEqual(specific_mappings[0].sentence_embedder, self.model_name1)

def test_closest_mappings_with_similarities(self):
"""Test retrieval of closest mappings with similarity scores."""
test_embedding = self.embedding_model1.get_embedding(self.test_text)
closest_mappings_with_similarities = self.repository.get_closest_mappings_with_similarities(test_embedding)
self.assertEqual(len(closest_mappings_with_similarities), 5)
self.assertEqual(closest_mappings_with_similarities[0][0].text, "Common cold")
self.assertEqual(closest_mappings_with_similarities[0][0].sentence_embedder, model_name1)
self.assertEqual(closest_mappings_with_similarities[0][0].sentence_embedder, self.model_name1)
self.assertAlmostEqual(closest_mappings_with_similarities[0][1], 0.6747197, 3)

terminology_and_model_specific_closest_mappings_with_similarities = repository.get_terminology_and_model_specific_closest_mappings_with_similarities(test_embedding, "snomed CT", model_name1)
self.assertEqual(len(terminology_and_model_specific_closest_mappings_with_similarities), 2)
self.assertEqual(terminology_and_model_specific_closest_mappings_with_similarities[0][0].text, "Asthma")
self.assertEqual(terminology_and_model_specific_closest_mappings_with_similarities[0][0].concept.terminology.name, "snomed CT")
self.assertEqual(terminology_and_model_specific_closest_mappings_with_similarities[0][0].sentence_embedder, model_name1)
self.assertAlmostEqual(terminology_and_model_specific_closest_mappings_with_similarities[0][1], 0.3947341, 3)
def test_terminology_and_model_specific_mappings_with_similarities(self):
"""Test retrieval of terminology and model-specific mappings with similarity scores."""
test_embedding = self.embedding_model1.get_embedding(self.test_text)
specific_mappings_with_similarities = self.repository.get_terminology_and_model_specific_closest_mappings_with_similarities(test_embedding, "snomed CT", self.model_name1)
self.assertEqual(len(specific_mappings_with_similarities), 2)
self.assertEqual(specific_mappings_with_similarities[0][0].text, "Asthma")
self.assertEqual(specific_mappings_with_similarities[0][0].concept.terminology.name, "snomed CT")
self.assertEqual(specific_mappings_with_similarities[0][0].sentence_embedder, self.model_name1)
self.assertAlmostEqual(specific_mappings_with_similarities[0][1], 0.3947341, 3)

# check if it crashed (due to schema re-creation) after restart
@unittest.skip("currently broken on github workflows")
def test_repository_restart(self):
"""Test the repository restart functionality to ensure no data is lost or corrupted."""
# Re-initialize repository
repository = WeaviateRepository(mode="disk", path="db")

# Try storing the same data again (should not create duplicates)
repository.store_all([self.terminology1, self.terminology2] + [item[0] for item in self.concepts_mappings] + [item[1] for item in self.concepts_mappings])

# Check if mappings and concepts are intact
mappings = repository.get_mappings(limit=5)
self.assertEqual(len(mappings), 5)

# try to store all again (should not create new entries since they already exist)
repository.store_all([
terminology1, terminology2, concept1, mapping1, concept2, mapping2, concept3, mapping3,
concept4, mapping4, concept5, mapping5, concept6, mapping6, concept7, mapping7, concept8,
mapping8, concept9, mapping9
])

concepts = repository.get_all_concepts()
self.assertEqual(len(concepts), 9)

0 comments on commit 2d6bcb3

Please sign in to comment.