Skip to content

Commit

Permalink
Fix nl-query-responses POST endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 committed Sep 12, 2023
1 parent 0c4cb94 commit b81c687
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 14 deletions.
5 changes: 2 additions & 3 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any, List
from typing import List

from dataherald.api.types import Query
from dataherald.config import Component
from dataherald.db_scanner.models.types import TableSchemaDetail
from dataherald.eval import Evaluation
from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings
from dataherald.types import (
DatabaseConnectionRequest,
Expand Down Expand Up @@ -83,7 +82,7 @@ def update_nl_query_response(

@abstractmethod
def get_nl_query_response(
self, query_id: str, query: ExecuteTempQueryRequest
self, query_request: ExecuteTempQueryRequest
) -> NLQueryResponse:
pass

Expand Down
8 changes: 5 additions & 3 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,13 @@ def update_nl_query_response(

@override
def get_nl_query_response(
self, query_id: str, query: ExecuteTempQueryRequest # noqa: ARG002
self, query_request: ExecuteTempQueryRequest # noqa: ARG002
) -> NLQueryResponse:
nl_query_response_repository = NLQueryResponseRepository(self.storage)
nl_query_response = nl_query_response_repository.find_by_id(query_id)
nl_query_response.sql_query = query.sql_query
nl_query_response = nl_query_response_repository.find_by_id(
query_request.query_id
)
nl_query_response.sql_query = query_request.sql_query
try:
generates_nl_answer = GeneratesNlAnswer(self.system, self.storage)
nl_query_response = generates_nl_answer.execute(nl_query_response)
Expand Down
12 changes: 5 additions & 7 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from typing import Any, List, Union
from typing import Any, List

import fastapi
from fastapi import FastAPI as _FastAPI
from fastapi import status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute

import dataherald
from dataherald.api.types import Query
from dataherald.config import Settings
from dataherald.db_scanner.models.types import TableSchemaDetail
from dataherald.eval import Evaluation
from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings
from dataherald.types import (
DatabaseConnectionRequest,
Expand Down Expand Up @@ -116,9 +114,9 @@ def __init__(self, settings: Settings):
)

self.router.add_api_route(
"/api/v1/nl-query-responses/{query_id}",
"/api/v1/nl-query-responses",
self.get_nl_query_response,
methods=["GET"],
methods=["POST"],
tags=["NL query responses"],
)

Expand Down Expand Up @@ -205,10 +203,10 @@ def update_nl_query_response(
return self._api.update_nl_query_response(query_id, query)

def get_nl_query_response(
self, query_id: str, query: ExecuteTempQueryRequest
self, query_request: ExecuteTempQueryRequest
) -> NLQueryResponse:
"""Executes a query on the given db_connection_id"""
return self._api.get_nl_query_response(query_id, query)
return self._api.get_nl_query_response(query_request)

def delete_golden_record(self, golden_record_id: str) -> dict:
"""Deletes a golden record"""
Expand Down
2 changes: 1 addition & 1 deletion dataherald/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# from datetime import datetime add this later
from enum import Enum
from typing import Any

Expand Down Expand Up @@ -26,6 +25,7 @@ class UpdateQueryRequest(BaseModel):


class ExecuteTempQueryRequest(BaseModel):
query_id: str
sql_query: str


Expand Down

0 comments on commit b81c687

Please sign in to comment.