diff --git a/datastew/repository/weaviate.py b/datastew/repository/weaviate.py index 12db520..57edb56 100644 --- a/datastew/repository/weaviate.py +++ b/datastew/repository/weaviate.py @@ -1,9 +1,11 @@ import logging import shutil +import socket from typing import List, Tuple, Union, Optional import weaviate +from weaviate import WeaviateClient from weaviate.util import generate_uuid5 from weaviate.classes.query import Filter, QueryReference, MetadataQuery @@ -17,15 +19,30 @@ class WeaviateRepository(BaseRepository): logger = logging.getLogger(__name__) - def __init__(self, mode="memory", path=None, port=80): + def __init__(self, mode="memory", path=None, port=80, http_port=8079, grpc_port=50050): self.mode = mode + self.client: Union[None, WeaviateClient] = None try: if mode == "memory": - self.client = weaviate.connect_to_embedded(persistence_data_path="db") + # Check if there is an existing instance of Weaviate client for the default ports + if self._is_port_in_use(http_port) and self._is_port_in_use(grpc_port): + # if the client exists, first close then re-connect + if self.client: + self.client.close() + self.client = weaviate.connect_to_local(port=http_port, grpc_port=grpc_port) + else: + self.client = weaviate.connect_to_embedded(persistence_data_path="db") elif mode == "disk": if path is None: raise ValueError("Path must be provided for disk mode.") - self.client = weaviate.connect_to_embedded(persistence_data_path=path) + # Check if there is an existing instance of Weaviate client for the default ports + if self._is_port_in_use(http_port) and self._is_port_in_use(grpc_port): + # if the client exists, first close then re-connect + if self.client: + self.client.close() + self.client = weaviate.connect_to_local(port=http_port, grpc_port=grpc_port) + else: + self.client = weaviate.connect_to_embedded(persistence_data_path=path) elif mode == "remote": if path is None: raise ValueError("Remote URL must be provided for remote mode.") @@ -548,3 +565,7 @@ def _mapping_exists(self, embedding) -> bool: return False except Exception as e: raise RuntimeError(f"Failed to check if mapping exists: {e}") + + def _is_port_in_use(self, port) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 diff --git a/tests/test_weaviate_repository.py b/tests/test_weaviate_repository.py index 51adf1a..85c92e9 100644 --- a/tests/test_weaviate_repository.py +++ b/tests/test_weaviate_repository.py @@ -139,7 +139,6 @@ def test_import_data_dictionary(self): self.assertEqual(mapping.concept.concept_identifier, f"import_test:{variable}") self.assertEqual(mapping.sentence_embedder, "sentence-transformers/all-mpnet-base-v2") - @unittest.skip("currently broken on github workflows") def test_repository_restart(self): """Test the repository restart functionality to ensure no data is lost or corrupted.""" # Re-initialize repository @@ -153,4 +152,4 @@ def test_repository_restart(self): self.assertEqual(len(mappings), 5) concepts = repository.get_all_concepts() - self.assertEqual(len(concepts), 9) \ No newline at end of file + self.assertEqual(len(concepts), 20) \ No newline at end of file