Skip to content

Commit

Permalink
Use DatabaseConnectionRepository
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 committed Sep 12, 2023
1 parent 60e9838 commit 40f171c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 20 deletions.
27 changes: 12 additions & 15 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,12 @@ def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse:
nl_question_repository = NLQuestionRepository(self.storage)
user_question = nl_question_repository.insert(user_question)

db_connection = self.storage.find_by_id(
"database_connection", question_request.db_connection_id
db_connection_repository = DatabaseConnectionRepository(self.storage)
database_connection = db_connection_repository.find_by_id(
question_request.db_connection_id
)
if not db_connection:
if not database_connection:
raise HTTPException(status_code=404, detail="Database connection not found")
database_connection = DatabaseConnection(**db_connection)
database_connection.id = question_request.db_connection_id
context = context_store.retrieve_context_for_question(user_question)
start_generated_answer = time.time()
try:
Expand Down Expand Up @@ -211,13 +210,12 @@ def add_golden_records(
@override
def execute_sql_query(self, query: Query) -> tuple[str, dict]:
"""Executes a SQL query against the database and returns the results"""
db_connection = self.storage.find_by_id(
"database_connection", query.db_connection_id
db_connection_repository = DatabaseConnectionRepository(self.storage)
database_connection = db_connection_repository.find_by_id(
query.db_connection_id
)
if not db_connection:
if not database_connection:
raise HTTPException(status_code=404, detail="Database connection not found")
database_connection = DatabaseConnection(**db_connection)
database_connection.id = query.db_connection_id
database = SQLDatabase.get_sql_engine(database_connection)
try:
result = database.run_sql(query.sql_query)
Expand All @@ -238,15 +236,14 @@ def update_nl_query_response(
if nl_query_response.sql_query.strip() != query.sql_query.strip():
nl_query_response.sql_query = query.sql_query
evaluator = self.system.instance(Evaluator)
db_connection = self.storage.find_by_id(
"database_connection", nl_question.db_connection_id
db_connection_repository = DatabaseConnectionRepository(self.storage)
database_connection = db_connection_repository.find_by_id(
nl_question.db_connection_id
)
if not db_connection:
if not database_connection:
raise HTTPException(
status_code=404, detail="Database connection not found"
)
database_connection = DatabaseConnection(**db_connection)
database_connection.id = nl_question.db_connection_id
try:
confidence_score = evaluator.get_confidence_score(
nl_question, nl_query_response, database_connection
Expand Down
9 changes: 4 additions & 5 deletions dataherald/sql_generator/generates_nl_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
)

from dataherald.model.chat_model import ChatModel
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.repositories.nl_question import NLQuestionRepository
from dataherald.sql_database.base import SQLDatabase
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.sql_generator.create_sql_query_status import create_sql_query_status
from dataherald.types import NLQueryResponse

Expand Down Expand Up @@ -36,11 +36,10 @@ def execute(self, nl_query_response: NLQueryResponse) -> NLQueryResponse:
nl_query_response.nl_question_id
)

db_connection = self.storage.find_by_id(
"database_connection", nl_question.db_connection_id
db_connection_repository = DatabaseConnectionRepository(self.storage)
database_connection = db_connection_repository.find_by_id(
nl_question.db_connection_id
)
database_connection = DatabaseConnection(**db_connection)
database_connection.id = nl_question.db_connection_id
database = SQLDatabase.get_sql_engine(database_connection)
nl_query_response = create_sql_query_status(
database, nl_query_response.sql_query, nl_query_response
Expand Down

0 comments on commit 40f171c

Please sign in to comment.