Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tooling for agents #109

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Empty file added agents/__init__.py
Empty file.
52 changes: 52 additions & 0 deletions agents/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import inspect
from textwrap import dedent
from typing import List

from langchain.tools.base import StructuredTool

from agents.encoder import FunctionDefinition, Parameter


# This is temporary until we have a better way to represent parameters
def get_parameters_from_tool(tool: StructuredTool) -> List[Parameter]:
"""Convert a langchain tool to a tool user tool."""
schema = tool.args_schema.schema()

properties = schema["properties"]
parameters = []
# Is this needed or is string OK?
type_adapter = {
"string": "str", # str or string?
"integer": "int",
"number": "float",
"boolean": "bool",
}
for key, value in properties.items():
parameters.append(
{
"name": key,
"type": type_adapter.get(value["type"], value["type"]),
"description": value.get("description", ""),
}
)

return parameters


#
def convert_tool_to_function_definition(tool: StructuredTool) -> FunctionDefinition:
"""Convert a langchain tool to a tool user tool."""
# Here we re-inspect the underlying function to get the doc-string
# since StructuredTool modifies it, but we want the raw one for maximum
# flexibility.
description = inspect.getdoc(tool.func)

parameters = get_parameters_from_tool(tool)
return {
"name": tool.name,
"description": dedent(description),
"parameters": parameters,
"return_value": {
"type": "Any",
},
}
105 changes: 105 additions & 0 deletions agents/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import List, Literal, Sequence, Tuple, Union

from langchain.agents import AgentOutputParser
from langchain.prompts.chat import ChatPromptTemplate
from langchain.schema.messages import HumanMessage
from langchain.schema.runnable import Runnable
from langchain.tools import StructuredTool
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
from langchain_core.messages import BaseMessage
from langchain_core.prompts import MessagesPlaceholder
from typing_extensions import NotRequired, TypedDict

from agents.adapters import convert_tool_to_function_definition
from agents.encoder import AstPrinter, TypeScriptEncoder
from agents.prompts import AGENT_INSTRUCTIONS_BLOB_STYLE


def format_observation(tool_name: str, observation: str) -> BaseMessage:
"""Format the observation."""
result = (
"<tool_output>\n"
f"<tool_name>{tool_name}</tool_name>\n"
f"<output>{observation}</output>\n"
"</tool_output>"
)

return HumanMessage(content=result)


def format_steps_for_chat(
intermediate_steps: List[Tuple[AgentAction, str]]
) -> List[BaseMessage]:
"""Format the steps."""
messages = []
for action, observation in intermediate_steps:
if not isinstance(action, AgentAction):
if action.tool != "_Exception":
raise AssertionError(f"Unexpected step: {action}. type: {type(action)}")

messages.append(HumanMessage(content=observation))
messages.extend(action.messages)
messages.append(format_observation(action.tool, observation))
return messages


# PUBLIC API


class AgentInput(TypedDict):
"""The input to the agent."""

input: str
"""The input to the agent."""
intermediate_steps: List[Tuple[AgentAction, str]]
"""The intermediate steps taken by the agent."""
examples: NotRequired[List[BaseMessage]]
"""A list of messages that can be used to form example traces."""


def create_agent(
model: Union[BaseChatModel, BaseLanguageModel],
tools: Sequence[StructuredTool],
parser: AgentOutputParser,
*,
ast_printer: Union[AstPrinter, Literal["xml"]] = "xml",
) -> Runnable[AgentInput, Union[AgentAction, AgentFinish]]:
"""Create an agent for a chat model."""
if isinstance(ast_printer, str):
if ast_printer == "xml":
ast_printer = AstPrinter()
elif ast_printer == "typescript":
ast_printer = TypeScriptEncoder()
else:
raise ValueError(f"Unknown ast printer: {ast_printer}")
elif isinstance(ast_printer, AstPrinter):
pass
else:
raise TypeError(
f"Expected AstPrinter or str, got {type(ast_printer)} for `ast_printer`"
)

function_definitions = [convert_tool_to_function_definition(tool) for tool in tools]
tool_description = ast_printer.visit_function_definitions(function_definitions)

template = ChatPromptTemplate.from_messages(
[
("system", AGENT_INSTRUCTIONS_BLOB_STYLE),
MessagesPlaceholder("examples"), # Can use to add example traces
("human", "{input}"),
MessagesPlaceholder("history"),
]
).partial(tool_description=tool_description)

agent = (
{
"input": lambda x: x["input"],
"history": lambda x: format_steps_for_chat(x["intermediate_steps"]),
"examples": lambda x: x.get("examples", []),
}
| template
| model.bind(stop=["</tool>"])
| parser
)
return agent
226 changes: 226 additions & 0 deletions agents/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""Prototyping code for rendering function definitions, invocations, and results.

Types are simplified for now to `str`.

