From 5f7ff1f57ce52c6cb751ee78e51153c021064b57 Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Tue, 12 Sep 2023 11:54:20 -0600 Subject: [PATCH] Fix nl-query-responses POST endpoint --- dataherald/api/__init__.py | 5 ++--- dataherald/api/fastapi.py | 8 +++++--- dataherald/server/fastapi/__init__.py | 12 +++++------- dataherald/types.py | 2 +- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index 410ca067..dcfa9b19 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod -from typing import Any, List +from typing import List from dataherald.api.types import Query from dataherald.config import Component from dataherald.db_scanner.models.types import TableSchemaDetail -from dataherald.eval import Evaluation from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings from dataherald.types import ( DatabaseConnectionRequest, @@ -83,7 +82,7 @@ def update_nl_query_response( @abstractmethod def get_nl_query_response( - self, query_id: str, query: ExecuteTempQueryRequest + self, query_request: ExecuteTempQueryRequest ) -> NLQueryResponse: pass diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 5dfc5cc0..38ec1ca1 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -258,11 +258,13 @@ def update_nl_query_response( @override def get_nl_query_response( - self, query_id: str, query: ExecuteTempQueryRequest # noqa: ARG002 + self, query_request: ExecuteTempQueryRequest # noqa: ARG002 ) -> NLQueryResponse: nl_query_response_repository = NLQueryResponseRepository(self.storage) - nl_query_response = nl_query_response_repository.find_by_id(query_id) - nl_query_response.sql_query = query.sql_query + nl_query_response = nl_query_response_repository.find_by_id( + query_request.query_id + ) + nl_query_response.sql_query = query_request.sql_query try: generates_nl_answer = GeneratesNlAnswer(self.system, self.storage) nl_query_response = generates_nl_answer.execute(nl_query_response) diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index fe661c13..6f35fdd8 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -1,9 +1,8 @@ -from typing import Any, List, Union +from typing import Any, List import fastapi from fastapi import FastAPI as _FastAPI from fastapi import status -from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.routing import APIRoute @@ -11,7 +10,6 @@ from dataherald.api.types import Query from dataherald.config import Settings from dataherald.db_scanner.models.types import TableSchemaDetail -from dataherald.eval import Evaluation from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings from dataherald.types import ( DatabaseConnectionRequest, @@ -116,9 +114,9 @@ def __init__(self, settings: Settings): ) self.router.add_api_route( - "/api/v1/nl-query-responses/{query_id}", + "/api/v1/nl-query-responses", self.get_nl_query_response, - methods=["GET"], + methods=["POST"], tags=["NL query responses"], ) @@ -205,10 +203,10 @@ def update_nl_query_response( return self._api.update_nl_query_response(query_id, query) def get_nl_query_response( - self, query_id: str, query: ExecuteTempQueryRequest + self, query_request: ExecuteTempQueryRequest ) -> NLQueryResponse: """Executes a query on the given db_connection_id""" - return self._api.get_nl_query_response(query_id, query) + return self._api.get_nl_query_response(query_request) def delete_golden_record(self, golden_record_id: str) -> dict: """Deletes a golden record""" diff --git a/dataherald/types.py b/dataherald/types.py index 6aa4eba3..f235b39b 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -1,4 +1,3 @@ -# from datetime import datetime add this later from enum import Enum from typing import Any @@ -26,6 +25,7 @@ class UpdateQueryRequest(BaseModel): class ExecuteTempQueryRequest(BaseModel): + query_id: str sql_query: str