Skip to content

Commit

Permalink
feat: support native model ability to invoke tools (#50)
Browse files Browse the repository at this point in the history
* Support native tools
  • Loading branch information
Oleksii-Klimov authored Jan 15, 2024
1 parent 2f22c9d commit 09923ec
Show file tree
Hide file tree
Showing 35 changed files with 1,229 additions and 682 deletions.
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
CONFIG_DIR=aidial_assistant/configs
LOG_LEVEL=DEBUG
OPENAI_API_BASE=http://localhost:5001
WEB_CONCURRENCY=1
WEB_CONCURRENCY=1
TOOLS_SUPPORTING_DEPLOYMENTS=gpt-4-1106-preview
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,13 @@ make serve

Copy .env.example to .env and customize it for your environment:

| Variable | Default | Description |
|-----------------|--------------------------|--------------------------------------------------------|
| CONFIG_DIR | aidial_assistant/configs | Configuration directory |
| LOG_LEVEL | INFO | Log level. Use DEBUG for dev purposes and INFO in prod |
| OPENAI_API_BASE | N/A | OpenAI API Base |
| WEB_CONCURRENCY | 1 | Number of workers for the server |
| Variable | Default | Description |
|------------------------------|--------------------------|--------------------------------------------------------------------------------|
| CONFIG_DIR | aidial_assistant/configs | Configuration directory |
| LOG_LEVEL | INFO | Log level. Use DEBUG for dev purposes and INFO in prod |
| OPENAI_API_BASE | | OpenAI API Base |
| WEB_CONCURRENCY | 1 | Number of workers for the server |
| TOOLS_SUPPORTING_DEPLOYMENTS | | Comma-separated deployment names that support tools in chat completion request |

### Docker

Expand Down
22 changes: 12 additions & 10 deletions aidial_assistant/app.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
#!/usr/bin/env python3
import logging.config
import os
from pathlib import Path

from aidial_sdk import DIALApp
from aidial_sdk.telemetry.types import TelemetryConfig, TracingConfig
from starlette.responses import Response

from aidial_assistant.application.assistant_application import (
AssistantApplication,
)
from aidial_assistant.utils.log_config import get_log_config

log_level = os.getenv("LOG_LEVEL", "INFO")
config_dir = Path(os.getenv("CONFIG_DIR", "aidial_assistant/configs"))

logging.config.dictConfig(get_log_config(log_level))

telemetry_config = TelemetryConfig(
service_name="aidial-assistant", tracing=TracingConfig()
)
app = DIALApp(telemetry_config=telemetry_config)
app.add_chat_completion("assistant", AssistantApplication(config_dir))

# A delayed import is necessary to set up the httpx hook before the openai client inherits from AsyncClient.
from aidial_assistant.application.assistant_application import ( # noqa: E402
AssistantApplication,
)

@app.get("/healthcheck/status200")
def status200() -> Response:
return Response("Service is running...", status_code=200)
config_dir = Path(os.getenv("CONFIG_DIR", "aidial_assistant/configs"))
tools_supporting_deployments: set[str] = set(
os.getenv("TOOLS_SUPPORTING_DEPLOYMENTS", "").split(",")
)
app.add_chat_completion(
"assistant",
AssistantApplication(config_dir, tools_supporting_deployments),
)
7 changes: 5 additions & 2 deletions aidial_assistant/application/addons_dialogue_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
LimitExceededException,
ModelRequestLimiter,
)
from aidial_assistant.model.model_client import Message, ModelClient
from aidial_assistant.model.model_client import (
ChatCompletionMessageParam,
ModelClient,
)


class AddonsDialogueLimiter(ModelRequestLimiter):
Expand All @@ -16,7 +19,7 @@ def __init__(self, max_dialogue_tokens: int, model_client: ModelClient):
self._initial_tokens: int | None = None

