From 256f966195ce95e6d5c6c7fbb4a0f802010c602d Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Tue, 19 Sep 2023 10:39:59 -0400 Subject: [PATCH] DH-4690/reformat with black --- .../sql_generator/dataherald_sqlagent.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index d095b0c8..12112519 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -19,7 +19,11 @@ CallbackManagerForToolRun, ) from langchain.chains.llm import LLMChain -from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) from langchain.schema import AgentAction from langchain.tools.base import BaseTool from overrides import override @@ -101,6 +105,7 @@ {sql_query} """ + def catch_exceptions(): # noqa: C901 def decorator(fn: Callable[[str], str]) -> Callable[[str], str]: # noqa: C901 @wraps(fn) @@ -191,12 +196,18 @@ def _run( query: str, # noqa: ARG002 run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: - human_message_prompt = HumanMessagePromptTemplate.from_template(POST_PROCESSING_HUMAN_TEMPLATE) - system_message_prompt = SystemMessagePromptTemplate.from_template(POST_PROCESSING_SYSTEM_TEMPLATE) - chat_prompt_template = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) + human_message_prompt = HumanMessagePromptTemplate.from_template( + POST_PROCESSING_HUMAN_TEMPLATE + ) + system_message_prompt = SystemMessagePromptTemplate.from_template( + POST_PROCESSING_SYSTEM_TEMPLATE + ) + chat_prompt_template = ChatPromptTemplate.from_messages( + [system_message_prompt, human_message_prompt] + ) chain = LLMChain(llm=self.llm, prompt=chat_prompt_template) return chain.run( - instructions= ",".join(SQL_QUERY_INSTRUCTIONS), + instructions=",".join(SQL_QUERY_INSTRUCTIONS), sql_query=query, ) @@ -501,7 +512,9 @@ class Config: def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" tools = [] - post_processing_tool = PostProcessingTool(db=self.db, context=self.context, llm=self.llm) + post_processing_tool = PostProcessingTool( + db=self.db, context=self.context, llm=self.llm + ) tools.append(post_processing_tool) query_sql_db_tool = QuerySQLDataBaseTool(db=self.db, context=self.context) tools.append(query_sql_db_tool)