Skip to content

Commit

Permalink
Structured action routing - working version
Browse files Browse the repository at this point in the history
  • Loading branch information
yorevs committed Apr 19, 2024
1 parent 17ffc27 commit 97c275d
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 48 deletions.
17 changes: 7 additions & 10 deletions src/main/askai/core/features/actions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import inspect
import re
from functools import lru_cache, cached_property
from textwrap import dedent
from typing import Optional
from typing import Optional, Any

from clitt.core.tui.line_input.line_input import line_input
from hspylib.core.metaclass.singleton import Singleton
Expand All @@ -14,6 +13,7 @@
from askai.core.features.tools.generation import generate_content
from askai.core.features.tools.summarization import summarize
from askai.core.features.tools.terminal import execute_command, list_contents, open_command
from askai.core.model.action_plan import ActionPlan
from askai.exception.exceptions import ImpossibleQuery, TerminatingQuery


Expand All @@ -34,21 +34,18 @@ def __init__(self):
def tool_names(self) -> list[str]:
return [str(dk) for dk in self._all.keys()]

def invoke(self, tool: str, context: str = '') -> Optional[str]:
def invoke(self, action: ActionPlan.Action, context: str = '') -> Optional[str]:
"""Invoke the tool with its arguments and context.
:param tool: The tool to be performed.
:param action: The action to be performed.
:param context: the tool context.
"""
fn_name = None
try:
if tool_fn := re.findall(r'([a-zA-Z]\w+)\s*\((.*)\)', tool.strip()):
fn_name = tool_fn[0][0].lower()
fn = self._all[fn_name]
args: list[str] = re.split(r'(?!\\),', tool_fn[0][1], re.MULTILINE)
if fn := self._all[action.tool]:
args: list[Any] = action.params
args.append(context)
return fn(*list(map(str.strip, args)))
except KeyError as err:
raise ImpossibleQuery(f"Tool not found: {fn_name} => {str(err)}")
raise ImpossibleQuery(f"Tool not found: {action.tool} => {str(err)}")

return None

Expand Down
24 changes: 13 additions & 11 deletions src/main/askai/core/features/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.features.actions import features
from askai.core.features.tools.analysis import assert_accuracy
from askai.core.features.tools.general import final_answer
from askai.core.features.tools.analysis import assert_accuracy, ASSERT_MSG
from askai.core.model.action_plan import ActionPlan
from askai.core.support.langchain_support import lc_llm
from askai.core.support.object_mapper import object_mapper
from askai.core.support.shared_instances import shared
from askai.exception.exceptions import InaccurateResponse, MaxInteractionsReached

Expand Down Expand Up @@ -68,27 +69,27 @@ def _process_wrapper(question: str) -> Optional[str]:
)
if response := chain.invoke({"input": query}, config={"configurable": {"session_id": "HISTORY"}}):
log.info("Router::[RESPONSE] Received from AI: \n%s.", str(response))
# actions: list[str] = re.sub(r'\d+[.:)-]\s+', '', response.content).split(os.linesep)
# actions = list(filter(len, map(str.strip, actions)))
# AskAiEvents.ASKAI_BUS.events.reply.emit(
# message=msg.action_plan('` -> `'.join(actions)), verbosity='debug')
# output = self._route(question, actions)
return ''
action_plan: ActionPlan = object_mapper.of_json(response.content, ActionPlan)
if not isinstance(action_plan, ActionPlan):
raise InaccurateResponse(ASSERT_MSG.substitute(reason='Invalid Json Format'))
AskAiEvents.ASKAI_BUS.events.reply.emit(
message=msg.action_plan(str(action_plan)), verbosity='debug')
output = self._route(question, action_plan.actions)
else:
output = response
return output

return _process_wrapper(query)

def _route(self, query: str, actions: list[str]) -> str:
def _route(self, query: str, actions: list[ActionPlan.Action]) -> str:
"""Route the actions to the proper function invocations.
:param query: The user query to complete.
"""
last_result: str = ''
accumulated: list[str] = []
for idx, action in enumerate(actions):
AskAiEvents.ASKAI_BUS.events.reply.emit(message=f"> `{action}`", verbosity='debug')
AskAiEvents.ASKAI_BUS.events.reply.emit(message=f"> `{action.tool}({', '.join(action.params)})`", verbosity='debug')
if idx > self.MAX_REQUESTS:
AskAiEvents.ASKAI_BUS.events.reply_error.emit(message=msg.too_many_actions())
raise MaxInteractionsReached(f"Maximum number of action was reached")
Expand All @@ -107,7 +108,8 @@ def _final_answer(self, question: str, response: str) -> str:
:param response: The AI response.
"""
# TODO For now we are just using Taius, but we can opt to use Taius, STT, no persona, or custom.
return final_answer(question, response)
# return final_answer(question, response)
return response


assert (router := Router().INSTANCE) is not None
8 changes: 4 additions & 4 deletions src/main/askai/core/features/tools/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from askai.exception.exceptions import InaccurateResponse

ASSERT_MSG: Template = Template(
"You (AI) provided an unsatisfactory answer. Please enhance your response in the next attempt. \n\n"
"Address these issues: \n\n"
"'{problems}'\n\n"
"(Remember to use at least one tool in your response, adhere to tool syntax, and refrain from direct responses).\n"
"You (AI) provided an unsatisfactory answer. Improve your response in the next attempt. \n"
"Address these issues: \n"
"'{problems}'\n"
"(Reminder to always respond with a valid JSON blob of an action plan. Use the necessary tools).\n"
)


Expand Down
39 changes: 19 additions & 20 deletions src/main/askai/core/features/tools/generation.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,46 @@
import logging as log
from textwrap import dedent
from os.path import dirname
from pathlib import Path

from hspylib.core.preconditions import check_not_none
from langchain_core.messages import AIMessage
from langchain_core.prompts import PromptTemplate

from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.component.cache_service import cache
from askai.core.engine.openai.temperature import Temperature
from askai.core.support.langchain_support import lc_llm
from askai.core.support.shared_instances import shared

GENERATION_TEMPLATE = PromptTemplate(
input_variables=[
'mime_type', 'input'
], template=dedent(
"""
You are a highly sophisticated GPT tailored for creating '{mime_type}' content. Please be as accurate as possible while
creating the content. Ensure that your content has a good quality. To create the content use the following prompt:

{input}
"""
))


def generate_content(prompt: str, mime_type: str, path_name: str) -> str:
def generate_content(description: str, mime_type: str, path_name: str | None) -> str:
"""Display the given texts formatted with markdown.
:param prompt: Specify the prompt to be used to generate the content.
:param description: Specify the prompt to be used to generate the content.
:param mime_type: Specify the content type and format using MIME types.
:param path_name: Specify the directory path where you want to save the generated content.
"""
check_not_none((prompt, mime_type, path_name))
check_not_none((description, mime_type, path_name))
output = None
template = GENERATION_TEMPLATE
template = PromptTemplate(
input_variables=[
'mime_type', 'input'
], template=prompt.read_prompt('generator-prompt')
)
final_prompt = template.format(
mime_type=mime_type, input=input)
mime_type=mime_type, input=description)

log.info("GENERATE::[PROMPT] '%s' Type: '%s' Path: '%s'", prompt, mime_type, path_name)
log.info("GENERATE::[PROMPT] '%s' Type: '%s' Path: '%s'", description, mime_type, path_name)
llm = lc_llm.create_chat_model(temperature=Temperature.CODE_GENERATION.temp)
response: AIMessage = llm.invoke(final_prompt)

if response and (output := response.content) and shared.UNCERTAIN_ID not in response.content:
if response and (output := response.content):
shared.context.push("HISTORY", output, 'assistant')
if path_name:
base_dir = Path(dirname(path_name))
if base_dir.is_dir() and base_dir.exists():
with open(path_name, 'w') as f_path_name:
f_path_name.write(output)
cache.save_reply(prompt, output)
else:
output = msg.translate("Sorry, I don't know.")
Expand Down
2 changes: 1 addition & 1 deletion src/main/askai/core/features/tools/terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def open_command(pathname: str) -> str:
match media_type_of(pathname):
case ('audio', _) | ('video', _):
fn_open = partial(_execute_bash, f'ffplay -v 0 -autoexit {pathname} &>/dev/null')
case ('text', 'plain'):
case ('text', _):
fn_open = partial(_execute_bash, f'echo "File \\`{pathname}\\`: \n" && cat {pathname}')
case _:
fn_open = partial(_execute_bash, f'open {pathname} 2>/dev/null')
Expand Down
39 changes: 39 additions & 0 deletions src/main/askai/core/model/action_plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
@project: HsPyLib-AskAI
@package: askai.core.model
@file: action_plan.py
@created: Fri, 19 Apr 2024
@author: <B>H</B>ugo <B>S</B>aporetti <B>J</B>unior"
@site: https://github.com/yorevs/hspylib
@license: MIT - Please refer to <https://opensource.org/licenses/MIT>
Copyright·(c)·2024,·HSPyLib
"""
import json
from dataclasses import dataclass
from functools import cached_property
from types import SimpleNamespace
from typing import Any


@dataclass(frozen=True)
class ActionPlan:
"""Keep track of the router action plan."""

@dataclass
class Action:
"""Represents a single action."""
tool: str
params: list[Any]

plan: list[SimpleNamespace] = None

@cached_property
def actions(self) -> list[Action]:
return [self.Action(a.action, a.inputs) for a in self.plan]

def __str__(self):
return f"Action Plan: {json.dumps(self.__dict__, default=lambda obj: obj.__dict__)}"
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Use the following context to answer the question at the end.

Ensure important details are included, specially when mentioning files, folders, line numbers, etc..

When the analysis includes a list, prefer rendering numbered than bulleted lists.
When the response includes a list, prefer rendering numbered than bulleted lists.

If you lack an answer, respond with: 'bde6f44d-c1a0-4b0c-bd74-8278e468e50c'. No further clarifications or details are necessary.

Expand Down
10 changes: 10 additions & 0 deletions src/main/askai/resources/assets/prompts/generator-prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
You are a highly sophisticated GPT tailored for creating content of the specified type '{mime_type}'.

- Create the content with high quality.
- If the content is a code, adhere to any applicable linter rules and include the correct hashbang if necessary.
- Use plain text; do not format the response using markdown.
- Do not provide additional comments.

Begin! Create the content using the following instructions:

{input}
2 changes: 1 addition & 1 deletion src/main/askai/resources/assets/prompts/router-prompt.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ ASSISTANT:
{{{{
"action": "open_command",
"inputs": [
"%cross_referenced_filenames%"
"%first_reminder_found%"
]
}}}},
{{{{
Expand Down

0 comments on commit 97c275d

Please sign in to comment.