@override
async def verify_limit(self, messages: list[Message]):
async def verify_limit(self, messages: list[ChatCompletionMessageParam]):
if self._initial_tokens is None:
self._initial_tokens = await self.model_client.count_tokens(
messages
Expand Down
153 changes: 126 additions & 27 deletions aidial_assistant/application/assistant_application.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
from pathlib import Path
from typing import Tuple

from aidial_sdk.chat_completion import FinishReason
from aidial_sdk.chat_completion.base import ChatCompletion
from aidial_sdk.chat_completion.request import Addon, Message, Request, Role
from aidial_sdk.chat_completion.response import Response
from openai.lib.azure import AsyncAzureOpenAI
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel

from aidial_assistant.application.addons_dialogue_limiter import (
Expand All @@ -18,18 +21,29 @@
MAIN_BEST_EFFORT_TEMPLATE,
MAIN_SYSTEM_DIALOG_MESSAGE,
)
from aidial_assistant.chain.command_chain import CommandChain, CommandDict
from aidial_assistant.chain.command_chain import (
CommandChain,
CommandConstructor,
CommandDict,
)
from aidial_assistant.chain.history import History
from aidial_assistant.commands.reply import Reply
from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin
from aidial_assistant.commands.run_tool import RunTool
from aidial_assistant.model.model_client import (
ModelClient,
ReasonLengthException,
)
from aidial_assistant.tools_chain.tools_chain import (
CommandToolDict,
ToolsChain,
convert_commands_to_tools,
)
from aidial_assistant.utils.exceptions import (
RequestParameterValidationError,
unhandled_exception_handler,
)
from aidial_assistant.utils.open_ai import construct_tool
from aidial_assistant.utils.open_ai_plugin import (
AddonTokenSource,
get_open_ai_plugin_info,
Expand All @@ -49,8 +63,6 @@ def _get_request_args(request: Request) -> dict[str, str]:
args = {
"model": request.model,
"temperature": request.temperature,
"api_version": request.api_version,
"api_key": request.api_key,
"user": request.user,
}

Expand Down Expand Up @@ -83,68 +95,114 @@ def _validate_messages(messages: list[Message]) -> None:
)


def _construct_tool(name: str, description: str) -> ChatCompletionToolParam:
return construct_tool(
name,
description,
{
"query": {
"type": "string",
"description": "A task written in natural language",
}
},
["query"],
)


class AssistantApplication(ChatCompletion):
def __init__(self, config_dir: Path):
def __init__(
self, config_dir: Path, tools_supporting_deployments: set[str]
):
self.args = parse_args(config_dir)
self.tools_supporting_deployments = tools_supporting_deployments

@unhandled_exception_handler
async def chat_completion(
self, request: Request, response: Response
) -> None:
_validate_messages(request.messages)
addon_references = _validate_addons(request.addons)
chat_args = self.args.openai_conf.dict() | _get_request_args(request)
chat_args = _get_request_args(request)

model = ModelClient(
model_args=chat_args
| {
"deployment_id": chat_args["model"],
"api_type": "azure",
"stream": True,
},
buffer_size=self.args.chat_conf.buffer_size,
client=AsyncAzureOpenAI(
azure_endpoint=self.args.openai_conf.api_base,
api_key=request.api_key,
# 2023-12-01-preview is needed to support tools
api_version="2023-12-01-preview",
),
model_args=chat_args,
)

token_source = AddonTokenSource(
request.headers,
(addon_reference.url for addon_reference in addon_references),
)

addons: dict[str, PluginInfo] = {}
plugins: list[PluginInfo] = []
# DIAL Core has own names for addons, so in stages we need to map them to the names used by the user
addon_name_mapping: dict[str, str] = {}
for addon_reference in addon_references:
info = await get_open_ai_plugin_info(addon_reference.url)
addons[info.ai_plugin.name_for_model] = PluginInfo(
info=info,
auth=get_plugin_auth(
info.ai_plugin.auth.type,
info.ai_plugin.auth.authorization_type,
addon_reference.url,
token_source,
),
plugins.append(
PluginInfo(
info=info,
auth=get_plugin_auth(
info.ai_plugin.auth.type,
info.ai_plugin.auth.authorization_type,
addon_reference.url,
token_source,
),
)
)

if addon_reference.name:
addon_name_mapping[
info.ai_plugin.name_for_model
] = addon_reference.name

if request.model in self.tools_supporting_deployments:
await AssistantApplication._run_native_tools_chat(
model, plugins, addon_name_mapping, request, response
)
else:
await AssistantApplication._run_emulated_tools_chat(
model, plugins, addon_name_mapping, request, response
)

@staticmethod
async def _run_emulated_tools_chat(
model: ModelClient,
addons: list[PluginInfo],
addon_name_mapping: dict[str, str],
request: Request,
response: Response,
):
# TODO: Add max_addons_dialogue_tokens as a request parameter
max_addons_dialogue_tokens = 1000

def create_command(addon: PluginInfo):
return lambda: RunPlugin(model, addon, max_addons_dialogue_tokens)

command_dict: CommandDict = {
RunPlugin.token(): lambda: RunPlugin(
model, addons, max_addons_dialogue_tokens
),
Reply.token(): Reply,
addon.info.ai_plugin.name_for_model: create_command(addon)
for addon in addons
}
if Reply.token() in command_dict:
RequestParameterValidationError(
f"Addon with name '{Reply.token()}' is not allowed for model {request.model}.",
param="addons",
)

command_dict[Reply.token()] = Reply

chain = CommandChain(
model_client=model, name="ASSISTANT", command_dict=command_dict
)
addon_descriptions = {
name: addon.info.open_api.info.description
addon.info.ai_plugin.name_for_model: addon.info.open_api.info.description
or addon.info.ai_plugin.description_for_human
for name, addon in addons.items()
for addon in addons
}
history = History(
assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build(
Expand Down Expand Up @@ -187,3 +245,44 @@ async def chat_completion(

if discarded_messages is not None:
response.set_discarded_messages(discarded_messages)

@staticmethod
async def _run_native_tools_chat(
model: ModelClient,
plugins: list[PluginInfo],
addon_name_mapping: dict[str, str],
request: Request,
response: Response,
):
def create_command_tool(
plugin: PluginInfo,
) -> Tuple[CommandConstructor, ChatCompletionToolParam]:
return lambda: RunTool(model, plugin), _construct_tool(
plugin.info.ai_plugin.name_for_model,
plugin.info.ai_plugin.description_for_human,
)

commands: CommandToolDict = {
plugin.info.ai_plugin.name_for_model: create_command_tool(plugin)
for plugin in plugins
}
chain = ToolsChain(model, commands)

choice = response.create_single_choice()
choice.open()

callback = AssistantChainCallback(choice, addon_name_mapping)
finish_reason = FinishReason.STOP
messages = convert_commands_to_tools(parse_history(request.messages))
try:
await chain.run_chat(messages, callback)
except ReasonLengthException:
finish_reason = FinishReason.LENGTH

if callback.invocations:
choice.set_state(State(invocations=callback.invocations))
choice.close(finish_reason)

response.set_usage(
model.total_prompt_tokens, model.total_completion_tokens
)
Loading

0 comments on commit 09923ec

Please sign in to comment.