From 459fb7f82a24aa14fe9803f6ce1079ff7ba699c8 Mon Sep 17 00:00:00 2001 From: Mohammadreza Pourreza <71866535+MohammadrezaPourreza@users.noreply.github.com> Date: Fri, 15 Sep 2023 15:29:08 -0400 Subject: [PATCH] Dh4670/agent bug with error handling (#165) * DH4670/solving agent bug with handling errors * DH4670/reformat with black * Reduce try attempts --------- Co-authored-by: Juan Carlos Jose Camacho --- .../sql_generator/dataherald_sqlagent.py | 95 +++++++++---------- dataherald/types.py | 2 +- 2 files changed, 47 insertions(+), 50 deletions(-) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 587cfd3a..63fe314e 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -29,7 +29,7 @@ from dataherald.db import DB from dataherald.db_scanner.models.types import TableSchemaDetail from dataherald.db_scanner.repository.base import DBScannerRepository -from dataherald.sql_database.base import SQLDatabase +from dataherald.sql_database.base import SQLDatabase, SQLInjectionError from dataherald.sql_database.models.types import ( DatabaseConnection, ) @@ -84,43 +84,34 @@ Thought: I should Collect examples of Question/SQL pairs to identify possibly relevant tables, columns, and SQL query styles. If there is a similar question among the examples, I can use the SQL query from the example and modify it to fit the given question. {agent_scratchpad}""" # noqa: E501 -MAX_HANDLING_EXCEPTIONS_FOR_EACH_TOOL = 5 - -def catch_exceptions(max_handling_list): # noqa: C901 +def catch_exceptions(): # noqa: C901 def decorator(fn: Callable[[str], str]) -> Callable[[str], str]: # noqa: C901 @wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: PLR0911 - nonlocal max_handling_list - max_handling = max_handling_list[0] - max_handling -= 1 - max_handling_list[0] = max_handling - if max_handling < 0: + try: return fn(*args, **kwargs) - else: # noqa: RET505 - try: - return fn(*args, **kwargs) - except openai.error.APIError as e: - # Handle API error here, e.g. retry or log - return f"OpenAI API returned an API Error: {e}" - except openai.error.APIConnectionError as e: - # Handle connection error here - return f"Failed to connect to OpenAI API: {e}" - except openai.error.RateLimitError as e: - # Handle rate limit error (we recommend using exponential backoff) - return f"OpenAI API request exceeded rate limit: {e}" - except openai.error.Timeout as e: - # Handle timeout error (we recommend using exponential backoff) - return f"OpenAI API request timed out: {e}" - except openai.error.ServiceUnavailableError as e: - # Handle service unavailable error (we recommend using exponential backoff) - return f"OpenAI API service unavailable: {e}" - except openai.error.InvalidRequestError as e: - return f"OpenAI API request was invalid: {e}" - except GoogleAPIError as e: - return f"Google API returned an error: {e}" - except SQLAlchemyError as e: - return f"Error: {e}" + except openai.error.APIError as e: + # Handle API error here, e.g. retry or log + return f"OpenAI API returned an API Error: {e}" + except openai.error.APIConnectionError as e: + # Handle connection error here + return f"Failed to connect to OpenAI API: {e}" + except openai.error.RateLimitError as e: + # Handle rate limit error (we recommend using exponential backoff) + return f"OpenAI API request exceeded rate limit: {e}" + except openai.error.Timeout as e: + # Handle timeout error (we recommend using exponential backoff) + return f"OpenAI API request timed out: {e}" + except openai.error.ServiceUnavailableError as e: + # Handle service unavailable error (we recommend using exponential backoff) + return f"OpenAI API service unavailable: {e}" + except openai.error.InvalidRequestError as e: + return f"OpenAI API request was invalid: {e}" + except GoogleAPIError as e: + return f"Google API returned an error: {e}" + except SQLAlchemyError as e: + return f"Error: {e}" return wrapper @@ -149,9 +140,8 @@ class GetCurrentTimeTool(BaseSQLDatabaseTool, BaseTool): Input is an empty string, output is the current data and time. Always use this tool before generating a query if there is any time or date in the given question. """ - max_handling: list = [MAX_HANDLING_EXCEPTIONS_FOR_EACH_TOOL] - @catch_exceptions(max_handling) + @catch_exceptions() def _run( self, tool_input: str = "", # noqa: ARG002 @@ -179,9 +169,8 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): If an error occurs, rewrite the query and retry. Use this tool to execute SQL queries. """ - max_handling: list = [MAX_HANDLING_EXCEPTIONS_FOR_EACH_TOOL] - @catch_exceptions(max_handling) + @catch_exceptions() def _run( self, query: str, @@ -208,7 +197,6 @@ class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): Use this tool to identify the relevant tables for the given question. """ db_scan: List[TableSchemaDetail] - max_handling: list = [MAX_HANDLING_EXCEPTIONS_FOR_EACH_TOOL] def get_embedding( self, text: str, model: str = "text-embedding-ada-002" @@ -221,7 +209,7 @@ def get_embedding( def cosine_similarity(self, a: List[float], b: List[float]) -> float: return round(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)), 4) - @catch_exceptions(max_handling) + @catch_exceptions() def _run( self, user_question: str, @@ -273,7 +261,6 @@ class ColumnEntityChecker(BaseSQLDatabaseTool, BaseTool): Example Input: table1 -> column2, entity """ - max_handling: list = [MAX_HANDLING_EXCEPTIONS_FOR_EACH_TOOL] def find_similar_strings( self, input_list: List[tuple], target_string: str, threshold=0.6 @@ -288,7 +275,7 @@ def find_similar_strings( similar_strings.sort(key=lambda x: x[1], reverse=True) return similar_strings[:25] - @catch_exceptions(max_handling) + @catch_exceptions() def _run( self, tool_input: str, @@ -324,9 +311,8 @@ class SchemaSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): Example Input: table1, table2, table3 """ db_scan: List[TableSchemaDetail] - max_handling: list = [MAX_HANDLING_EXCEPTIONS_FOR_EACH_TOOL] - @catch_exceptions(max_handling) + @catch_exceptions() def _run( self, table_names: str, @@ -362,9 +348,8 @@ class InfoRelevantColumns(BaseSQLDatabaseTool, BaseTool): Example Input: table1 -> column1, table1 -> column2, table2 -> column1 """ db_scan: List[TableSchemaDetail] - max_handling: list = [MAX_HANDLING_EXCEPTIONS_FOR_EACH_TOOL] - @catch_exceptions(max_handling) + @catch_exceptions() def _run( self, column_names: str, @@ -416,9 +401,8 @@ class GetFewShotExamples(BaseSQLDatabaseTool, BaseTool): Always use this tool first and before any other tool! """ # noqa: E501 few_shot_examples: List[dict] - max_handling: list = [MAX_HANDLING_EXCEPTIONS_FOR_EACH_TOOL] - @catch_exceptions(max_handling) + @catch_exceptions() def _run( self, number_of_samples: str, @@ -519,7 +503,7 @@ def create_sql_agent( input_variables: List[str] | None = None, max_examples: int = 20, top_k: int = 13, - max_iterations: int | None = 20, + max_iterations: int | None = 10, max_execution_time: float | None = None, early_stopping_method: str = "force", verbose: bool = False, @@ -597,7 +581,20 @@ def generate_response( agent_executor.return_intermediate_steps = True agent_executor.handle_parsing_errors = True with get_openai_callback() as cb: - result = agent_executor({"input": user_question.question}) + try: + result = agent_executor({"input": user_question.question}) + except SQLInjectionError as e: + raise SQLAlchemyError(e) from e + except Exception as e: + return NLQueryResponse( + nl_question_id=user_question.id, + total_tokens=cb.total_tokens, + total_cost=cb.total_cost, + sql_query="", + sql_generation_status="INVALID", + sql_query_result=None, + error_message=str(e), + ) sql_query_list = [] for step in result["intermediate_steps"]: action = step[0] diff --git a/dataherald/types.py b/dataherald/types.py index f235b39b..28d08e65 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -65,7 +65,7 @@ class NLQueryResponse(BaseModel): intermediate_steps: list[str] | None = None sql_query: str sql_query_result: SQLQueryResult | None - sql_generation_status: str = "NONE" + sql_generation_status: str = "INVALID" error_message: str | None exec_time: float | None = None total_tokens: int | None = None