Skip to content

Commit

Permalink
Improving the router to use the context
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugo Saporetti Junior committed Apr 6, 2024
1 parent a3fe2b3 commit ba2972f
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 55 deletions.
2 changes: 1 addition & 1 deletion docs/devel/askai-questions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ pt_BR

Composite:

1. list my music and let me know if there is any ac/dc song. If so, show me the file name.
1. list my music and let me know if there is any ac/dc song. If so, show me the file name and open it.
2. open the first reminder file you find at my downloads.

18 changes: 8 additions & 10 deletions src/main/askai/core/askai.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
from askai.core.support.langchain_support import lc_llm
from askai.core.support.shared_instances import shared
from askai.core.support.utilities import display_text, read_stdin
from askai.exception.exceptions import ImpossibleQuery, UnintelligibleQuery, TerminatingQuery, MaxInteractionsReached
from askai.exception.exceptions import ImpossibleQuery, UnintelligibleQuery, TerminatingQuery, MaxInteractionsReached, \
InaccurateResponse


class AskAi:
Expand Down Expand Up @@ -225,26 +226,23 @@ def _ask_and_reply(self, question: str) -> bool:
"""Ask the question and provide the reply.
:param question: The question to ask to the AI engine.
"""
status = False
status = True
try:
if not (reply := cache.read_reply(question)):
log.debug('Response not found for "%s" in cache. Querying from %s.', question, self.engine.nickname())
AskAiEvents.ASKAI_BUS.events.reply.emit(message=msg.wait())
if (output := router.process(question)) and output.response:
if output := router.process(question):
self.reply(output.response)
else:
log.debug("Reply found for '%s' in cache.", question)
self.reply(reply)
shared.context.forget()
status = True
except (NotImplementedError, ImpossibleQuery, UnintelligibleQuery) as err:
self.reply_error(str(err))
status = True
except MaxInteractionsReached as err:
self.reply_error(str(err))
except (MaxInteractionsReached, InaccurateResponse) as err:
self.reply_error(msg.unprocessable(str(err)))
except TerminatingQuery:
pass
status = False

return status

def _get_query_string(self, interactive: bool, query_arg: str | list[str]) -> None:
Expand Down
8 changes: 4 additions & 4 deletions src/main/askai/core/askai_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ def press_esc_enter(self) -> str:

# Warnings and alerts

@lru_cache
def exec_result(self, exit_code: ExitStatus) -> str:
return self.translate(f"Execution result: `{exit_code}`")

@lru_cache
def search_empty(self) -> str:
return self.translate("The search didn't return an output !")
Expand Down Expand Up @@ -143,5 +139,9 @@ def fail_to_search(self, error: str) -> str:
def too_many_actions(self) -> str:
return self.translate("Failed to complete the request => 'Max chained actions reached' !")

@lru_cache
def unprocessable(self, reason: str) -> str:
return self.translate(f"Sorry, I was unable to process your request => {reason}")


assert (msg := AskAiMessages().INSTANCE) is not None
63 changes: 32 additions & 31 deletions src/main/askai/core/proxy/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
from typing import TypeAlias, Optional

from hspylib.core.metaclass.singleton import Singleton
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.globals import set_llm_cache
from langchain_community.cache import InMemoryCache
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import Runnable
from langchain_core.runnables.utils import Input, Output
from retry import retry
Expand Down Expand Up @@ -40,27 +38,49 @@ def __init__(self):
def template(self) -> str:
return prompt.read_prompt("router-prompt.txt")

@staticmethod
def _assert_accuracy(question: str, ai_response: str) -> None:
"""Function responsible for asserting that the question was properly answered."""
if ai_response:
template = PromptTemplate(input_variables=[
'question', 'response'
], template=prompt.read_prompt('rag-prompt'))
final_prompt = template.format(
question=question, response=ai_response or '')
llm = lc_llm.create_chat_model(Temperature.DATA_ANALYSIS.temp)
if (output := llm.predict(final_prompt)) and (mat := RagResponse.matches(output)):
status, reason = mat.group(1), mat.group(2)
log.info("Accuracy status: '%s' reason: '%s'", status, reason)
AskAiEvents.ASKAI_BUS.events.reply.emit(message=msg.assert_acc(output), verbosity='debug')
if RagResponse.of_value(status.strip()).is_bad:
raise InaccurateResponse(f"The RAG response was not 'Green' => '{output}' ")
return

raise InaccurateResponse(f"The RAG response was not 'Green'")

def process(self, query: str) -> Optional[str]:
"""Process the user query and retrieve the final response."""
@retry(exceptions=InaccurateResponse, tries=2, delay=0, backoff=0)
def _process_wrapper(question: str) -> Optional[str]:
"""Wrapper to allow RAG retries."""
template = PromptTemplate(input_variables=[], template=self.template())
final_prompt = template.format(features=features.enlist(), objective=question)
ctx: str = shared.context.flat("OUTPUT", "ANALYSIS", "INTERNET", "GENERAL")
log.info("Router::[QUESTION] '%s' context: '%s'", question, ctx)
chat_prompt = ChatPromptTemplate.from_messages([("system", "{query}\n\n{context}")])
chain = create_stuff_documents_chain(lc_llm.create_chat_model(), chat_prompt)
context = [Document(ctx)]
if response := chain.invoke({"query": final_prompt, "context": context}):
template = PromptTemplate(input_variables=[
'features', 'context', 'objective'
], template=self.template())
context: str = shared.context.flat("OUTPUT", "ANALYSIS", "INTERNET", "GENERAL")
final_prompt = template.format(
features=features.enlist(), context=context or 'Nothing yet', objective=question)
log.info("Router::[QUESTION] '%s' context: '%s'", question, context)
llm = lc_llm.create_chat_model(Temperature.DATA_ANALYSIS.temp)

