From 91661094f0e7bda10c4004fc9fbad72f4a8443b8 Mon Sep 17 00:00:00 2001 From: Mohammadreza Pourreza <71866535+MohammadrezaPourreza@users.noreply.github.com> Date: Fri, 17 May 2024 12:50:59 -0400 Subject: [PATCH] DH-5776/fixing the azure openai (#487) * DH-5776/fixing the azure openai * Fixing the linter * reformat with black --- services/engine/dataherald/api/fastapi.py | 12 +++--- services/engine/dataherald/config.py | 2 +- .../finetuning/openai_finetuning.py | 20 +++++++--- .../engine/dataherald/model/base_model.py | 10 ++--- .../engine/dataherald/model/chat_model.py | 12 +++--- .../dataherald/services/sql_generations.py | 4 +- .../dataherald/sql_generator/__init__.py | 4 +- .../dataherald_finetuning_agent.py | 36 +++++++++++------ .../sql_generator/dataherald_sqlagent.py | 39 ++++++++----------- .../exceptions/exception_handlers.py | 1 - services/enterprise/exceptions/exceptions.py | 1 - .../modules/db_connection/controller.py | 1 - .../modules/db_connection/service.py | 10 +++-- .../modules/organization/invoice/service.py | 19 ++++----- 14 files changed, 93 insertions(+), 78 deletions(-) diff --git a/services/engine/dataherald/api/fastapi.py b/services/engine/dataherald/api/fastapi.py index 23e3976c..e9edbd57 100644 --- a/services/engine/dataherald/api/fastapi.py +++ b/services/engine/dataherald/api/fastapi.py @@ -110,8 +110,8 @@ def async_scanning(scanner, database, table_descriptions, storage): ) -def async_fine_tuning(storage, model): - openai_fine_tuning = OpenAIFineTuning(storage, model) +def async_fine_tuning(system, storage, model): + openai_fine_tuning = OpenAIFineTuning(system, storage, model) openai_fine_tuning.create_fintuning_dataset() openai_fine_tuning.create_fine_tuning_job() @@ -626,7 +626,7 @@ def create_finetuning_job( e, fine_tuning_request.dict(), "finetuning_not_created" ) - background_tasks.add_task(async_fine_tuning, self.storage, model) + background_tasks.add_task(async_fine_tuning, self.system, self.storage, model) return model @@ -652,7 +652,7 @@ def cancel_finetuning_job( status_code=400, detail="Model has already been cancelled." ) - openai_fine_tuning = OpenAIFineTuning(self.storage, model) + openai_fine_tuning = OpenAIFineTuning(self.system, self.storage, model) return openai_fine_tuning.cancel_finetuning_job() @@ -665,7 +665,7 @@ def get_finetunings(self, db_connection_id: str | None = None) -> list[Finetunin models = model_repository.find_by(query) result = [] for model in models: - openai_fine_tuning = OpenAIFineTuning(self.storage, model) + openai_fine_tuning = OpenAIFineTuning(self.system, self.storage, model) result.append( Finetuning(**openai_fine_tuning.retrieve_finetuning_job().dict()) ) @@ -685,7 +685,7 @@ def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning: model = model_repository.find_by_id(finetuning_job_id) if not model: raise HTTPException(status_code=404, detail="Model not found") - openai_fine_tuning = OpenAIFineTuning(self.storage, model) + openai_fine_tuning = OpenAIFineTuning(self.system, self.storage, model) return openai_fine_tuning.retrieve_finetuning_job() @override diff --git a/services/engine/dataherald/config.py b/services/engine/dataherald/config.py index 370947c1..9a43e334 100644 --- a/services/engine/dataherald/config.py +++ b/services/engine/dataherald/config.py @@ -45,7 +45,7 @@ class Settings(BaseSettings): encrypt_key: str = os.environ.get("ENCRYPT_KEY") s3_aws_access_key_id: str | None = os.environ.get("S3_AWS_ACCESS_KEY_ID") s3_aws_secret_access_key: str | None = os.environ.get("S3_AWS_SECRET_ACCESS_KEY") - #Needed for Azure OpenAI integration: + # Needed for Azure OpenAI integration: azure_api_key: str | None = os.environ.get("AZURE_API_KEY") embedding_model: str | None = os.environ.get("EMBEDDING_MODEL") azure_api_version: str | None = os.environ.get("AZURE_API_VERSION") diff --git a/services/engine/dataherald/finetuning/openai_finetuning.py b/services/engine/dataherald/finetuning/openai_finetuning.py index 58fe89c1..2876d2c8 100644 --- a/services/engine/dataherald/finetuning/openai_finetuning.py +++ b/services/engine/dataherald/finetuning/openai_finetuning.py @@ -7,12 +7,13 @@ import numpy as np import tiktoken -from langchain_openai import OpenAIEmbeddings +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from openai import OpenAI from overrides import override from sql_metadata import Parser from tiktoken import Encoding +from dataherald.config import System from dataherald.db_scanner.models.types import TableDescription, TableDescriptionStatus from dataherald.db_scanner.repository.base import TableDescriptionRepository from dataherald.finetuning import FinetuningModel @@ -36,17 +37,24 @@ class OpenAIFineTuning(FinetuningModel): storage: Any client: OpenAI - def __init__(self, storage: Any, fine_tuning_model: Finetuning): + def __init__(self, system: System, storage: Any, fine_tuning_model: Finetuning): self.storage = storage + self.system = system self.fine_tuning_model = fine_tuning_model db_connection_repository = DatabaseConnectionRepository(storage) db_connection = db_connection_repository.find_by_id( fine_tuning_model.db_connection_id ) - self.embedding = OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure - openai_api_key=db_connection.decrypt_api_key(), - model=EMBEDDING_MODEL, - ) + if self.system.settings["azure_api_key"] is not None: + self.embedding = AzureOpenAIEmbeddings( + azure_api_key=db_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) + else: + self.embedding = OpenAIEmbeddings( + openai_api_key=db_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) self.encoding = tiktoken.encoding_for_model( fine_tuning_model.base_llm.model_name ) diff --git a/services/engine/dataherald/model/base_model.py b/services/engine/dataherald/model/base_model.py index 6b7fb23a..398655a6 100644 --- a/services/engine/dataherald/model/base_model.py +++ b/services/engine/dataherald/model/base_model.py @@ -1,7 +1,7 @@ import os from typing import Any -from langchain.llms import AlephAlpha, Anthropic, Cohere, OpenAI +from langchain.llms import AlephAlpha, Anthropic, AzureOpenAI, Cohere, OpenAI from overrides import override from dataherald.model import LLMModel @@ -19,7 +19,7 @@ def __init__(self, system): self.azure_api_key = os.environ.get("AZURE_API_KEY") @override - def get_model( + def get_model( # noqa: C901 self, database_connection: DatabaseConnection, model_family="openai", @@ -27,8 +27,8 @@ def get_model( api_base: str | None = None, # noqa: ARG002 **kwargs: Any ) -> Any: - if self.system.settings['azure_api_key'] != None: - model_family = 'azure' + if self.system.settings["azure_api_key"] is not None: + model_family = "azure" if database_connection.llm_api_key is not None: fernet_encrypt = FernetEncrypt() api_key = fernet_encrypt.decrypt(database_connection.llm_api_key) @@ -39,7 +39,7 @@ def get_model( elif model_family == "google": self.google_api_key = api_key elif model_family == "azure": - self.azure_api_key == api_key + self.azure_api_key = api_key if self.openai_api_key: self.model = OpenAI(model_name=model_name, **kwargs) elif self.aleph_alpha_api_key: diff --git a/services/engine/dataherald/model/chat_model.py b/services/engine/dataherald/model/chat_model.py index 3ab1fcce..4c7d57f9 100644 --- a/services/engine/dataherald/model/chat_model.py +++ b/services/engine/dataherald/model/chat_model.py @@ -1,7 +1,7 @@ from typing import Any from langchain_community.chat_models import ChatAnthropic, ChatCohere, ChatGooglePalm -from langchain_openai import ChatOpenAI, AzureChatOpenAI +from langchain_openai import AzureChatOpenAI, ChatOpenAI from overrides import override from dataherald.model import LLMModel @@ -22,16 +22,16 @@ def get_model( **kwargs: Any ) -> Any: api_key = database_connection.decrypt_api_key() - if self.system.settings['azure_api_key'] != None: - model_family = 'azure' + if self.system.settings["azure_api_key"] is not None: + model_family = "azure" if model_family == "azure": - if api_base.endswith("/"): #TODO check where final "/" is added to api_base + if api_base.endswith("/"): # check where final "/" is added to api_base api_base = api_base[:-1] return AzureChatOpenAI( deployment_name=model_name, openai_api_key=api_key, - azure_endpoint= api_base, - api_version=self.system.settings['azure_api_version'], + azure_endpoint=api_base, + api_version=self.system.settings["azure_api_version"], **kwargs ) if model_family == "openai": diff --git a/services/engine/dataherald/services/sql_generations.py b/services/engine/dataherald/services/sql_generations.py index b1890443..413101ca 100644 --- a/services/engine/dataherald/services/sql_generations.py +++ b/services/engine/dataherald/services/sql_generations.py @@ -63,9 +63,9 @@ def update_the_initial_sql_generation( initial_sql_generation.intermediate_steps = sql_generation.intermediate_steps return self.sql_generation_repository.update(initial_sql_generation) - def create( + def create( # noqa: PLR0912 self, prompt_id: str, sql_generation_request: SQLGenerationRequest - ) -> SQLGeneration: + ) -> SQLGeneration: # noqa: PLR0912 initial_sql_generation = SQLGeneration( prompt_id=prompt_id, created_at=datetime.now(), diff --git a/services/engine/dataherald/sql_generator/__init__.py b/services/engine/dataherald/sql_generator/__init__.py index e997920e..6612332b 100644 --- a/services/engine/dataherald/sql_generator/__init__.py +++ b/services/engine/dataherald/sql_generator/__init__.py @@ -179,7 +179,7 @@ def generate_response( """Generates a response to a user question.""" pass - def stream_agent_steps( # noqa: C901 + def stream_agent_steps( # noqa: PLR0912, C901 self, question: str, agent_executor: AgentExecutor, @@ -187,7 +187,7 @@ def stream_agent_steps( # noqa: C901 sql_generation_repository: SQLGenerationRepository, queue: Queue, metadata: dict = None, - ): + ): # noqa: PLR0912 try: with get_openai_callback() as cb: for chunk in agent_executor.stream( diff --git a/services/engine/dataherald/sql_generator/dataherald_finetuning_agent.py b/services/engine/dataherald/sql_generator/dataherald_finetuning_agent.py index 6fc64b95..fe54dcf4 100644 --- a/services/engine/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/services/engine/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -21,7 +21,7 @@ from langchain.chains.llm import LLMChain from langchain.tools.base import BaseTool from langchain_community.callbacks import get_openai_callback -from langchain_openai import OpenAIEmbeddings +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from openai import OpenAI from overrides import override from pydantic import BaseModel, Field @@ -587,7 +587,7 @@ def generate_response( ) finetunings_repository = FinetuningsRepository(storage) finetuning = finetunings_repository.find_by_id(self.finetuning_id) - openai_fine_tuning = OpenAIFineTuning(storage, finetuning) + openai_fine_tuning = OpenAIFineTuning(self.system, storage, finetuning) finetuning = openai_fine_tuning.retrieve_finetuning_job() if finetuning.status != FineTuningStatus.SUCCEEDED.value: raise FinetuningNotAvailableError( @@ -595,6 +595,16 @@ def generate_response( f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries." ) self.database = SQLDatabase.get_sql_engine(database_connection) + if self.system.settings["azure_api_key"] is not None: + embedding = AzureOpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) + else: + embedding = OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) toolkit = SQLDatabaseToolkit( db=self.database, instructions=instructions, @@ -605,10 +615,7 @@ def generate_response( use_finetuned_model_only=self.use_fintuned_model_only, model_name=finetuning.base_llm.model_name, openai_fine_tuning=openai_fine_tuning, - embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure - openai_api_key=database_connection.decrypt_api_key(), - model=EMBEDDING_MODEL, - ), + embedding=embedding, ) agent_executor = self.create_sql_agent( toolkit=toolkit, @@ -693,7 +700,7 @@ def stream_response( ) finetunings_repository = FinetuningsRepository(storage) finetuning = finetunings_repository.find_by_id(self.finetuning_id) - openai_fine_tuning = OpenAIFineTuning(storage, finetuning) + openai_fine_tuning = OpenAIFineTuning(self.system, storage, finetuning) finetuning = openai_fine_tuning.retrieve_finetuning_job() if finetuning.status != FineTuningStatus.SUCCEEDED.value: raise FinetuningNotAvailableError( @@ -701,6 +708,16 @@ def stream_response( f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries." ) self.database = SQLDatabase.get_sql_engine(database_connection) + if self.system.settings["azure_api_key"] is not None: + embedding = AzureOpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) + else: + embedding = OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) toolkit = SQLDatabaseToolkit( db=self.database, instructions=instructions, @@ -710,10 +727,7 @@ def stream_response( use_finetuned_model_only=self.use_fintuned_model_only, model_name=finetuning.base_llm.model_name, openai_fine_tuning=openai_fine_tuning, - embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure - openai_api_key=database_connection.decrypt_api_key(), - model=EMBEDDING_MODEL, - ), + embedding=embedding, ) agent_executor = self.create_sql_agent( toolkit=toolkit, diff --git a/services/engine/dataherald/sql_generator/dataherald_sqlagent.py b/services/engine/dataherald/sql_generator/dataherald_sqlagent.py index a9635ff5..414ab089 100644 --- a/services/engine/dataherald/sql_generator/dataherald_sqlagent.py +++ b/services/engine/dataherald/sql_generator/dataherald_sqlagent.py @@ -22,7 +22,7 @@ from langchain.chains.llm import LLMChain from langchain.tools.base import BaseTool from langchain_community.callbacks import get_openai_callback -from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from overrides import override from pydantic import BaseModel, Field from sql_metadata import Parser @@ -710,13 +710,13 @@ def create_sql_agent( ) @override - def generate_response( + def generate_response( # noqa: PLR0912 self, user_prompt: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, metadata: dict = None, - ) -> SQLGeneration: + ) -> SQLGeneration: # noqa: PLR0912 context_store = self.system.instance(ContextStore) storage = self.system.instance(DB) response = SQLGeneration( @@ -753,8 +753,8 @@ def generate_response( number_of_samples = 0 logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}") self.database = SQLDatabase.get_sql_engine(database_connection) - #Set Embeddings class depending on azure / not azure - if self.llm.openai_api_type == "azure": + # Set Embeddings class depending on azure / not azure + if self.system.settings["azure_api_key"] is not None: toolkit = SQLDatabaseToolkit( db=self.database, context=context, @@ -873,21 +873,17 @@ def stream_response( new_fewshot_examples = None number_of_samples = 0 self.database = SQLDatabase.get_sql_engine(database_connection) - #Set Embeddings class depending on azure / not azure - if self.llm.openai_api_type == "azure": - toolkit = SQLDatabaseToolkit( - db=self.database, - context=context, - few_shot_examples=new_fewshot_examples, - instructions=instructions, - is_multiple_schema=True if user_prompt.schemas else False, - db_scan=db_scan, - embedding=AzureOpenAIEmbeddings( - openai_api_key=database_connection.decrypt_api_key(), - model=EMBEDDING_MODEL, - ), + # Set Embeddings class depending on azure / not azure + if self.system.settings["azure_api_key"] is not None: + embedding = AzureOpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) + else: + embedding = OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, ) - else: toolkit = SQLDatabaseToolkit( queuer=queue, db=self.database, @@ -896,10 +892,7 @@ def stream_response( instructions=instructions, is_multiple_schema=True if user_prompt.schemas else False, db_scan=db_scan, - embedding=OpenAIEmbeddings( - openai_api_key=database_connection.decrypt_api_key(), - model=EMBEDDING_MODEL, - ), + embedding=embedding, ) agent_executor = self.create_sql_agent( toolkit=toolkit, diff --git a/services/enterprise/exceptions/exception_handlers.py b/services/enterprise/exceptions/exception_handlers.py index 658e5b1f..cf599533 100644 --- a/services/enterprise/exceptions/exception_handlers.py +++ b/services/enterprise/exceptions/exception_handlers.py @@ -13,7 +13,6 @@ async def exception_handler(request: Request, exc: BaseError): # noqa: ARG001 - trace_id = exc.trace_id error_code = exc.error_code status_code = exc.status_code diff --git a/services/enterprise/exceptions/exceptions.py b/services/enterprise/exceptions/exceptions.py index 32480d67..a1ad997f 100644 --- a/services/enterprise/exceptions/exceptions.py +++ b/services/enterprise/exceptions/exceptions.py @@ -39,7 +39,6 @@ def __init__( description: str = None, detail: dict = None, ) -> None: - if type(self) is BaseError: raise TypeError("BaseError class may not be instantiated directly") diff --git a/services/enterprise/modules/db_connection/controller.py b/services/enterprise/modules/db_connection/controller.py index d104182f..750e4e40 100644 --- a/services/enterprise/modules/db_connection/controller.py +++ b/services/enterprise/modules/db_connection/controller.py @@ -95,7 +95,6 @@ async def ac_get_db_connection( id: ObjectIdString, user: User = Security(authenticate_user), ) -> DBConnectionResponse: - return db_connection_service.get_db_connection(id, user.organization_id) diff --git a/services/enterprise/modules/db_connection/service.py b/services/enterprise/modules/db_connection/service.py index 3c482675..4c924dc9 100644 --- a/services/enterprise/modules/db_connection/service.py +++ b/services/enterprise/modules/db_connection/service.py @@ -84,7 +84,9 @@ async def add_db_connection( ) if organization.llm_api_key: - db_connection_internal_request.llm_api_key = FernetEncrypt().decrypt(organization.llm_api_key) + db_connection_internal_request.llm_api_key = FernetEncrypt().decrypt( + organization.llm_api_key + ) if db_connection_request.use_ssh: db_connection_internal_request.ssh_settings.private_key_password = ( @@ -140,8 +142,9 @@ async def update_db_connection( ) if organization.llm_api_key: - db_connection_internal_request.llm_api_key = FernetEncrypt().decrypt(organization.llm_api_key) - + db_connection_internal_request.llm_api_key = FernetEncrypt().decrypt( + organization.llm_api_key + ) if db_connection_request.use_ssh: db_connection_internal_request.ssh_settings.private_key_password = ( @@ -163,7 +166,6 @@ async def update_db_connection( async def add_sample_db_connection( self, sample_request: SampleDBRequest, org_id: str ) -> DBConnectionResponse: - sample_db_dict = await self.sample_db.add_sample_db( sample_request.sample_db_id, org_id ) diff --git a/services/enterprise/modules/organization/invoice/service.py b/services/enterprise/modules/organization/invoice/service.py index 6cdf2b27..8b8bb6b1 100644 --- a/services/enterprise/modules/organization/invoice/service.py +++ b/services/enterprise/modules/organization/invoice/service.py @@ -84,12 +84,12 @@ def update_spending_limit( raise CannotUpdateSpendingLimitError(org_id) def get_pending_invoice(self, org_id: str) -> InvoiceResponse: - organization = self.org_repo.get_organization(org_id) - current_period_start, current_period_end = ( - self.billing.get_current_subscription_period_with_anchor( - organization.invoice_details.billing_cycle_anchor - ) + ( + current_period_start, + current_period_end, + ) = self.billing.get_current_subscription_period_with_anchor( + organization.invoice_details.billing_cycle_anchor ) upcoming_invoice = self.billing.get_upcoming_invoice( organization.invoice_details.stripe_customer_id @@ -304,10 +304,11 @@ def check_usage( ): raise SubscriptionCanceledError(org_id) raise UnknownSubscriptionStatusError(org_id) - start_date, end_date = ( - self.billing.get_current_subscription_period_with_anchor( - organization.invoice_details.billing_cycle_anchor - ) + ( + start_date, + end_date, + ) = self.billing.get_current_subscription_period_with_anchor( + organization.invoice_details.billing_cycle_anchor ) usages = self.repo.get_usages(org_id, start_date, end_date) usage = Usage(