Skip to content

Commit

Permalink
add: DataDictionarSource import function
Browse files Browse the repository at this point in the history
  • Loading branch information
mehmetcanay committed Oct 23, 2024
1 parent 55b7e7c commit 6229a90
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 3 deletions.
6 changes: 6 additions & 0 deletions datastew/repository/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from abc import ABC, abstractmethod
from typing import List, Optional

from datastew.embedding import EmbeddingModel
from datastew.process.parsing import DataDictionarySource
from datastew.repository.model import Mapping, Concept, Terminology


class BaseRepository(ABC):

@abstractmethod
def import_data_dictionary(self, data_dictionary: DataDictionarySource, terminology_name: str, embedding_model: Optional[EmbeddingModel] = None):
"""Store a data dictionary"""

@abstractmethod
def store(self, model_object_instance):
"""Store a single model object instance."""
Expand Down
30 changes: 30 additions & 0 deletions datastew/repository/sqllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from sqlalchemy import create_engine, func
from sqlalchemy.orm import sessionmaker

from datastew.embedding import EmbeddingModel
from datastew.process.parsing import DataDictionarySource
from datastew.repository.base import BaseRepository
from datastew.repository.model import Base, Concept, Mapping, Terminology

Expand Down Expand Up @@ -35,6 +37,34 @@ def store_all(self, model_object_instances: List[Union[Terminology, Concept, Map
self.session.add_all(model_object_instances)
self.session.commit()

def import_data_dictionary(self, data_dictionary: DataDictionarySource, terminology_name: str, embedding_model: Optional[EmbeddingModel] = None):
terminology = Terminology(terminology_name, terminology_name)
self.store(terminology)
data_frame = data_dictionary.to_dataframe()
descriptions = data_frame["description"].tolist()

if embedding_model is None:
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
else:
embedding_model_name = embedding_model.get_model_name()

variable_to_embedding = data_dictionary.get_embeddings(embedding_model)

for variable, description in zip(variable_to_embedding.keys(), descriptions):
concept_id = f"{terminology_name}:{variable}"
concept = Concept(
terminology=terminology,
pref_label=variable,
concept_identifier=concept_id
)
mapping = Mapping(
concept=concept,
text=description,
embedding=variable_to_embedding[variable],
sentence_embedder=embedding_model_name
)
self.store_all([concept, mapping])

def get_all_concepts(self) -> List[Concept]:
return self.session.query(Concept).all()

Expand Down
34 changes: 32 additions & 2 deletions datastew/repository/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from weaviate.util import generate_uuid5
from weaviate.classes.query import Filter, QueryReference, MetadataQuery

from datastew.embedding import EmbeddingModel
from datastew.process.parsing import DataDictionarySource
from datastew.repository import Concept, Mapping, Terminology
from datastew.repository.base import BaseRepository
from datastew.repository.weaviate_schema import concept_schema, mapping_schema, terminology_schema
Expand Down Expand Up @@ -62,7 +64,35 @@ def _create_schema_if_not_exists(self, schema):
except Exception as e:
raise RuntimeError(f"Failed to check/create schema for {class_name}: {e}")

def store_all(self, model_object_instances):
def import_data_dictionary(self, data_dictionary: DataDictionarySource, terminology_name: str, embedding_model: Optional[EmbeddingModel] = None):
terminology = Terminology(terminology_name, terminology_name)
self.store(terminology)
data_frame = data_dictionary.to_dataframe()
descriptions = data_frame["description"].tolist()

if embedding_model is None:
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
else:
embedding_model_name = embedding_model.get_model_name()

variable_to_embedding = data_dictionary.get_embeddings(embedding_model)

for variable, description in zip(variable_to_embedding.keys(), descriptions):
concept_id = f"{terminology_name}:{variable}"
concept = Concept(
terminology=terminology,
pref_label=variable,
concept_identifier=concept_id
)
mapping = Mapping(
concept=concept,
text=description,
embedding=variable_to_embedding[variable],
sentence_embedder=embedding_model_name
)
self.store_all([concept, mapping])

def store_all(self, model_object_instances: List[Union[Terminology, Concept, Mapping]]):
for instance in model_object_instances:
self.store(instance)

Expand Down Expand Up @@ -94,7 +124,7 @@ def get_concept(self, concept_id: str) -> Concept:
terminology_id = str(terminology_data.uuid)
terminology = Terminology(terminology_name, terminology_id)

id = concept_data.uuid
id = str(concept_data.uuid)
concept_name = str(concept_data.properties["prefLabel"])
concept = Concept(terminology, concept_name, concept_id, id)
except Exception as e:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_sql_repository.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import unittest
import os

from datastew.process.parsing import DataDictionarySource
from datastew.repository.model import Terminology, Concept, Mapping
from datastew.repository.sqllite import SQLLiteRepository


class TestGetClosestEmbedding(unittest.TestCase):

TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))

def setUp(self):
self.repository = SQLLiteRepository(mode="memory")

Expand Down Expand Up @@ -52,3 +56,11 @@ def test_get_all_sentence_embedders(self):
self.assertEqual(len(sentence_embedders), 2)
self.assertEqual(sentence_embedders[0], "sentence-transformers/all-mpnet-base-v2")
self.assertEqual(sentence_embedders[1], "text-embedding-ada-002")

def test_import_data_dictionary(self):
data_dictionary_source = DataDictionarySource(os.path.join(self.TEST_DIR_PATH, "resources", "test_data_dict.csv"), "VAR_1", "DESC")
self.repository.import_data_dictionary(data_dictionary_source, terminology_name="import_test")
terminologies = [terminology.name for terminology in self.repository.get_all_terminologies()]
concept_identifiers = [concept.concept_identifier for concept in self.repository.get_all_concepts()]
self.assertIn("import_test", terminologies)
self.assertIn("import_test:A", concept_identifiers)
15 changes: 14 additions & 1 deletion tests/test_weaviate_repository.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import unittest
from unittest import TestCase

from datastew import MPNetAdapter
from datastew.process.parsing import DataDictionarySource
from datastew.repository import Terminology, Concept, Mapping
from datastew.repository.weaviate import WeaviateRepository

class TestWeaviateRepository(TestCase):

TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))

