Skip to content

Commit

Permalink
Merge pull request #50 from SCAI-BIO/fix-restart
Browse files Browse the repository at this point in the history
Fix Weaviate Database Client Restart Error
  • Loading branch information
tiadams authored Nov 11, 2024
2 parents 6307836 + 5d0cd92 commit 0e4cbcd
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
27 changes: 24 additions & 3 deletions datastew/repository/weaviate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
import shutil
import socket

from typing import List, Tuple, Union, Optional

import weaviate
from weaviate import WeaviateClient
from weaviate.util import generate_uuid5
from weaviate.classes.query import Filter, QueryReference, MetadataQuery

Expand All @@ -17,15 +19,30 @@
class WeaviateRepository(BaseRepository):
logger = logging.getLogger(__name__)

def __init__(self, mode="memory", path=None, port=80):
def __init__(self, mode="memory", path=None, port=80, http_port=8079, grpc_port=50050):
self.mode = mode
self.client: Union[None, WeaviateClient] = None
try:
if mode == "memory":
self.client = weaviate.connect_to_embedded(persistence_data_path="db")
# Check if there is an existing instance of Weaviate client for the default ports
if self._is_port_in_use(http_port) and self._is_port_in_use(grpc_port):
# if the client exists, first close then re-connect
if self.client:
self.client.close()
self.client = weaviate.connect_to_local(port=http_port, grpc_port=grpc_port)
else:
self.client = weaviate.connect_to_embedded(persistence_data_path="db")
elif mode == "disk":
if path is None:
raise ValueError("Path must be provided for disk mode.")
self.client = weaviate.connect_to_embedded(persistence_data_path=path)
# Check if there is an existing instance of Weaviate client for the default ports
if self._is_port_in_use(http_port) and self._is_port_in_use(grpc_port):
# if the client exists, first close then re-connect
if self.client:
self.client.close()
self.client = weaviate.connect_to_local(port=http_port, grpc_port=grpc_port)
else:
self.client = weaviate.connect_to_embedded(persistence_data_path=path)
elif mode == "remote":
if path is None:
raise ValueError("Remote URL must be provided for remote mode.")
Expand Down Expand Up @@ -548,3 +565,7 @@ def _mapping_exists(self, embedding) -> bool:
return False
except Exception as e:
raise RuntimeError(f"Failed to check if mapping exists: {e}")

def _is_port_in_use(self, port) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
3 changes: 1 addition & 2 deletions tests/test_weaviate_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def test_import_data_dictionary(self):
self.assertEqual(mapping.concept.concept_identifier, f"import_test:{variable}")
self.assertEqual(mapping.sentence_embedder, "sentence-transformers/all-mpnet-base-v2")

@unittest.skip("currently broken on github workflows")
def test_repository_restart(self):
"""Test the repository restart functionality to ensure no data is lost or corrupted."""
# Re-initialize repository
Expand All @@ -153,4 +152,4 @@ def test_repository_restart(self):
self.assertEqual(len(mappings), 5)

concepts = repository.get_all_concepts()
self.assertEqual(len(concepts), 9)
self.assertEqual(len(concepts), 20)

0 comments on commit 0e4cbcd

Please sign in to comment.