if response := llm.predict(final_prompt):
log.info("Router::[RESPONSE] Received from AI: \n%s.", str(response))
output = self._route(question, re.sub(r'\d+[.:)-]\s+', '', response))
else:
output = response
return output

return _process_wrapper(query)

@lru_cache
def _route(self, question: str, action_plan: str) -> Optional[str]:
"""Route the actions to the proper function invocations."""
set_llm_cache(InMemoryCache())
Expand All @@ -78,24 +98,5 @@ def _route(self, question: str, action_plan: str) -> Optional[str]:

return self._assert_accuracy(question, result)

@lru_cache
def _assert_accuracy(self, question: str, ai_response: str) -> None:
"""Function responsible for asserting that the question was properly answered."""
if ai_response:
template = PromptTemplate(
input_variables=['question', 'response'],
template=prompt.read_prompt('rag-prompt'))
final_prompt = template.format(question=question, response=ai_response or '')
llm = lc_llm.create_chat_model(Temperature.DATA_ANALYSIS.temp)
if (output := llm.predict(final_prompt)) and (mat := RagResponse.matches(output)):
status, reason = mat.group(1), mat.group(2)
log.info("Accuracy status: '%s' reason: '%s'", status, reason)
AskAiEvents.ASKAI_BUS.events.reply.emit(message=msg.assert_acc(output), verbosity='debug')
if RagResponse.of_value(status.strip()).is_bad:
raise InaccurateResponse(f"The RAG response was not 'Green' => '{output}' ")
return

raise InaccurateResponse(f"The RAG response was not 'Green'")


assert (router := Router().INSTANCE) is not None
8 changes: 4 additions & 4 deletions src/main/askai/core/proxy/tools/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def check_output(question: str, context: str) -> Optional[str]:
final_prompt = template.format(context=context, question=question)

if output := llm.predict(final_prompt):
shared.context.set("ANALYSIS", f"\nUser:\n{question}")
shared.context.push("ANALYSIS", f"\nAI:\n{output}", "assistant")
shared.context.set("ANALYSIS", f"\nUser: {question}")
shared.context.push("ANALYSIS", f"\nAI: {output}", "assistant")

return text_formatter.ensure_ln(output)

Expand All @@ -42,7 +42,7 @@ def stt(question: str, existing_answer: str) -> str:
llm = lc_llm.create_chat_model(temperature=Temperature.CREATIVE_WRITING.temp)

if output := llm.predict(final_prompt):
shared.context.set("ANALYSIS", f"\nUser:\n{question}")
shared.context.push("ANALYSIS", f"\nAI:\n{output}", "assistant")
shared.context.set("ANALYSIS", f"\nUser: {question}")
shared.context.push("ANALYSIS", f"\nAI: {output}", "assistant")

return text_formatter.ensure_ln(output)
6 changes: 2 additions & 4 deletions src/main/askai/core/proxy/tools/terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,16 @@ def _execute_shell(command_line: str) -> Tuple[bool, Optional[str]]:
output, exit_code = Terminal.INSTANCE.shell_exec(command, shell=True)
if exit_code == ExitStatus.SUCCESS:
log.info("Command succeeded.\nCODE=%s \nPATH: %s \nCMD: %s ", exit_code, os.getcwd(), command)
AskAiEvents.ASKAI_BUS.events.reply.emit(message=msg.cmd_success(command_line, exit_code), verbosity='debug')
if _path_ := extract_path(command):
os.chdir(_path_)
log.info("Current directory changed to '%s'", _path_)
else:
log.warning("Directory '%s' does not exist. Current dir unchanged!", _path_)
if not output:
output = msg.exec_result(exit_code)
output = msg.cmd_success(command_line, exit_code)
else:
output = f"\n```bash\n{output}\n```"
shared.context.set("OUTPUT", f"\n\nUser:\nCommand `{command_line}' output:")
shared.context.push("OUTPUT", f"\nAI:{output}", "assistant")
shared.context.set("OUTPUT", f"\nUser: Command `{command_line}' output:\n{output}")
status = True
else:
log.error("Command failed.\nCODE=%s \nPATH=%s \nCMD=%s ", exit_code, os.getcwd(), command)
Expand Down
6 changes: 5 additions & 1 deletion src/main/askai/resources/assets/prompts/router-prompt.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ As the interface with your computer, you have the following features:

{features}

Given the question at the end, Your task is to review it and provide a structured list of actions employing one or more of the specified features. If the prompt calls for multiple features, delineate the necessary steps in the order required to meet the request. For each feature, you must include the corresponding command associated with that feature.
Use the following context to answer the question at the end:

'''{context}'''

Your task is to review it and provide a structured list of actions employing one or more of the specified features. If the prompt calls for multiple features, delineate the necessary steps in the order required to meet the request. For each feature, you must include the corresponding command associated with that feature.

If you encounter any challenges understanding the query due to ambiguity, context dependency, or lack of clarity, please refer to the command output for clarification. Pay attention to file or folder names, mentions, and contents like 'file contents', 'folder listing', 'dates', and 'spoken names' to disambiguate effectively.

Expand Down

0 comments on commit ba2972f

Please sign in to comment.