@classmethod
def setUpClass(cls):
"""Set up reusable components for the tests."""
Expand Down Expand Up @@ -65,7 +69,7 @@ def test_terminology_retrieval(self):
terminologies = self.repository.get_all_terminologies()
terminology_names = [t.name for t in terminologies]
self.assertEqual(terminology.name, "snomed CT")
self.assertEqual(len(terminologies), 2)
self.assertEqual(len(terminologies), 3)
self.assertIn("NCI Thesaurus OBO Edition", terminology_names)
self.assertIn("snomed CT", terminology_names)

Expand Down Expand Up @@ -112,6 +116,15 @@ def test_terminology_and_model_specific_mappings_with_similarities(self):
self.assertEqual(specific_mappings_with_similarities[0][0].sentence_embedder, self.model_name1)
self.assertAlmostEqual(specific_mappings_with_similarities[0][1], 0.3947341, 3)

def test_import_data_dictionary(self):
"""Test importing a data dictionary."""
data_dictionary_source = DataDictionarySource(os.path.join(self.TEST_DIR_PATH, "resources", "test_data_dict.csv"), "VAR_1", "DESC")
self.repository.import_data_dictionary(data_dictionary_source, terminology_name="import_test")
terminology = self.repository.get_terminology("import_test")
concept = self.repository.get_concept("import_test:A")
self.assertEqual("import_test", terminology.name)
self.assertEqual("import_test:A", concept.concept_identifier)

@unittest.skip("currently broken on github workflows")
def test_repository_restart(self):
"""Test the repository restart functionality to ensure no data is lost or corrupted."""
Expand Down

0 comments on commit 6229a90

Please sign in to comment.