We should actually support something like pydantic or jsonschema for the types, so
we can expand them recursively for nested types.
"""
import abc
from typing import Any, List, Optional

from typing_extensions import NotRequired, TypedDict


class Parameter(TypedDict):
"""Representation for a parameter."""

name: str
type: str
description: str


class Arguments(TypedDict):
"""Arguments are passed to a function during function invocation."""

name: Optional[str]
value: Any


class ReturnValue(TypedDict):
"""Representation for a return value of a function call."""

type: str
description: NotRequired[str]


class FunctionDefinition(TypedDict):
"""Representation for a function."""

name: str
description: str # Function description
parameters: List[Parameter]
return_value: ReturnValue


class FunctionInvocation(TypedDict):
"""Representation for a function invocation."""

id: NotRequired[str]
name: str
arguments: List[Arguments]


class FunctionResult(TypedDict):
"""Representation for a function result."""

id: NotRequired[str]
name: str
result: Optional[str]
error: Optional[str]


class Visitor(abc.ABC):
def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
"""Render a function."""
raise NotImplementedError()

def visit_function_definitions(
self, function_definitions: List[FunctionDefinition]
) -> str:
"""Render a function."""
raise NotImplementedError()

def visit_function_invocation(self, function_invocation: FunctionInvocation) -> str:
"""Render a function invocation."""
raise NotImplementedError()

def visit_function_result(self, function_result: FunctionResult) -> str:
"""Render a function result."""
raise NotImplementedError()


class AstPrinter(Visitor):
"""Print the AST."""


class XMLEncoder(AstPrinter):
def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
"""Render a function."""
parameters_as_strings = [
"<parameter>\n"
f"<name>{parameter['name']}</name>\n"
f"<type>{parameter['type']}</type>\n"
f"<description>{parameter['description']}</description>\n"
"</parameter>\n"
for parameter in function_definition["parameters"]
]
function = (
"<function>\n"
f"<function_name>{function_definition['name']}</function_name>\n"
"<description>\n"
f"{function_definition['description']}\n"
"</description>\n"
"<parameters>\n"
f"{''.join(parameters_as_strings)}" # Already includes trailing newline
"</parameters>\n"
"<return_value>\n"
f"<type>{function_definition['return_value']['type']}</type>\n"
f"<description>{function_definition['return_value']['description']}</description>\n"
"</return_value>\n"
"</function>"
)
return function

def visit_function_definitions(
self, function_definitions: List[FunctionDefinition]
) -> str:
"""Render a function."""
strs = [
self.visit_function_definition(function_definition)
for function_definition in function_definitions
]
return "<functions>\n" + "\n".join(strs) + "\n</functions>"

def visit_function_invocation(self, invocation: FunctionInvocation) -> str:
"""Render a function invocation."""
arguments_as_strings = [
"<argument>\n"
f"<name>{argument['name']}</name>\n"
f"<value>{argument['value']}</value>\n"
"</argument>\n"
for argument in invocation["arguments"]
]
lines = ["<function_invocation>"]

if invocation.get("id"):
lines.append(f"<id>{invocation['id']}</id>")

lines.extend(
[
f"<function_name>{invocation['name']}</function_name>\n"
"<arguments>\n"
f"{''.join(arguments_as_strings)}" # Already includes trailing newline
"</arguments>\n"
"</function_invocation>"
]
)
return "\n".join(lines)

def visit_function_result(self, function_result: FunctionResult) -> str:
"""Render a function result."""
lines = [
"<function_result>",
]

if function_result.get("id"):
lines.append(f"<id>{function_result['id']}</id>")

lines.extend(
[
f"<function_name>{function_result['name']}</function_name>",
f"<result>{function_result['result']}</result>",
f"<error>{function_result['error']}</error>",
"</function_result>",
]
)

return "\n".join(lines)


class TypeScriptEncoder(AstPrinter):
def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
"""Render a function."""
parameters_as_strings = [
f"{parameter['name']}: {parameter['type']}"
for parameter in function_definition["parameters"]
]
# Let's use JSdoc style comments
# First the function description
lines = [
f"// {function_definition['description']}",
# Then the parameter descriptions
*[
f"// @param {parameter['name']} {parameter['description']}"
for parameter in function_definition["parameters"]
],
# Then the return value description
f"// @returns {function_definition['return_value']['description']}",
# Then the function definition
f"function {function_definition['name']}("
f"{', '.join(parameters_as_strings)}): "
f"{function_definition['return_value']['type']};",
]

# finally join
function = "\n".join(lines)
return function

def visit_function_definitions(
self, function_definitions: List[FunctionDefinition]
) -> str:
"""Render a function."""
strs = [
self.visit_function_definition(function_definition)
for function_definition in function_definitions
]
return "\n\n".join(strs)

def visit_function_invocation(self, invocation: FunctionInvocation) -> str:
"""Render a function invocation."""
arguments_as_strings = [
f"{argument['name']}: {argument['value']}"
for argument in invocation["arguments"]
]
lines = [f"{invocation['name']}(" f"{', '.join(arguments_as_strings)});"]
return "\n".join(lines)

def visit_function_result(self, function_result: FunctionResult) -> str:
"""Render a function result."""
lines = []
if function_result["error"]:
lines.append(f"ERROR: {function_result['error']}")
else:
lines.append(f"> {function_result['result']}")
if function_result.get("id"):
lines.append(f"// ID: {function_result['id']}")
return "\n".join(lines)
Loading