Skip to content

Commit

Permalink
Add scrip to update version
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 committed Sep 13, 2023
1 parent 40f955d commit 00d9340
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 6 deletions.
4 changes: 4 additions & 0 deletions dataherald/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ def __init__(self, system: System):
def insert_one(self, collection: str, obj: dict) -> int:
pass

@abstractmethod
def rename(self, old_collection_name: str, new_collection_name) -> None:
pass

@abstractmethod
def update_or_create(self, collection: str, query: dict, obj: dict) -> int:
pass
Expand Down
4 changes: 4 additions & 0 deletions dataherald/db/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def find_one(self, collection: str, query: dict) -> dict:
def insert_one(self, collection: str, obj: dict) -> int:
return self._data_store[collection].insert_one(obj).inserted_id

@override
def rename(self, old_collection_name: str, new_collection_name) -> None:
self._data_store[old_collection_name].rename(new_collection_name)

@override
def update_or_create(self, collection: str, query: dict, obj: dict) -> int:
row = self.find_one(collection, query)
Expand Down
2 changes: 1 addition & 1 deletion dataherald/db_scanner/repository/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataherald.db_scanner.models.types import TableSchemaDetail

DB_COLLECTION = "table_schema_detail"
DB_COLLECTION = "table_descriptions"


class DBScannerRepository:
Expand Down
2 changes: 1 addition & 1 deletion dataherald/repositories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataherald.types import NLQueryResponse

DB_COLLECTION = "nl_query_response"
DB_COLLECTION = "nl_query_responses"


class NLQueryResponseRepository:
Expand Down
2 changes: 1 addition & 1 deletion dataherald/repositories/database_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataherald.sql_database.models.types import DatabaseConnection

DB_COLLECTION = "database_connection"
DB_COLLECTION = "database_connections"


class DatabaseConnectionRepository:
Expand Down
2 changes: 1 addition & 1 deletion dataherald/repositories/nl_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataherald.types import NLQuery

DB_COLLECTION = "nl_question"
DB_COLLECTION = "nl_questions"


class NLQuestionRepository:
Expand Down
Empty file added dataherald/scripts/__init__.py
Empty file.
69 changes: 69 additions & 0 deletions dataherald/scripts/migrate_v001_to_v002.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os

from sql_metadata import Parser

import dataherald.config
from dataherald.config import System
from dataherald.db import DB
from dataherald.vector_store import VectorStore


def add_db_connection_id(collection_name: str, storage) -> None:
collection_rows = storage.find_all(collection_name)
for collection_row in collection_rows:
if "db_alias" not in collection_row:
continue
database_connection = storage.find_one(
"database_connection", {"alias": collection_row["db_alias"]}
)
if not database_connection:
continue
collection_row["db_connection_id"] = str(database_connection["_id"])
# update object
storage.update_or_create(
collection_name, {"_id": collection_row["_id"]}, collection_row
)


if __name__ == "__main__":
settings = dataherald.config.Settings()
system = System(settings)
system.start()
storage = system.instance(DB)
# Update relations
add_db_connection_id("table_schema_detail", storage)
add_db_connection_id("golden_records", storage)
add_db_connection_id("nl_question", storage)
# Refresh vector stores
golden_record_collection = os.environ.get(
"GOLDEN_RECORD_COLLECTION", "dataherald-staging"
)
vector_store = system.instance(VectorStore)
try:
vector_store.delete_collection(golden_record_collection)
except Exception: # noqa: S110
pass
# Upload golden records
golden_records = storage.find_all("golden_records")
for golden_record in golden_records:
tables = Parser(golden_record["sql_query"]).tables
question = golden_record["question"]
vector_store.add_record(
documents=question,
collection=golden_record_collection,
metadata=[
{
"tables_used": tables[0],
"db_connection_id": golden_record["db_connection_id"],
}
], # this should be updated for multiple tables
ids=[str(golden_record["_id"])],
)
# Re-name collections
try:
storage.rename("nl_query_response", "nl_query_responses")
storage.rename("nl_question", "nl_questions")
storage.rename("database_connection", "database_connections")
storage.rename("table_schema_detail", "table_descriptions")
except Exception: # noqa: S110
pass
2 changes: 1 addition & 1 deletion dataherald/tests/db/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TestDB(DB):
def __init__(self, system: System):
super().__init__(system)
self.memory = {}
self.memory["database_connection"] = [
self.memory["database_connections"] = [
{
"_id": "64dfa0e103f5134086f7090c",
"alias": "alias",
Expand Down
3 changes: 2 additions & 1 deletion dataherald/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def test_heartbeat():

def test_scan_all_tables():
response = client.post(
"/api/v1/table-descriptions/scan", json={"db_connection_id": "64dfa0e103f5134086f7090c"}
"/api/v1/table-descriptions/scan",
json={"db_connection_id": "64dfa0e103f5134086f7090c"},
)
assert response.status_code == HTTP_200_CODE

Expand Down

0 comments on commit 00d9340

Please sign in to comment.