Skip to content

Commit

Permalink
Fix nl-query-responses endpoints and sql-query-executions endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 committed Sep 12, 2023
1 parent afca33e commit 2b4f81b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 22 deletions.
8 changes: 5 additions & 3 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,17 @@ def add_golden_records(
pass

@abstractmethod
def execute_query(self, query: Query) -> tuple[str, dict]:
def execute_sql_query(self, query: Query) -> tuple[str, dict]:
pass

@abstractmethod
def update_query(self, query_id: str, query: UpdateQueryRequest) -> NLQueryResponse:
def update_nl_query_response(
self, query_id: str, query: UpdateQueryRequest
) -> NLQueryResponse:
pass

@abstractmethod
def execute_temp_query(
def get_nl_query_response(
self, query_id: str, query: ExecuteTempQueryRequest
) -> NLQueryResponse:
pass
Expand Down
6 changes: 3 additions & 3 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def add_golden_records(
return context_store.add_golden_records(golden_records)

@override
def execute_query(self, query: Query) -> tuple[str, dict]:
def execute_sql_query(self, query: Query) -> tuple[str, dict]:
"""Executes a SQL query against the database and returns the results"""
db_connection = self.storage.find_by_id(
"database_connection", query.db_connection_id
Expand All @@ -226,7 +226,7 @@ def execute_query(self, query: Query) -> tuple[str, dict]:
return result

@override
def update_query(
def update_nl_query_response(
self, query_id: str, query: UpdateQueryRequest # noqa: ARG002
) -> NLQueryResponse:
nl_query_response_repository = NLQueryResponseRepository(self.storage)
Expand Down Expand Up @@ -260,7 +260,7 @@ def update_query(
return json.loads(json_util.dumps(nl_query_response))

@override
def execute_temp_query(
def get_nl_query_response(
self, query_id: str, query: ExecuteTempQueryRequest # noqa: ARG002
) -> NLQueryResponse:
nl_query_response_repository = NLQueryResponseRepository(self.storage)
Expand Down
4 changes: 3 additions & 1 deletion dataherald/db_scanner/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def scan(
tables = inspector.get_table_names()
if table_names:
table_names = [table.lower() for table in table_names]
tables = [table for table in tables if table and table.lower() in table_names]
tables = [
table for table in tables if table and table.lower() in table_names
]
if len(tables) == 0:
raise ValueError("No table found")
result = []
Expand Down
35 changes: 20 additions & 15 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,24 @@ def __init__(self, settings: Settings):
)

self.router.add_api_route(
"/api/v1/queries", self.execute_query, methods=["POST"], tags=["Queries"]
"/api/v1/nl-query-responses/{query_id}",
self.get_nl_query_response,
methods=["GET"],
tags=["NL query responses"],
)

self.router.add_api_route(
"/api/v1/query-responses/{query_id}",
self.execute_temp_query,
methods=["GET"],
tags=["Query responses"],
"/api/v1/nl-query-responses/{query_id}",
self.update_nl_query_response,
methods=["PATCH"],
tags=["NL query responses"],
)

self.router.add_api_route(
"/api/v1/query-responses/{query_id}",
self.update_query,
methods=["PATCH"],
tags=["Query responses"],
"/api/v1/sql-query-executions",
self.execute_sql_query,
methods=["POST"],
tags=["SQL queries"],
)

self.router.add_api_route(
Expand Down Expand Up @@ -191,19 +194,21 @@ def list_table_descriptions(
"""List table descriptions"""
return self._api.list_table_descriptions(db_connection_id, table_name)

def execute_query(self, query: Query) -> tuple[str, dict]:
def execute_sql_query(self, query: Query) -> tuple[str, dict]:
"""Executes a query on the given db_connection_id"""
return self._api.execute_query(query)
return self._api.execute_sql_query(query)

def update_query(self, query_id: str, query: UpdateQueryRequest) -> NLQueryResponse:
def update_nl_query_response(
self, query_id: str, query: UpdateQueryRequest
) -> NLQueryResponse:
"""Executes a query on the given db_connection_id"""
return self._api.update_query(query_id, query)
return self._api.update_nl_query_response(query_id, query)

def execute_temp_query(
def get_nl_query_response(
self, query_id: str, query: ExecuteTempQueryRequest
) -> NLQueryResponse:
"""Executes a query on the given db_connection_id"""
return self._api.execute_temp_query(query_id, query)
return self._api.get_nl_query_response(query_id, query)

def delete_golden_record(self, golden_record_id: str) -> dict:
"""Deletes a golden record"""
Expand Down

0 comments on commit 2b4f81b

Please sign in to comment.