Skip to content

Commit

Permalink
DH-4690/reformat with black
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Sep 19, 2023
1 parent 8813287 commit 256f966
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
CallbackManagerForToolRun,
)
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction
from langchain.tools.base import BaseTool
from overrides import override
Expand Down Expand Up @@ -101,6 +105,7 @@
{sql_query}
"""


def catch_exceptions(): # noqa: C901
def decorator(fn: Callable[[str], str]) -> Callable[[str], str]: # noqa: C901
@wraps(fn)
Expand Down Expand Up @@ -191,12 +196,18 @@ def _run(
query: str, # noqa: ARG002
run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002
) -> str:
human_message_prompt = HumanMessagePromptTemplate.from_template(POST_PROCESSING_HUMAN_TEMPLATE)
system_message_prompt = SystemMessagePromptTemplate.from_template(POST_PROCESSING_SYSTEM_TEMPLATE)
chat_prompt_template = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
human_message_prompt = HumanMessagePromptTemplate.from_template(
POST_PROCESSING_HUMAN_TEMPLATE
)
system_message_prompt = SystemMessagePromptTemplate.from_template(
POST_PROCESSING_SYSTEM_TEMPLATE
)
chat_prompt_template = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)
chain = LLMChain(llm=self.llm, prompt=chat_prompt_template)
return chain.run(
instructions= ",".join(SQL_QUERY_INSTRUCTIONS),
instructions=",".join(SQL_QUERY_INSTRUCTIONS),
sql_query=query,
)

Expand Down Expand Up @@ -501,7 +512,9 @@ class Config:
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
tools = []
post_processing_tool = PostProcessingTool(db=self.db, context=self.context, llm=self.llm)
post_processing_tool = PostProcessingTool(
db=self.db, context=self.context, llm=self.llm
)
tools.append(post_processing_tool)
query_sql_db_tool = QuerySQLDataBaseTool(db=self.db, context=self.context)
tools.append(query_sql_db_tool)
Expand Down

0 comments on commit 256f966

Please sign in to comment.