diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index 6e96faf3..410ca067 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -72,15 +72,17 @@ def add_golden_records( pass @abstractmethod - def execute_query(self, query: Query) -> tuple[str, dict]: + def execute_sql_query(self, query: Query) -> tuple[str, dict]: pass @abstractmethod - def update_query(self, query_id: str, query: UpdateQueryRequest) -> NLQueryResponse: + def update_nl_query_response( + self, query_id: str, query: UpdateQueryRequest + ) -> NLQueryResponse: pass @abstractmethod - def execute_temp_query( + def get_nl_query_response( self, query_id: str, query: ExecuteTempQueryRequest ) -> NLQueryResponse: pass diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 888e5319..4b61cd0e 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -209,7 +209,7 @@ def add_golden_records( return context_store.add_golden_records(golden_records) @override - def execute_query(self, query: Query) -> tuple[str, dict]: + def execute_sql_query(self, query: Query) -> tuple[str, dict]: """Executes a SQL query against the database and returns the results""" db_connection = self.storage.find_by_id( "database_connection", query.db_connection_id @@ -226,7 +226,7 @@ def execute_query(self, query: Query) -> tuple[str, dict]: return result @override - def update_query( + def update_nl_query_response( self, query_id: str, query: UpdateQueryRequest # noqa: ARG002 ) -> NLQueryResponse: nl_query_response_repository = NLQueryResponseRepository(self.storage) @@ -260,7 +260,7 @@ def update_query( return json.loads(json_util.dumps(nl_query_response)) @override - def execute_temp_query( + def get_nl_query_response( self, query_id: str, query: ExecuteTempQueryRequest # noqa: ARG002 ) -> NLQueryResponse: nl_query_response_repository = NLQueryResponseRepository(self.storage) diff --git a/dataherald/db_scanner/sqlalchemy.py b/dataherald/db_scanner/sqlalchemy.py index 11eaf57a..d1628fa6 100644 --- a/dataherald/db_scanner/sqlalchemy.py +++ b/dataherald/db_scanner/sqlalchemy.py @@ -151,7 +151,9 @@ def scan( tables = inspector.get_table_names() if table_names: table_names = [table.lower() for table in table_names] - tables = [table for table in tables if table and table.lower() in table_names] + tables = [ + table for table in tables if table and table.lower() in table_names + ] if len(tables) == 0: raise ValueError("No table found") result = [] diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index 5ac6d519..fe661c13 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -116,21 +116,24 @@ def __init__(self, settings: Settings): ) self.router.add_api_route( - "/api/v1/queries", self.execute_query, methods=["POST"], tags=["Queries"] + "/api/v1/nl-query-responses/{query_id}", + self.get_nl_query_response, + methods=["GET"], + tags=["NL query responses"], ) self.router.add_api_route( - "/api/v1/query-responses/{query_id}", - self.execute_temp_query, - methods=["GET"], - tags=["Query responses"], + "/api/v1/nl-query-responses/{query_id}", + self.update_nl_query_response, + methods=["PATCH"], + tags=["NL query responses"], ) self.router.add_api_route( - "/api/v1/query-responses/{query_id}", - self.update_query, - methods=["PATCH"], - tags=["Query responses"], + "/api/v1/sql-query-executions", + self.execute_sql_query, + methods=["POST"], + tags=["SQL queries"], ) self.router.add_api_route( @@ -191,19 +194,21 @@ def list_table_descriptions( """List table descriptions""" return self._api.list_table_descriptions(db_connection_id, table_name) - def execute_query(self, query: Query) -> tuple[str, dict]: + def execute_sql_query(self, query: Query) -> tuple[str, dict]: """Executes a query on the given db_connection_id""" - return self._api.execute_query(query) + return self._api.execute_sql_query(query) - def update_query(self, query_id: str, query: UpdateQueryRequest) -> NLQueryResponse: + def update_nl_query_response( + self, query_id: str, query: UpdateQueryRequest + ) -> NLQueryResponse: """Executes a query on the given db_connection_id""" - return self._api.update_query(query_id, query) + return self._api.update_nl_query_response(query_id, query) - def execute_temp_query( + def get_nl_query_response( self, query_id: str, query: ExecuteTempQueryRequest ) -> NLQueryResponse: """Executes a query on the given db_connection_id""" - return self._api.execute_temp_query(query_id, query) + return self._api.get_nl_query_response(query_id, query) def delete_golden_record(self, golden_record_id: str) -> dict: """Deletes a golden record"""