Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update .env.example #508

Merged
merged 3 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: chartboost/ruff-action@v1
- uses: rickstaa/action-black@v1
with:
black_args: ". --check"
- uses: psf/black@stable

2 changes: 1 addition & 1 deletion services/engine/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ AGENT_MAX_ITERATIONS = 15
#timeout in seconds for the engine to return a response. Defaults to 150 seconds
DH_ENGINE_TIMEOUT = 150
#tmeout for SQL execution, our agents exceute the SQL query to recover from errors, this is the timeout for that execution. Defaults to 30 seconds
SQL_EXECUTION_TIMEOUT =
SQL_EXECUTION_TIMEOUT = 30
#The upper limit on number of rows returned from the query engine (equivalent to using LIMIT N in PostgreSQL/MySQL/SQlite). Defauls to 50
UPPER_LIMIT_QUERY_RETURN_ROWS = 50
#Encryption key for storing DB connection data in Mongo
Expand Down
8 changes: 5 additions & 3 deletions services/engine/dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,11 @@ def create_finetuning_job(
Finetuning(
db_connection_id=fine_tuning_request.db_connection_id,
schemas=fine_tuning_request.schemas,
alias=fine_tuning_request.alias
if fine_tuning_request.alias
else f"{db_connection.alias}_{datetime.datetime.now().strftime('%Y%m%d%H')}",
alias=(
fine_tuning_request.alias
if fine_tuning_request.alias
else f"{db_connection.alias}_{datetime.datetime.now().strftime('%Y%m%d%H')}"
),
base_llm=base_llm,
golden_sqls=[str(golden_sql.id) for golden_sql in golden_sqls],
metadata=fine_tuning_request.metadata,
Expand Down
1 change: 1 addition & 0 deletions services/engine/dataherald/db_scanner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base class that all scanner classes inherit from."""

from abc import ABC, abstractmethod

from dataherald.config import Component
Expand Down
16 changes: 9 additions & 7 deletions services/engine/dataherald/finetuning/openai_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,15 @@ def create_fine_tuning_job(self):
finetuning_request = self.client.fine_tuning.jobs.create(
training_file=model.finetuning_file_id,
model=model.base_llm.model_name,
hyperparameters=model.base_llm.model_parameters
if model.base_llm.model_parameters
else {
"batch_size": 1,
"learning_rate_multiplier": "auto",
"n_epochs": 3,
},
hyperparameters=(
model.base_llm.model_parameters
if model.base_llm.model_parameters
else {
"batch_size": 1,
"learning_rate_multiplier": "auto",
"n_epochs": 3,
}
),
)
model.finetuning_job_id = finetuning_request.id
if finetuning_request.status == "failed":
Expand Down
40 changes: 23 additions & 17 deletions services/engine/dataherald/scripts/migrate_v006_to_v100.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def update_object_id_fields(field_name: str, collection_name: str):
"_id": question["_id"],
"db_connection_id": str(question["db_connection_id"]),
"text": question["question"],
"created_at": None
if len(responses) == 0
else responses[0]["created_at"],
"created_at": (
None if len(responses) == 0 else responses[0]["created_at"]
),
"metadata": None,
},
)
Expand All @@ -112,17 +112,21 @@ def update_object_id_fields(field_name: str, collection_name: str):
{
"_id": response["_id"],
"prompt_id": str(response["question_id"]),
"evaluate": False
if response["confidence_score"] is None
else True,
"evaluate": (
False if response["confidence_score"] is None else True
),
"sql": response["sql_query"],
"status": "VALID"
if response["sql_generation_status"] == "VALID"
else "INVALID",
"completed_at": response["created_at"]
+ timedelta(seconds=response["exec_time"])
if response["exec_time"]
else None,
"status": (
"VALID"
if response["sql_generation_status"] == "VALID"
else "INVALID"
),
"completed_at": (
response["created_at"]
+ timedelta(seconds=response["exec_time"])
if response["exec_time"]
else None
),
"tokens_used": response["total_tokens"],
"confidence_score": response["confidence_score"],
"error": response["error_message"],
Expand All @@ -140,10 +144,12 @@ def update_object_id_fields(field_name: str, collection_name: str):
{
"sql_generation_id": str(response["_id"]),
"text": response["response"],
"created_at": response["created_at"]
+ timedelta(seconds=response["exec_time"])
if response["exec_time"]
else response["created_at"],
"created_at": (
response["created_at"]
+ timedelta(seconds=response["exec_time"])
if response["exec_time"]
else response["created_at"]
),
"metadata": None,
},
)
Expand Down
6 changes: 3 additions & 3 deletions services/engine/dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,9 @@ def export_csv_file(self, sql_generation_id: str) -> StreamingResponse:
stream = self._api.export_csv_file(sql_generation_id)

response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv")
response.headers[
"Content-Disposition"
] = f"attachment; filename=sql_generation_{sql_generation_id}.csv"
response.headers["Content-Disposition"] = (
f"attachment; filename=sql_generation_{sql_generation_id}.csv"
)
return response

def delete_golden_sql(self, golden_sql_id: str) -> dict:
Expand Down
16 changes: 10 additions & 6 deletions services/engine/dataherald/services/nl_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ def create(
initial_nl_generation = NLGeneration(
sql_generation_id=sql_generation_id,
created_at=datetime.now(),
llm_config=nl_generation_request.llm_config
if nl_generation_request.llm_config
else LLMConfig(),
llm_config=(
nl_generation_request.llm_config
if nl_generation_request.llm_config
else LLMConfig()
),
metadata=nl_generation_request.metadata,
)
self.nl_generation_repository.insert(initial_nl_generation)
Expand All @@ -46,9 +48,11 @@ def create(
nl_generator = GeneratesNlAnswer(
self.system,
self.storage,
nl_generation_request.llm_config
if nl_generation_request.llm_config
else LLMConfig(),
(
nl_generation_request.llm_config
if nl_generation_request.llm_config
else LLMConfig()
),
)
try:
nl_generation = nl_generator.execute(
Expand Down
48 changes: 30 additions & 18 deletions services/engine/dataherald/services/sql_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ def create( # noqa: PLR0912
initial_sql_generation = SQLGeneration(
prompt_id=prompt_id,
created_at=datetime.now(),
llm_config=sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
llm_config=(
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig()
),
metadata=sql_generation_request.metadata,
)
langsmith_metadata = (
Expand Down Expand Up @@ -115,16 +117,20 @@ def create( # noqa: PLR0912
)
sql_generator = DataheraldSQLAgent(
self.system,
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
(
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig()
),
)
else:
sql_generator = DataheraldFinetuningAgent(
self.system,
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
(
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig()
),
)
sql_generator.finetuning_id = sql_generation_request.finetuning_id
sql_generator.use_fintuned_model_only = (
Expand Down Expand Up @@ -184,9 +190,11 @@ def start_streaming(
initial_sql_generation = SQLGeneration(
prompt_id=prompt_id,
created_at=datetime.now(),
llm_config=sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
llm_config=(
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig()
),
metadata=sql_generation_request.metadata,
)
langsmith_metadata = (
Expand Down Expand Up @@ -215,16 +223,20 @@ def start_streaming(
)
sql_generator = DataheraldSQLAgent(
self.system,
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
(
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig()
),
)
else:
sql_generator = DataheraldFinetuningAgent(
self.system,
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
(
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig()
),
)
sql_generator.finetuning_id = sql_generation_request.finetuning_id
sql_generator.use_fintuned_model_only = (
Expand Down
1 change: 1 addition & 0 deletions services/engine/dataherald/smart_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base class that all cache classes inherit from."""

from abc import ABC, abstractmethod
from typing import Any, Union

Expand Down
1 change: 1 addition & 0 deletions services/engine/dataherald/sql_database/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""SQL wrapper around SQLDatabase in langchain."""

import logging
import re
from typing import List
Expand Down
1 change: 1 addition & 0 deletions services/engine/dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base class that all sql generation classes inherit from."""

import datetime
import logging
import os
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,9 @@ def create_sql_agent(
suffix: str = FINETUNING_AGENT_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: List[str] | None = None,
max_iterations: int
| None = int(os.getenv("AGENT_MAX_ITERATIONS", "12")), # noqa: B008
max_iterations: int | None = int(
os.getenv("AGENT_MAX_ITERATIONS", "12")
), # noqa: B008
max_execution_time: float | None = None,
early_stopping_method: str = "generate",
verbose: bool = False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,9 @@ def create_sql_agent(
input_variables: List[str] | None = None,
max_examples: int = 20,
number_of_instructions: int = 1,
max_iterations: int
| None = int(os.getenv("AGENT_MAX_ITERATIONS", "15")), # noqa: B008
max_iterations: int | None = int(
os.getenv("AGENT_MAX_ITERATIONS", "15")
), # noqa: B008
max_execution_time: float | None = None,
early_stopping_method: str = "generate",
verbose: bool = False,
Expand Down
15 changes: 12 additions & 3 deletions services/engine/dataherald/utils/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ class S3:
def __init__(self):
self.settings = Settings()

def _get_client(self, access_key: str | None = None, secret_access_key: str | None = None, region: str | None = None) -> boto3.client:
def _get_client(
self,
access_key: str | None = None,
secret_access_key: str | None = None,
region: str | None = None,
) -> boto3.client:
_access_key = access_key or self.settings.s3_aws_access_key_id
_secret_access_key = secret_access_key or self.settings.s3_aws_secret_access_key
_region = region or self.settings.s3_region
Expand Down Expand Up @@ -44,7 +49,9 @@ def upload(self, file_location, file_storage: FileStorage | None = None) -> str:
bucket_name = file_storage.bucket
s3_client = self._get_client(
access_key=fernet_encrypt.decrypt(file_storage.access_key_id),
secret_access_key=fernet_encrypt.decrypt(file_storage.secret_access_key),
secret_access_key=fernet_encrypt.decrypt(
file_storage.secret_access_key
),
region=file_storage.region,
)
else:
Expand All @@ -63,7 +70,9 @@ def download(self, path: str, file_storage: FileStorage | None = None) -> str:
fernet_encrypt = FernetEncrypt()
s3_client = self._get_client(
access_key=fernet_encrypt.decrypt(file_storage.access_key_id),
secret_access_key=fernet_encrypt.decrypt(file_storage.secret_access_key),
secret_access_key=fernet_encrypt.decrypt(
file_storage.secret_access_key
),
region=file_storage.region,
)
else:
Expand Down
8 changes: 5 additions & 3 deletions services/engine/dataherald/vector_store/astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ def add_records(self, golden_sqls: List[GoldenSQL], collection: str):
{
"_id": str(golden_sqls[key].id),
"$vector": embeds[key],
"tables_used": ", ".join(Parser(golden_sqls[key].sql))
if isinstance(Parser(golden_sqls[key].sql), list)
else "",
"tables_used": (
", ".join(Parser(golden_sqls[key].sql))
if isinstance(Parser(golden_sqls[key].sql), list)
else ""
),
"db_connection_id": str(golden_sqls[key].db_connection_id),
}
)
Expand Down
8 changes: 5 additions & 3 deletions services/engine/dataherald/vector_store/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ def add_records(self, golden_sqls: List[GoldenSQL], collection: str):
collection,
[
{
"tables_used": ", ".join(Parser(golden_sql.sql))
if isinstance(Parser(golden_sql.sql), list)
else "",
"tables_used": (
", ".join(Parser(golden_sql.sql))
if isinstance(Parser(golden_sql.sql), list)
else ""
),
"db_connection_id": str(golden_sql.db_connection_id),
}
],
Expand Down
1 change: 1 addition & 0 deletions services/engine/setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Set up the package."""

import os
from pathlib import Path

Expand Down
Loading