Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DH-4724/adding support for db level instructions #185

Merged
merged 11 commits into from
Sep 27, 2023
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,56 @@ curl -X 'PATCH' \
}'
```

#### adding database level instructions

You can add database level instructions to the context store manually from the `POST /api/v1/{db_connection_id}/instructions` endpoint
These instructions are passed directly to the engine and can be used to steer the engine to generate SQL that is more in line with your business logic.

```
curl -X 'POST' \
'<host>/api/v1/{db_connection_id}/instructions' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"instruction": "This is a database level instruction"
}'
```

#### getting database level instructions

You can get database level instructions from the `GET /api/v1/{db_connection_id}/instructions` endpoint

```
curl -X 'GET' \
'<host>/api/v1/{db_connection_id}/instructions?page=1&limit=10' \
-H 'accept: application/json'
```

#### deleting database level instructions

You can delete database level instructions from the `DELETE /api/v1/{db_connection_id}/instructions/{instruction_id}` endpoint

```
curl -X 'DELETE' \
'<host>/api/v1/{db_connection_id}/instructions/{instruction_id}' \
-H 'accept: application/json'
```

#### updating database level instructions

You can update database level instructions from the `PATCH /api/v1/{db_connection_id}/instructions/{instruction_id}` endpoint
Try different instructions to see how the engine generates SQL

```
curl -X 'PATCH' \
'<host>/api/v1/{db_connection_id}/instructions/{instruction_id}' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"instruction": "This is a database level instruction"
}'
```


### Querying the Database in Natural Language
Once you have connected the engine to your data warehouse (and preferably added some context to the store), you can query your data warehouse using the `POST /api/v1/question` endpoint.
Expand Down
27 changes: 27 additions & 0 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
ExecuteTempQueryRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
InstructionRequest,
NLQueryResponse,
QuestionRequest,
ScannerRequest,
Expand Down Expand Up @@ -97,3 +99,28 @@ def delete_golden_record(self, golden_record_id: str) -> dict:
@abstractmethod
def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecord]:
pass

@abstractmethod
def add_instruction(
self, db_connection_id: str, instruction_request: InstructionRequest
) -> Instruction:
pass

@abstractmethod
def get_instructions(
self, db_connection_id: str, page: int = 1, limit: int = 10
) -> List[Instruction]:
pass

@abstractmethod
def delete_instruction(self, db_connection_id: str, instruction_id: str) -> dict:
pass

@abstractmethod
def update_instruction(
self,
db_connection_id: str,
instruction_id: str,
instruction_request: InstructionRequest,
) -> Instruction:
pass
51 changes: 50 additions & 1 deletion dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dataherald.repositories.base import NLQueryResponseRepository
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.repositories.nl_question import NLQuestionRepository
from dataherald.sql_database.base import (
InvalidDBConnectionError,
Expand All @@ -33,6 +34,8 @@
ExecuteTempQueryRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
InstructionRequest,
NLQuery,
NLQueryResponse,
QuestionRequest,
Expand Down Expand Up @@ -117,7 +120,7 @@ def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse:
start_generated_answer = time.time()
try:
generated_answer = sql_generation.generate_response(
user_question, database_connection, context
user_question, database_connection, context[0]
)
logger.info("Starts evaluator...")
confidence_score = evaluator.get_confidence_score(
Expand Down Expand Up @@ -312,3 +315,49 @@ def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecor
start_idx = (page - 1) * limit
end_idx = start_idx + limit
return all_records[start_idx:end_idx]

@override
def add_instruction(
self, db_connection_id: str, instruction_request: InstructionRequest
) -> Instruction:
instruction_repository = InstructionRepository(self.storage)
instruction = Instruction(
instruction=instruction_request.instruction,
db_connection_id=db_connection_id,
)
return instruction_repository.insert(instruction)

@override
def get_instructions(
self, db_connection_id: str, page: int = 1, limit: int = 10
) -> List[Instruction]:
instruction_repository = InstructionRepository(self.storage)
instructions = instruction_repository.find_by({"db_connection_id": db_connection_id})
start_idx = (page - 1) * limit
end_idx = start_idx + limit
jcjc712 marked this conversation as resolved.
Show resolved Hide resolved
return instructions[start_idx:end_idx]

@override
def delete_instruction(self, db_connection_id: str, instruction_id: str) -> dict:
instruction_repository = InstructionRepository(self.storage)
instruction = instruction_repository.find_by_id(instruction_id)
if instruction.db_connection_id != db_connection_id:
raise HTTPException(status_code=404, detail="Instruction not found")
instruction_repository.delete_by_id(instruction_id)
return {"status": "success"}

@override
def update_instruction(
self,
db_connection_id: str,
instruction_id: str,
instruction_request: InstructionRequest,
) -> Instruction:
instruction_repository = InstructionRepository(self.storage)
instruction = Instruction(
id=instruction_id,
instruction=instruction_request.instruction,
db_connection_id=db_connection_id,
)
instruction_repository.update(instruction)
return json.loads(json_util.dumps(instruction))
4 changes: 2 additions & 2 deletions dataherald/context_store/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from abc import ABC, abstractmethod
from typing import Any, List
from typing import List, Tuple

from dataherald.config import Component, System
from dataherald.db import DB
Expand All @@ -25,7 +25,7 @@ def __init__(self, system: System):
@abstractmethod
def retrieve_context_for_question(
self, nl_question: NLQuery, number_of_samples: int = 3
) -> List[dict] | None:
) -> Tuple[List[dict] | None, List[dict] | None]:
pass

@abstractmethod
Expand Down
21 changes: 17 additions & 4 deletions dataherald/context_store/default.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
from typing import List
from typing import List, Tuple

from overrides import override
from sql_metadata import Parser

from dataherald.config import System
from dataherald.context_store import ContextStore
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.types import GoldenRecord, GoldenRecordRequest, NLQuery

logger = logging.getLogger(__name__)
Expand All @@ -19,7 +20,7 @@ def __init__(self, system: System):
@override
def retrieve_context_for_question(
self, nl_question: NLQuery, number_of_samples: int = 3
) -> List[dict] | None:
) -> Tuple[List[dict] | None, List[dict] | None]:
logger.info(f"Getting context for {nl_question.question}")
closest_questions = self.vector_store.query(
query_texts=[nl_question.question],
Expand All @@ -41,9 +42,21 @@ def retrieve_context_for_question(
}
)
if len(samples) == 0:
return None
samples = None
instructions = []
instruction_repository = InstructionRepository(self.db)
all_instructions = instruction_repository.find_all()
for instruction in all_instructions:
if instruction.db_connection_id == nl_question.db_connection_id:
instructions.append(
{
"instruction": instruction.instruction,
}
)
if len(instructions) == 0:
instructions = None

return samples
return samples, instructions

@override
def add_golden_records(
Expand Down
52 changes: 52 additions & 0 deletions dataherald/repositories/instructions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from bson.objectid import ObjectId

from dataherald.types import Instruction

DB_COLLECTION = "instructions"


class InstructionRepository:
def __init__(self, storage):
self.storage = storage

def insert(self, instruction: Instruction) -> Instruction:
instruction.id = str(
self.storage.insert_one(DB_COLLECTION, instruction.dict(exclude={"id"}))
)
return instruction

jcjc712 marked this conversation as resolved.
Show resolved Hide resolved
def find_one(self, query: dict) -> Instruction | None:
row = self.storage.find_one(DB_COLLECTION, query)
if not row:
return None
return Instruction(**row)

def update(self, instruction: Instruction) -> Instruction:
self.storage.update_or_create(
DB_COLLECTION,
{"_id": ObjectId(instruction.id)},
instruction.dict(exclude={"id"}),
)
return instruction

def find_by_id(self, id: str) -> Instruction | None:
row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)})
if not row:
return None
return Instruction(**row)

def find_by(self, query: dict) -> list[Instruction]:
rows = self.storage.find(DB_COLLECTION, query)
result = []
for row in rows:
obj = Instruction(**row)
obj.id = str(row["_id"])
result.append(obj)
return result

def find_all(self) -> list[Instruction]:
rows = self.storage.find_all(DB_COLLECTION)
return [Instruction(id=str(row["_id"]), **row) for row in rows]

def delete_by_id(self, id: str) -> int:
return self.storage.delete_by_id(DB_COLLECTION, id)
66 changes: 66 additions & 0 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
ExecuteTempQueryRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
InstructionRequest,
NLQueryResponse,
QuestionRequest,
ScannerRequest,
Expand Down Expand Up @@ -134,6 +136,34 @@ def __init__(self, settings: Settings):
tags=["SQL queries"],
)

self.router.add_api_route(
"/api/v1/{db_connection_id}/instructions",
self.add_instruction,
methods=["POST"],
tags=["Instructions"],
)

self.router.add_api_route(
"/api/v1/{db_connection_id}/instructions",
self.get_instructions,
methods=["GET"],
tags=["Instructions"],
)

self.router.add_api_route(
"/api/v1/{db_connection_id}/instructions/{instruction_id}",
self.delete_instruction,
methods=["DELETE"],
tags=["Instructions"],
)

self.router.add_api_route(
"/api/v1/{db_connection_id}/instructions/{instruction_id}",
self.update_instruction,
methods=["PATCH"],
tags=["Instructions"],
)

self.router.add_api_route(
"/api/v1/heartbeat", self.heartbeat, methods=["GET"], tags=["System"]
)
Expand Down Expand Up @@ -229,3 +259,39 @@ def add_golden_records(
def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecord]:
"""Gets golden records"""
return self._api.get_golden_records(page, limit)

def add_instruction(
self, db_connection_id: str, instruction_request: InstructionRequest
) -> Instruction:
"""Adds an instruction"""
created_records = self._api.add_instruction(
db_connection_id, instruction_request
)

# Return a JSONResponse with status code 201 and the location header.
instruction_as_dict = created_records.dict()

return JSONResponse(
content=instruction_as_dict, status_code=status.HTTP_201_CREATED
)

def get_instructions(
self, db_connection_id: str, page: int = 1, limit: int = 10
) -> List[Instruction]:
"""Gets instructions"""
return self._api.get_instructions(db_connection_id, page, limit)

def delete_instruction(self, db_connection_id: str, instruction_id: str) -> dict:
"""Deletes an instruction"""
return self._api.delete_instruction(db_connection_id, instruction_id)

def update_instruction(
self,
db_connection_id: str,
instruction_id: str,
instruction_request: InstructionRequest,
) -> Instruction:
"""Updates an instruction"""
return self._api.update_instruction(
db_connection_id, instruction_id, instruction_request
)
Loading