From 09923ece480d6209cd56ba8f51de31620e839e28 Mon Sep 17 00:00:00 2001 From: Oleksii-Klimov <133792808+Oleksii-Klimov@users.noreply.github.com> Date: Mon, 15 Jan 2024 13:30:39 +0000 Subject: [PATCH] feat: support native model ability to invoke tools (#50) * Support native tools --- .env.example | 3 +- README.md | 13 +- aidial_assistant/app.py | 22 +- .../application/addons_dialogue_limiter.py | 7 +- .../application/assistant_application.py | 153 ++++++++++-- .../application/assistant_callback.py | 58 +---- aidial_assistant/application/prompts.py | 42 ++-- .../chain/callbacks/arg_callback.py | 21 -- .../chain/callbacks/args_callback.py | 11 +- aidial_assistant/chain/command_chain.py | 105 ++++---- aidial_assistant/chain/command_result.py | 16 +- aidial_assistant/chain/dialogue.py | 9 +- aidial_assistant/chain/history.py | 61 ++--- .../chain/model_response_reader.py | 12 +- aidial_assistant/commands/base.py | 25 +- aidial_assistant/commands/open_api.py | 8 +- aidial_assistant/commands/run_plugin.py | 40 ++- aidial_assistant/commands/run_tool.py | 106 ++++++++ aidial_assistant/model/model_client.py | 146 ++++++----- aidial_assistant/tools_chain/__init__.py | 0 aidial_assistant/tools_chain/tools_chain.py | 231 ++++++++++++++++++ aidial_assistant/utils/exceptions.py | 22 +- aidial_assistant/utils/open_ai.py | 73 ++++++ aidial_assistant/utils/state.py | 64 ++++- poetry.lock | 211 ++++++++++++---- pyproject.toml | 6 +- .../test_addons_dialogue_limiter.py | 21 +- .../chain/test_command_chain_best_effort.py | 100 ++++---- tests/unit_tests/chain/test_history.py | 51 ++-- tests/unit_tests/chain/test_model_client.py | 105 -------- tests/unit_tests/model/test_model_client.py | 113 +++++---- tests/unit_tests/tools_chain/__init__.py | 0 .../utils/test_exception_handler.py | 29 ++- tests/unit_tests/utils/test_state.py | 23 +- tests/utils/async_helper.py | 4 + 35 files changed, 1229 insertions(+), 682 deletions(-) delete mode 100644 aidial_assistant/chain/callbacks/arg_callback.py create mode 100644 aidial_assistant/commands/run_tool.py create mode 100644 aidial_assistant/tools_chain/__init__.py create mode 100644 aidial_assistant/tools_chain/tools_chain.py create mode 100644 aidial_assistant/utils/open_ai.py delete mode 100644 tests/unit_tests/chain/test_model_client.py create mode 100644 tests/unit_tests/tools_chain/__init__.py diff --git a/.env.example b/.env.example index a8584bc..1c225bc 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,5 @@ CONFIG_DIR=aidial_assistant/configs LOG_LEVEL=DEBUG OPENAI_API_BASE=http://localhost:5001 -WEB_CONCURRENCY=1 \ No newline at end of file +WEB_CONCURRENCY=1 +TOOLS_SUPPORTING_DEPLOYMENTS=gpt-4-1106-preview \ No newline at end of file diff --git a/README.md b/README.md index a581f6d..939717d 100644 --- a/README.md +++ b/README.md @@ -81,12 +81,13 @@ make serve Copy .env.example to .env and customize it for your environment: -| Variable | Default | Description | -|-----------------|--------------------------|--------------------------------------------------------| -| CONFIG_DIR | aidial_assistant/configs | Configuration directory | -| LOG_LEVEL | INFO | Log level. Use DEBUG for dev purposes and INFO in prod | -| OPENAI_API_BASE | N/A | OpenAI API Base | -| WEB_CONCURRENCY | 1 | Number of workers for the server | +| Variable | Default | Description | +|------------------------------|--------------------------|--------------------------------------------------------------------------------| +| CONFIG_DIR | aidial_assistant/configs | Configuration directory | +| LOG_LEVEL | INFO | Log level. Use DEBUG for dev purposes and INFO in prod | +| OPENAI_API_BASE | | OpenAI API Base | +| WEB_CONCURRENCY | 1 | Number of workers for the server | +| TOOLS_SUPPORTING_DEPLOYMENTS | | Comma-separated deployment names that support tools in chat completion request | ### Docker diff --git a/aidial_assistant/app.py b/aidial_assistant/app.py index db7ecb3..d14541d 100644 --- a/aidial_assistant/app.py +++ b/aidial_assistant/app.py @@ -1,19 +1,13 @@ -#!/usr/bin/env python3 import logging.config import os from pathlib import Path from aidial_sdk import DIALApp from aidial_sdk.telemetry.types import TelemetryConfig, TracingConfig -from starlette.responses import Response -from aidial_assistant.application.assistant_application import ( - AssistantApplication, -) from aidial_assistant.utils.log_config import get_log_config log_level = os.getenv("LOG_LEVEL", "INFO") -config_dir = Path(os.getenv("CONFIG_DIR", "aidial_assistant/configs")) logging.config.dictConfig(get_log_config(log_level)) @@ -21,9 +15,17 @@ service_name="aidial-assistant", tracing=TracingConfig() ) app = DIALApp(telemetry_config=telemetry_config) -app.add_chat_completion("assistant", AssistantApplication(config_dir)) +# A delayed import is necessary to set up the httpx hook before the openai client inherits from AsyncClient. +from aidial_assistant.application.assistant_application import ( # noqa: E402 + AssistantApplication, +) -@app.get("/healthcheck/status200") -def status200() -> Response: - return Response("Service is running...", status_code=200) +config_dir = Path(os.getenv("CONFIG_DIR", "aidial_assistant/configs")) +tools_supporting_deployments: set[str] = set( + os.getenv("TOOLS_SUPPORTING_DEPLOYMENTS", "").split(",") +) +app.add_chat_completion( + "assistant", + AssistantApplication(config_dir, tools_supporting_deployments), +) diff --git a/aidial_assistant/application/addons_dialogue_limiter.py b/aidial_assistant/application/addons_dialogue_limiter.py index fc3b130..e850931 100644 --- a/aidial_assistant/application/addons_dialogue_limiter.py +++ b/aidial_assistant/application/addons_dialogue_limiter.py @@ -4,7 +4,10 @@ LimitExceededException, ModelRequestLimiter, ) -from aidial_assistant.model.model_client import Message, ModelClient +from aidial_assistant.model.model_client import ( + ChatCompletionMessageParam, + ModelClient, +) class AddonsDialogueLimiter(ModelRequestLimiter): @@ -16,7 +19,7 @@ def __init__(self, max_dialogue_tokens: int, model_client: ModelClient): self._initial_tokens: int | None = None @override - async def verify_limit(self, messages: list[Message]): + async def verify_limit(self, messages: list[ChatCompletionMessageParam]): if self._initial_tokens is None: self._initial_tokens = await self.model_client.count_tokens( messages diff --git a/aidial_assistant/application/assistant_application.py b/aidial_assistant/application/assistant_application.py index 69fbc9c..0001617 100644 --- a/aidial_assistant/application/assistant_application.py +++ b/aidial_assistant/application/assistant_application.py @@ -1,10 +1,13 @@ import logging from pathlib import Path +from typing import Tuple from aidial_sdk.chat_completion import FinishReason from aidial_sdk.chat_completion.base import ChatCompletion from aidial_sdk.chat_completion.request import Addon, Message, Request, Role from aidial_sdk.chat_completion.response import Response +from openai.lib.azure import AsyncAzureOpenAI +from openai.types.chat import ChatCompletionToolParam from pydantic import BaseModel from aidial_assistant.application.addons_dialogue_limiter import ( @@ -18,18 +21,29 @@ MAIN_BEST_EFFORT_TEMPLATE, MAIN_SYSTEM_DIALOG_MESSAGE, ) -from aidial_assistant.chain.command_chain import CommandChain, CommandDict +from aidial_assistant.chain.command_chain import ( + CommandChain, + CommandConstructor, + CommandDict, +) from aidial_assistant.chain.history import History from aidial_assistant.commands.reply import Reply from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin +from aidial_assistant.commands.run_tool import RunTool from aidial_assistant.model.model_client import ( ModelClient, ReasonLengthException, ) +from aidial_assistant.tools_chain.tools_chain import ( + CommandToolDict, + ToolsChain, + convert_commands_to_tools, +) from aidial_assistant.utils.exceptions import ( RequestParameterValidationError, unhandled_exception_handler, ) +from aidial_assistant.utils.open_ai import construct_tool from aidial_assistant.utils.open_ai_plugin import ( AddonTokenSource, get_open_ai_plugin_info, @@ -49,8 +63,6 @@ def _get_request_args(request: Request) -> dict[str, str]: args = { "model": request.model, "temperature": request.temperature, - "api_version": request.api_version, - "api_key": request.api_key, "user": request.user, } @@ -83,9 +95,26 @@ def _validate_messages(messages: list[Message]) -> None: ) +def _construct_tool(name: str, description: str) -> ChatCompletionToolParam: + return construct_tool( + name, + description, + { + "query": { + "type": "string", + "description": "A task written in natural language", + } + }, + ["query"], + ) + + class AssistantApplication(ChatCompletion): - def __init__(self, config_dir: Path): + def __init__( + self, config_dir: Path, tools_supporting_deployments: set[str] + ): self.args = parse_args(config_dir) + self.tools_supporting_deployments = tools_supporting_deployments @unhandled_exception_handler async def chat_completion( @@ -93,16 +122,16 @@ async def chat_completion( ) -> None: _validate_messages(request.messages) addon_references = _validate_addons(request.addons) - chat_args = self.args.openai_conf.dict() | _get_request_args(request) + chat_args = _get_request_args(request) model = ModelClient( - model_args=chat_args - | { - "deployment_id": chat_args["model"], - "api_type": "azure", - "stream": True, - }, - buffer_size=self.args.chat_conf.buffer_size, + client=AsyncAzureOpenAI( + azure_endpoint=self.args.openai_conf.api_base, + api_key=request.api_key, + # 2023-12-01-preview is needed to support tools + api_version="2023-12-01-preview", + ), + model_args=chat_args, ) token_source = AddonTokenSource( @@ -110,19 +139,21 @@ async def chat_completion( (addon_reference.url for addon_reference in addon_references), ) - addons: dict[str, PluginInfo] = {} + plugins: list[PluginInfo] = [] # DIAL Core has own names for addons, so in stages we need to map them to the names used by the user addon_name_mapping: dict[str, str] = {} for addon_reference in addon_references: info = await get_open_ai_plugin_info(addon_reference.url) - addons[info.ai_plugin.name_for_model] = PluginInfo( - info=info, - auth=get_plugin_auth( - info.ai_plugin.auth.type, - info.ai_plugin.auth.authorization_type, - addon_reference.url, - token_source, - ), + plugins.append( + PluginInfo( + info=info, + auth=get_plugin_auth( + info.ai_plugin.auth.type, + info.ai_plugin.auth.authorization_type, + addon_reference.url, + token_source, + ), + ) ) if addon_reference.name: @@ -130,21 +161,48 @@ async def chat_completion( info.ai_plugin.name_for_model ] = addon_reference.name + if request.model in self.tools_supporting_deployments: + await AssistantApplication._run_native_tools_chat( + model, plugins, addon_name_mapping, request, response + ) + else: + await AssistantApplication._run_emulated_tools_chat( + model, plugins, addon_name_mapping, request, response + ) + + @staticmethod + async def _run_emulated_tools_chat( + model: ModelClient, + addons: list[PluginInfo], + addon_name_mapping: dict[str, str], + request: Request, + response: Response, + ): # TODO: Add max_addons_dialogue_tokens as a request parameter max_addons_dialogue_tokens = 1000 + + def create_command(addon: PluginInfo): + return lambda: RunPlugin(model, addon, max_addons_dialogue_tokens) + command_dict: CommandDict = { - RunPlugin.token(): lambda: RunPlugin( - model, addons, max_addons_dialogue_tokens - ), - Reply.token(): Reply, + addon.info.ai_plugin.name_for_model: create_command(addon) + for addon in addons } + if Reply.token() in command_dict: + RequestParameterValidationError( + f"Addon with name '{Reply.token()}' is not allowed for model {request.model}.", + param="addons", + ) + + command_dict[Reply.token()] = Reply + chain = CommandChain( model_client=model, name="ASSISTANT", command_dict=command_dict ) addon_descriptions = { - name: addon.info.open_api.info.description + addon.info.ai_plugin.name_for_model: addon.info.open_api.info.description or addon.info.ai_plugin.description_for_human - for name, addon in addons.items() + for addon in addons } history = History( assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build( @@ -187,3 +245,44 @@ async def chat_completion( if discarded_messages is not None: response.set_discarded_messages(discarded_messages) + + @staticmethod + async def _run_native_tools_chat( + model: ModelClient, + plugins: list[PluginInfo], + addon_name_mapping: dict[str, str], + request: Request, + response: Response, + ): + def create_command_tool( + plugin: PluginInfo, + ) -> Tuple[CommandConstructor, ChatCompletionToolParam]: + return lambda: RunTool(model, plugin), _construct_tool( + plugin.info.ai_plugin.name_for_model, + plugin.info.ai_plugin.description_for_human, + ) + + commands: CommandToolDict = { + plugin.info.ai_plugin.name_for_model: create_command_tool(plugin) + for plugin in plugins + } + chain = ToolsChain(model, commands) + + choice = response.create_single_choice() + choice.open() + + callback = AssistantChainCallback(choice, addon_name_mapping) + finish_reason = FinishReason.STOP + messages = convert_commands_to_tools(parse_history(request.messages)) + try: + await chain.run_chat(messages, callback) + except ReasonLengthException: + finish_reason = FinishReason.LENGTH + + if callback.invocations: + choice.set_state(State(invocations=callback.invocations)) + choice.close(finish_reason) + + response.set_usage( + model.total_prompt_tokens, model.total_completion_tokens + ) diff --git a/aidial_assistant/application/assistant_callback.py b/aidial_assistant/application/assistant_callback.py index b75ecab..e7f0bf0 100644 --- a/aidial_assistant/application/assistant_callback.py +++ b/aidial_assistant/application/assistant_callback.py @@ -1,67 +1,18 @@ from types import TracebackType -from typing import Callable from aidial_sdk.chat_completion import Status from aidial_sdk.chat_completion.choice import Choice from aidial_sdk.chat_completion.stage import Stage from typing_extensions import override -from aidial_assistant.chain.callbacks.arg_callback import ArgCallback from aidial_assistant.chain.callbacks.args_callback import ArgsCallback from aidial_assistant.chain.callbacks.chain_callback import ChainCallback from aidial_assistant.chain.callbacks.command_callback import CommandCallback from aidial_assistant.chain.callbacks.result_callback import ResultCallback from aidial_assistant.commands.base import ExecutionCallback, ResultObject -from aidial_assistant.commands.run_plugin import RunPlugin from aidial_assistant.utils.state import Invocation -class PluginNameArgCallback(ArgCallback): - def __init__( - self, - callback: Callable[[str], None], - addon_name_mapping: dict[str, str], - ): - super().__init__(0, callback) - self.addon_name_mapping = addon_name_mapping - - self._plugin_name = "" - - @override - def on_arg(self, chunk: str): - chunk = chunk.replace('"', "") - self._plugin_name += chunk - - @override - def on_arg_end(self): - self.callback( - self.addon_name_mapping.get(self._plugin_name, self._plugin_name) - + "(" - ) - - -class RunPluginArgsCallback(ArgsCallback): - def __init__( - self, - callback: Callable[[str], None], - addon_name_mapping: dict[str, str], - ): - super().__init__(callback) - self.addon_name_mapping = addon_name_mapping - - @override - def on_args_start(self): - pass - - @override - def arg_callback(self) -> ArgCallback: - self.arg_index += 1 - if self.arg_index == 0: - return PluginNameArgCallback(self.callback, self.addon_name_mapping) - else: - return ArgCallback(self.arg_index - 1, self.callback) - - class AssistantCommandCallback(CommandCallback): def __init__(self, stage: Stage, addon_name_mapping: dict[str, str]): self.stage = stage @@ -71,12 +22,7 @@ def __init__(self, stage: Stage, addon_name_mapping: dict[str, str]): @override def on_command(self, command: str): - if command == RunPlugin.token(): - self._args_callback = RunPluginArgsCallback( - self._on_stage_name, self.addon_name_mapping - ) - else: - self._on_stage_name(command) + self._on_stage_name(self.addon_name_mapping.get(command, command)) @override def execution_callback(self) -> ExecutionCallback: @@ -84,7 +30,7 @@ def execution_callback(self) -> ExecutionCallback: @override def args_callback(self) -> ArgsCallback: - return self._args_callback + return ArgsCallback(self._on_stage_name) @override def on_result(self, result: ResultObject): diff --git a/aidial_assistant/application/prompts.py b/aidial_assistant/application/prompts.py index a1f8e22..a54d169 100644 --- a/aidial_assistant/application/prompts.py +++ b/aidial_assistant/application/prompts.py @@ -30,29 +30,36 @@ def build(self, **kwargs) -> Template: _REQUEST_FORMAT_TEXT = """ You should ALWAYS reply with a JSON containing an array of commands: +```json { "commands": [ { "command": "", - "args": [ - "", "", ... - ] + "arguments": { + "": "" + } } ] } -The commands are invoked by system on user's behalf. +``` +The commands are invoked by the system on the user's behalf. """.strip() _PROTOCOL_FOOTER = """ * reply -The command delivers final response to the user. +The last command that delivers the final response to the user. + Arguments: - - MESSAGE is a string containing the final and complete result for the user. + - 'message' is a string containing the final and complete result for the user. Your goal is to answer user questions. Use relevant commands when they help to achieve the goal. ## Example -{"commands": [{"command": "reply", "args": ["Hello, world!"]}]} +```json +{"commands": [{"command": "reply", "arguments": {"message": "Hello, world!"}}]} +``` + +End of the protocol. """.strip() _SYSTEM_TEXT = """ @@ -67,17 +74,13 @@ def build(self, **kwargs) -> Template: {{request_format}} ## Commands -{%- if addons %} -* run-addon -This command executes a specified addon to address a one-time task described in natural language. -Addons do not see current conversation and require all details to be provided in the query to solve the task. -Arguments: - - NAME is one of the following addons: {%- for name, description in addons.items() %} - * {{name}} - {{description | decap}} -{%- endfor %} - - QUERY is the query string. -{%- endif %} +* {{name}} +{{description.strip()}} + +Arguments: + - 'query' is a query written in natural language. +{% endfor %} {{protocol_footer}} """.strip() @@ -100,9 +103,10 @@ def build(self, **kwargs) -> Template: ## Commands {%- for command_name in command_names %} * {{command_name}} + Arguments: - - -{%- endfor %} + - +{% endfor %} {{protocol_footer}} """.strip() diff --git a/aidial_assistant/chain/callbacks/arg_callback.py b/aidial_assistant/chain/callbacks/arg_callback.py deleted file mode 100644 index 5102637..0000000 --- a/aidial_assistant/chain/callbacks/arg_callback.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Callable - - -class ArgCallback: - """Callback for reporting arguments""" - - def __init__(self, arg_index: int, callback: Callable[[str], None]): - self.arg_index = arg_index - self.callback = callback - - def on_arg_start(self): - """Called when the arg starts""" - if self.arg_index > 0: - self.callback(", ") - - def on_arg(self, chunk: str): - """Called when an argument chunk is read""" - self.callback(chunk) - - def on_arg_end(self): - """Called when the arg ends""" diff --git a/aidial_assistant/chain/callbacks/args_callback.py b/aidial_assistant/chain/callbacks/args_callback.py index ca5b59f..5730ec9 100644 --- a/aidial_assistant/chain/callbacks/args_callback.py +++ b/aidial_assistant/chain/callbacks/args_callback.py @@ -1,24 +1,17 @@ from typing import Callable -from aidial_assistant.chain.callbacks.arg_callback import ArgCallback - class ArgsCallback: """Callback for reporting arguments""" def __init__(self, callback: Callable[[str], None]): self.callback = callback - self.arg_index = -1 def on_args_start(self): - """Called when the arguments start""" self.callback("(") - def arg_callback(self) -> ArgCallback: - """Returns a callback for reporting an argument""" - self.arg_index += 1 - return ArgCallback(self.arg_index, self.callback) + def on_args_chunk(self, chunk: str): + self.callback(chunk) def on_args_end(self): - """Called when the arguments end""" self.callback(")") diff --git a/aidial_assistant/chain/command_chain.py b/aidial_assistant/chain/command_chain.py index 48a8a63..ea76ab4 100644 --- a/aidial_assistant/chain/command_chain.py +++ b/aidial_assistant/chain/command_chain.py @@ -1,14 +1,13 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, AsyncIterator, Callable, Tuple, cast +from typing import Any, AsyncIterator, Tuple, cast -from aidial_sdk.chat_completion.request import Role -from openai import InvalidRequestError +from openai import BadRequestError from aidial_assistant.application.prompts import ENFORCE_JSON_FORMAT_TEMPLATE +from aidial_assistant.chain.callbacks.args_callback import ArgsCallback from aidial_assistant.chain.callbacks.chain_callback import ChainCallback -from aidial_assistant.chain.callbacks.command_callback import CommandCallback from aidial_assistant.chain.callbacks.result_callback import ResultCallback from aidial_assistant.chain.command_result import ( CommandInvocation, @@ -24,13 +23,20 @@ CommandsReader, skip_to_json_start, ) -from aidial_assistant.commands.base import Command, FinalCommand +from aidial_assistant.commands.base import ( + Command, + CommandConstructor, + FinalCommand, +) from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream from aidial_assistant.json_stream.exceptions import JsonParsingException -from aidial_assistant.json_stream.json_node import JsonNode -from aidial_assistant.json_stream.json_parser import JsonParser +from aidial_assistant.json_stream.json_object import JsonObject +from aidial_assistant.json_stream.json_parser import JsonParser, string_node from aidial_assistant.json_stream.json_string import JsonString -from aidial_assistant.model.model_client import Message, ModelClient +from aidial_assistant.model.model_client import ( + ChatCompletionMessageParam, + ModelClient, +) from aidial_assistant.utils.stream import CumulativeStream logger = logging.getLogger(__name__) @@ -41,7 +47,6 @@ # Later, the upper limit will be provided by the DIAL Core (proxy). MAX_MODEL_COMPLETION_CHUNKS = 32000 -CommandConstructor = Callable[[], Command] CommandDict = dict[str, CommandConstructor] @@ -51,7 +56,7 @@ class LimitExceededException(Exception): class ModelRequestLimiter(ABC): @abstractmethod - async def verify_limit(self, messages: list[Message]): + async def verify_limit(self, messages: list[ChatCompletionMessageParam]): pass @@ -74,13 +79,13 @@ def __init__( ) self.max_retry_count = max_retry_count - def _log_message(self, role: Role, content: str): - logger.debug(f"[{self.name}] {role.value}: {content}") + def _log_message(self, role: str, content: str | None): + logger.debug(f"[{self.name}] {role}: {content or ''}") - def _log_messages(self, messages: list[Message]): + def _log_messages(self, messages: list[ChatCompletionMessageParam]): if logger.isEnabledFor(logging.DEBUG): for message in messages: - self._log_message(message.role, message.content) + self._log_message(message["role"], message.get("content")) async def run_chat( self, @@ -112,9 +117,9 @@ async def run_chat( else history.to_user_messages() ) await self._generate_result(messages, callback) - except (InvalidRequestError, LimitExceededException) as e: + except (BadRequestError, LimitExceededException) as e: if dialogue.is_empty() or ( - isinstance(e, InvalidRequestError) and e.code == "429" + isinstance(e, BadRequestError) and e.code == "429" ): raise @@ -128,7 +133,7 @@ async def run_chat( async def _run_with_protocol_failure_retries( self, callback: ChainCallback, - messages: list[Message], + messages: list[ChatCompletionMessageParam], model_request_limiter: ModelRequestLimiter | None = None, ) -> DialogueTurn | None: last_error: Exception | None = None @@ -186,8 +191,8 @@ async def _run_with_protocol_failure_retries( ) ) finally: - self._log_message(Role.ASSISTANT, chunk_stream.buffer) - except (InvalidRequestError, LimitExceededException) as e: + self._log_message("assistant", chunk_stream.buffer) + except (BadRequestError, LimitExceededException) as e: if last_error: # Retries can increase the prompt size, which may lead to token overflow. # Thus, if the original error was a protocol error, it should be thrown instead. @@ -210,11 +215,11 @@ async def _run_commands( async for invocation in request_reader.parse_invocations(): command_name = await invocation.parse_name() command = self._create_command(command_name) - args = invocation.parse_args() + args = await invocation.parse_args() if isinstance(command, FinalCommand): if len(responses) > 0: continue - message = await anext(args) + message = string_node(await args.get("message")) await CommandChain._to_result( message if isinstance(message, JsonString) @@ -237,47 +242,44 @@ async def _run_commands( def _create_command(self, name: str) -> Command: if name not in self.command_dict: raise AssistantProtocolException( - f"The command '{name}' is expected to be one of {[*self.command_dict.keys()]}" + f"The command '{name}' is expected to be one of {list(self.command_dict.keys())}" ) return self.command_dict[name]() async def _generate_result( - self, messages: list[Message], callback: ChainCallback + self, + messages: list[ChatCompletionMessageParam], + callback: ChainCallback, ): stream = self.model_client.agenerate(messages) await CommandChain._to_result(stream, callback.result_callback()) @staticmethod - def _reinforce_json_format(messages: list[Message]) -> list[Message]: - last_message = messages[-1] - return messages[:-1] + [ - Message( - role=last_message.role, - content=ENFORCE_JSON_FORMAT_TEMPLATE.render( - response=last_message.content - ), - ), - ] + def _reinforce_json_format( + messages: list[ChatCompletionMessageParam], + ) -> list[ChatCompletionMessageParam]: + last_message = messages[-1].copy() + last_message["content"] = ENFORCE_JSON_FORMAT_TEMPLATE.render( + response=last_message.get("content", "") + ) + return messages[:-1] + [last_message] @staticmethod async def _to_args( - args: AsyncIterator[JsonNode], callback: CommandCallback - ) -> AsyncIterator[Any]: - args_callback = callback.args_callback() + args: JsonObject, args_callback: ArgsCallback + ) -> dict[str, Any]: args_callback.on_args_start() - async for arg in args: - arg_callback = args_callback.arg_callback() - arg_callback.on_arg_start() - result = "" - async for chunk in arg.to_chunks(): - arg_callback.on_arg(chunk) - result += chunk - arg_callback.on_arg_end() - yield json.loads(result) + result = "" + async for chunk in args.to_chunks(): + args_callback.on_args_chunk(chunk) + result += chunk + parsed_args = json.loads(result) args_callback.on_args_end() + return parsed_args + @staticmethod async def _to_result(stream: AsyncIterator[str], callback: ResultCallback): try: @@ -294,20 +296,17 @@ async def _to_result(stream: AsyncIterator[str], callback: ResultCallback): async def _execute_command( name: str, command: Command, - args: AsyncIterator[JsonNode], + args: JsonObject, chain_callback: ChainCallback, ) -> CommandResult: try: with chain_callback.command_callback() as command_callback: command_callback.on_command(name) - args_list = [ - arg - async for arg in CommandChain._to_args( - args, command_callback - ) - ] response = await command.execute( - args_list, command_callback.execution_callback() + await CommandChain._to_args( + args, command_callback.args_callback() + ), + command_callback.execution_callback(), ) command_callback.on_result(response) diff --git a/aidial_assistant/chain/command_result.py b/aidial_assistant/chain/command_result.py index 133685d..1c6ef38 100644 --- a/aidial_assistant/chain/command_result.py +++ b/aidial_assistant/chain/command_result.py @@ -1,6 +1,6 @@ import json from enum import Enum -from typing import List, TypedDict +from typing import Any, List, TypedDict class Status(str, Enum): @@ -18,12 +18,20 @@ class CommandResult(TypedDict): class CommandInvocation(TypedDict): command: str - args: list[str] + arguments: dict[str, Any] + + +class Commands(TypedDict): + commands: list[CommandInvocation] + + +class Responses(TypedDict): + responses: list[CommandResult] def responses_to_text(responses: List[CommandResult]) -> str: - return json.dumps({"responses": responses}) + return json.dumps(Responses(responses=responses)) def commands_to_text(commands: List[CommandInvocation]) -> str: - return json.dumps({"commands": commands}) + return json.dumps(Commands(commands=commands)) diff --git a/aidial_assistant/chain/dialogue.py b/aidial_assistant/chain/dialogue.py index b8b3077..f1abda5 100644 --- a/aidial_assistant/chain/dialogue.py +++ b/aidial_assistant/chain/dialogue.py @@ -1,6 +1,7 @@ +from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel -from aidial_assistant.model.model_client import Message +from aidial_assistant.utils.open_ai import assistant_message, user_message class DialogueTurn(BaseModel): @@ -10,11 +11,11 @@ class DialogueTurn(BaseModel): class Dialogue: def __init__(self): - self.messages: list[Message] = [] + self.messages: list[ChatCompletionMessageParam] = [] def append(self, dialogue_turn: DialogueTurn): - self.messages.append(Message.assistant(dialogue_turn.assistant_message)) - self.messages.append(Message.user(dialogue_turn.user_message)) + self.messages.append(assistant_message(dialogue_turn.assistant_message)) + self.messages.append(user_message(dialogue_turn.user_message)) def pop(self): self.messages.pop() diff --git a/aidial_assistant/chain/history.py b/aidial_assistant/chain/history.py index 27bfa07..6e8db05 100644 --- a/aidial_assistant/chain/history.py +++ b/aidial_assistant/chain/history.py @@ -1,17 +1,17 @@ from enum import Enum -from aidial_sdk.chat_completion import Role from jinja2 import Template +from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel -from aidial_assistant.application.prompts import ENFORCE_JSON_FORMAT_TEMPLATE from aidial_assistant.chain.command_result import ( CommandInvocation, commands_to_text, ) from aidial_assistant.chain.dialogue import Dialogue from aidial_assistant.commands.reply import Reply -from aidial_assistant.model.model_client import Message, ModelClient +from aidial_assistant.model.model_client import ModelClient +from aidial_assistant.utils.open_ai import assistant_message, system_message class ContextLengthExceeded(Exception): @@ -25,19 +25,7 @@ class MessageScope(str, Enum): class ScopedMessage(BaseModel): scope: MessageScope = MessageScope.USER - message: Message - - -def enforce_json_format(messages: list[Message]) -> list[Message]: - last_message = messages[-1] - return messages[:-1] + [ - Message( - role=last_message.role, - content=ENFORCE_JSON_FORMAT_TEMPLATE.render( - response=last_message.content - ), - ), - ] + message: ChatCompletionMessageParam class History: @@ -58,44 +46,45 @@ def __init__( if message.scope == MessageScope.USER ) - def to_protocol_messages(self) -> list[Message]: - messages: list[Message] = [] + def to_protocol_messages(self) -> list[ChatCompletionMessageParam]: + messages: list[ChatCompletionMessageParam] = [] for index, scoped_message in enumerate(self.scoped_messages): message = scoped_message.message scope = scoped_message.scope if index == 0: - if message.role == Role.SYSTEM: + if message["role"] == "system": messages.append( - Message.system( + system_message( self.assistant_system_message_template.render( - system_prefix=message.content + system_prefix=message["content"] ) ) ) else: messages.append( - Message.system( + system_message( self.assistant_system_message_template.render() ) ) messages.append(message) - elif scope == MessageScope.USER and message.role == Role.ASSISTANT: + elif scope == MessageScope.USER and message["role"] == "assistant": # Clients see replies in plain text, but the model should understand how to reply appropriately. content = commands_to_text( [ CommandInvocation( - command=Reply.token(), args=[message.content] + command=Reply.token(), + arguments={"message": message.get("content", "")}, ) ] ) - messages.append(Message.assistant(content=content)) + messages.append(assistant_message(content)) else: messages.append(message) return messages - def to_user_messages(self) -> list[Message]: + def to_user_messages(self) -> list[ChatCompletionMessageParam]: return [ scoped_message.message for scoped_message in self.scoped_messages @@ -104,18 +93,16 @@ def to_user_messages(self) -> list[Message]: def to_best_effort_messages( self, error: str, dialogue: Dialogue - ) -> list[Message]: + ) -> list[ChatCompletionMessageParam]: messages = self.to_user_messages() - last_message = messages[-1] - messages[-1] = Message( - role=last_message.role, - content=self.best_effort_template.render( - message=last_message.content, - error=error, - dialogue=dialogue.messages, - ), + last_message = messages[-1].copy() + last_message["content"] = self.best_effort_template.render( + message=last_message.get("content", ""), + error=error, + dialogue=dialogue.messages, ) + messages[-1] = last_message return messages @@ -146,7 +133,7 @@ def _skip_messages(self, discarded_messages: int) -> list[ScopedMessage]: message_iterator = iter(self.scoped_messages) for _ in range(discarded_messages): current_message = next(message_iterator) - while current_message.message.role == Role.SYSTEM: + while current_message.message["role"] == "system": # System messages should be kept in the history messages.append(current_message) current_message = next(message_iterator) @@ -157,7 +144,7 @@ def _skip_messages(self, discarded_messages: int) -> list[ScopedMessage]: # Internal messages (i.e. addon requests/responses) are always followed by an assistant reply assert ( - current_message.message.role == Role.ASSISTANT + current_message.message["role"] == "assistant" ), "Internal messages must be followed by an assistant reply." remaining_messages = list(message_iterator) diff --git a/aidial_assistant/chain/model_response_reader.py b/aidial_assistant/chain/model_response_reader.py index dffe11f..8e63f1a 100644 --- a/aidial_assistant/chain/model_response_reader.py +++ b/aidial_assistant/chain/model_response_reader.py @@ -46,18 +46,12 @@ async def parse_name(self) -> str: except (TypeError, KeyError) as e: raise AssistantProtocolException(f"Cannot parse command name: {e}") - async def parse_args(self) -> AsyncIterator[JsonNode]: + async def parse_args(self) -> JsonObject: try: - args = await self.node.get("args") - # HACK: model not always passes args as an array - if isinstance(args, JsonArray): - async for arg in array_node(args): - yield arg - else: - yield args + return object_node(await self.node.get("arguments")) except (TypeError, KeyError) as e: raise AssistantProtocolException( - f"Cannot parse command args array: {e}" + f"Cannot parse command arguments array: {e}" ) diff --git a/aidial_assistant/commands/base.py b/aidial_assistant/commands/base.py index b4f2fca..1c12d8b 100644 --- a/aidial_assistant/commands/base.py +++ b/aidial_assistant/commands/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Callable, List, TypedDict +from typing import Any, Callable, List, TypedDict, TypeVar from typing_extensions import override @@ -43,24 +43,18 @@ def token() -> str: pass async def execute( - self, args: List[Any], execution_callback: ExecutionCallback + self, args: dict[str, Any], execution_callback: ExecutionCallback ) -> ResultObject: raise Exception(f"Command {self} isn't implemented") def __str__(self) -> str: return self.token() - def assert_arg_count(self, args: List[Any], count: int): - if len(args) != count: - raise ValueError( - f"Command {self} expects {count} args, but got {len(args)}" - ) - class FinalCommand(Command, ABC): @override async def execute( - self, args: List[Any], execution_callback: ExecutionCallback + self, args: dict[str, Any], execution_callback: ExecutionCallback ) -> ResultObject: raise Exception( f"Internal error: command {self} is final and can't be executed" @@ -70,3 +64,16 @@ async def execute( class CommandObject(TypedDict): command: str args: List[str] + + +CommandConstructor = Callable[[], Command] + + +T = TypeVar("T") + + +def get_required_field(args: dict[str, T], field: str) -> T: + value = args.get(field) + if value is None: + raise Exception(f"Parameter '{field}' is required") + return value diff --git a/aidial_assistant/commands/open_api.py b/aidial_assistant/commands/open_api.py index 4f90303..8f2a679 100644 --- a/aidial_assistant/commands/open_api.py +++ b/aidial_assistant/commands/open_api.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from langchain.tools.openapi.utils.api_models import APIOperation from typing_extensions import override @@ -22,10 +22,8 @@ def __init__(self, op: APIOperation, plugin_auth: str | None): @override async def execute( - self, args: List[Any], execution_callback: ExecutionCallback + self, args: dict[str, Any], execution_callback: ExecutionCallback ) -> ResultObject: - self.assert_arg_count(args, 1) - return await OpenAPIEndpointRequester( self.op, self.plugin_auth - ).execute(args[0]) + ).execute(args) diff --git a/aidial_assistant/commands/run_plugin.py b/aidial_assistant/commands/run_plugin.py index 15cc00f..96f2913 100644 --- a/aidial_assistant/commands/run_plugin.py +++ b/aidial_assistant/commands/run_plugin.py @@ -1,5 +1,3 @@ -from typing import List - from langchain.tools import APIOperation from pydantic.main import BaseModel from typing_extensions import override @@ -18,16 +16,17 @@ ExecutionCallback, ResultObject, TextResult, + get_required_field, ) from aidial_assistant.commands.open_api import OpenAPIChatCommand from aidial_assistant.commands.plugin_callback import PluginChainCallback from aidial_assistant.commands.reply import Reply from aidial_assistant.model.model_client import ( - Message, ModelClient, ReasonLengthException, ) from aidial_assistant.open_api.operation_selector import collect_operations +from aidial_assistant.utils.open_ai import user_message from aidial_assistant.utils.open_ai_plugin import OpenAIPluginInfo @@ -40,11 +39,11 @@ class RunPlugin(Command): def __init__( self, model_client: ModelClient, - plugins: dict[str, PluginInfo], + plugin: PluginInfo, max_completion_tokens: int, ): self.model_client = model_client - self.plugins = plugins + self.plugin = plugin self.max_completion_tokens = max_completion_tokens @staticmethod @@ -53,35 +52,29 @@ def token(): @override async def execute( - self, args: List[str], execution_callback: ExecutionCallback + self, args: dict[str, str], execution_callback: ExecutionCallback ) -> ResultObject: - self.assert_arg_count(args, 2) - name = args[0] - query = args[1] + query = get_required_field(args, "query") - return await self._run_plugin(name, query, execution_callback) + return await self._run_plugin(query, execution_callback) async def _run_plugin( - self, name: str, query: str, execution_callback: ExecutionCallback + self, query: str, execution_callback: ExecutionCallback ) -> ResultObject: - if name not in self.plugins: - raise ValueError( - f"Unknown addon: {name}. Available addons: {[*self.plugins.keys()]}" - ) - - plugin = self.plugins[name] - info = plugin.info + info = self.plugin.info ops = collect_operations(info.open_api, info.ai_plugin.api.url) api_schema = "\n\n".join([op.to_typescript() for op in ops.values()]) # type: ignore def create_command(op: APIOperation): - return lambda: OpenAPIChatCommand(op, plugin.auth) + return lambda: OpenAPIChatCommand(op, self.plugin.auth) command_dict: dict[str, CommandConstructor] = {} for name, op in ops.items(): # The function is necessary to capture the current value of op. # Otherwise, only first op will be used for all commands command_dict[name] = create_command(op) + if Reply.token() in command_dict: + Exception(f"Operation with name '{Reply.token()}' is not allowed.") command_dict[Reply.token()] = Reply @@ -94,12 +87,12 @@ def create_command(op: APIOperation): best_effort_template=ADDON_BEST_EFFORT_TEMPLATE.build( api_schema=api_schema ), - scoped_messages=[ScopedMessage(message=Message.user(query))], + scoped_messages=[ScopedMessage(message=user_message(query))], ) chat = CommandChain( model_client=self.model_client, - name="PLUGIN:" + name, + name="PLUGIN:" + self.plugin.info.ai_plugin.name_for_model, command_dict=command_dict, max_completion_tokens=self.max_completion_tokens, ) @@ -107,6 +100,7 @@ def create_command(op: APIOperation): callback = PluginChainCallback(execution_callback) try: await chat.run_chat(history, callback) - return TextResult(callback.result) except ReasonLengthException: - return TextResult(callback.result) + pass + + return TextResult(callback.result) diff --git a/aidial_assistant/commands/run_tool.py b/aidial_assistant/commands/run_tool.py new file mode 100644 index 0000000..b1a147e --- /dev/null +++ b/aidial_assistant/commands/run_tool.py @@ -0,0 +1,106 @@ +from typing import Any + +from langchain_community.tools.openapi.utils.api_models import ( + APIOperation, + APIPropertyBase, +) +from openai.types.chat import ChatCompletionToolParam +from typing_extensions import override + +from aidial_assistant.commands.base import ( + Command, + ExecutionCallback, + ResultObject, + TextResult, + get_required_field, +) +from aidial_assistant.commands.open_api import OpenAPIChatCommand +from aidial_assistant.commands.plugin_callback import PluginChainCallback +from aidial_assistant.commands.run_plugin import PluginInfo +from aidial_assistant.model.model_client import ( + ModelClient, + ReasonLengthException, +) +from aidial_assistant.open_api.operation_selector import collect_operations +from aidial_assistant.tools_chain.tools_chain import ( + CommandTool, + CommandToolDict, + ToolsChain, +) +from aidial_assistant.utils.open_ai import ( + construct_tool, + system_message, + user_message, +) + + +def _construct_property(p: APIPropertyBase) -> dict[str, Any]: + parameter = { + "type": p.type, + "description": p.description, + } + return {k: v for k, v in parameter.items() if v is not None} + + +def _construct_tool(op: APIOperation) -> ChatCompletionToolParam: + properties = {} + required = [] + for p in op.properties: + properties[p.name] = _construct_property(p) + + if p.required: + required.append(p.name) + + if op.request_body is not None: + for p in op.request_body.properties: + properties[p.name] = _construct_property(p) + + if p.required: + required.append(p.name) + + return construct_tool( + op.operation_id, op.description or "", properties, required + ) + + +class RunTool(Command): + def __init__(self, model: ModelClient, plugin: PluginInfo): + self.model = model + self.plugin = plugin + + @staticmethod + def token(): + return "run-tool" + + @override + async def execute( + self, args: dict[str, Any], execution_callback: ExecutionCallback + ) -> ResultObject: + query = get_required_field(args, "query") + + ops = collect_operations( + self.plugin.info.open_api, self.plugin.info.ai_plugin.api.url + ) + + def create_command_tool(op: APIOperation) -> CommandTool: + return lambda: OpenAPIChatCommand( + op, self.plugin.auth + ), _construct_tool(op) + + commands: CommandToolDict = { + name: create_command_tool(op) for name, op in ops.items() + } + + chain = ToolsChain(self.model, commands) + + messages = [ + system_message(self.plugin.info.ai_plugin.description_for_model), + user_message(query), + ] + chain_callback = PluginChainCallback(execution_callback) + try: + await chain.run_chat(messages, chain_callback) + except ReasonLengthException: + pass + + return TextResult(chain_callback.result) diff --git a/aidial_assistant/model/model_client.py b/aidial_assistant/model/model_client.py index cb5499b..83dfa8e 100644 --- a/aidial_assistant/model/model_client.py +++ b/aidial_assistant/model/model_client.py @@ -1,41 +1,20 @@ from abc import ABC -from typing import Any, AsyncIterator, List, TypedDict +from typing import Any, AsyncIterator, List -import openai -from aidial_sdk.chat_completion import Role -from aiohttp import ClientSession -from pydantic import BaseModel +from aidial_sdk.utils.merge_chunks import merge +from openai import AsyncOpenAI +from openai.types.chat import ( + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, +) + +from aidial_assistant.utils.open_ai import Usage class ReasonLengthException(Exception): pass -class Message(BaseModel): - role: Role - content: str - - def to_openai_message(self) -> dict[str, str]: - return {"role": self.role.value, "content": self.content} - - @classmethod - def system(cls, content): - return cls(role=Role.SYSTEM, content=content) - - @classmethod - def user(cls, content): - return cls(role=Role.USER, content=content) - - @classmethod - def assistant(cls, content): - return cls(role=Role.ASSISTANT, content=content) - - -class Usage(TypedDict): - prompt_tokens: int - completion_tokens: int - - class ExtraResultsCallback: def on_discarded_messages(self, discarded_messages: int): pass @@ -43,6 +22,11 @@ def on_discarded_messages(self, discarded_messages: int): def on_prompt_tokens(self, prompt_tokens: int): pass + def on_tool_calls( + self, tool_calls: list[ChatCompletionMessageToolCallParam] + ): + pass + async def _flush_stream(stream: AsyncIterator[str]): try: @@ -53,64 +37,78 @@ async def _flush_stream(stream: AsyncIterator[str]): class ModelClient(ABC): - def __init__( - self, - model_args: dict[str, Any], - buffer_size: int, - ): + def __init__(self, client: AsyncOpenAI, model_args: dict[str, Any]): + self.client = client self.model_args = model_args - self.buffer_size = buffer_size self._total_prompt_tokens: int = 0 self._total_completion_tokens: int = 0 async def agenerate( self, - messages: List[Message], + messages: List[ChatCompletionMessageParam], extra_results_callback: ExtraResultsCallback | None = None, **kwargs, ) -> AsyncIterator[str]: - async with ClientSession(read_bufsize=self.buffer_size) as session: - openai.aiosession.set(session) - - model_result = await openai.ChatCompletion.acreate( - messages=[message.to_openai_message() for message in messages], - **self.model_args | kwargs, - ) - - finish_reason_length = False - async for chunk in model_result: # type: ignore - usage: Usage | None = chunk.get("usage") - if usage: - prompt_tokens = usage["prompt_tokens"] - self._total_prompt_tokens += prompt_tokens - self._total_completion_tokens += usage["completion_tokens"] - if extra_results_callback: - extra_results_callback.on_prompt_tokens(prompt_tokens) + model_result = await self.client.chat.completions.create( + **self.model_args, + extra_body=kwargs, + stream=True, + messages=messages, + ) + finish_reason_length = False + tool_calls_chunks: list[list[dict[str, Any]]] = [] + async for chunk in model_result: + chunk_dict = chunk.dict() + usage: Usage | None = chunk_dict.get("usage") + if usage: + prompt_tokens = usage["prompt_tokens"] + self._total_prompt_tokens += prompt_tokens + self._total_completion_tokens += usage["completion_tokens"] if extra_results_callback: - discarded_messages: int | None = chunk.get( - "statistics", {} - ).get("discarded_messages") - if discarded_messages is not None: - extra_results_callback.on_discarded_messages( - discarded_messages - ) - - choice = chunk["choices"][0] - text = choice["delta"].get("content") - if text: - yield text - - if choice.get("finish_reason") == "length": - finish_reason_length = True - - if finish_reason_length: - raise ReasonLengthException() + extra_results_callback.on_prompt_tokens(prompt_tokens) + + if extra_results_callback: + discarded_messages: int | None = chunk_dict.get( + "statistics", {} + ).get("discarded_messages") + if discarded_messages is not None: + extra_results_callback.on_discarded_messages( + discarded_messages + ) + + choice = chunk.choices[0] + delta = choice.delta + if delta.content: + yield delta.content + + if delta.tool_calls: + tool_calls_chunks.append( + [ + tool_call_chunk.dict() + for tool_call_chunk in delta.tool_calls + ] + ) + + if choice.finish_reason == "length": + finish_reason_length = True + + if finish_reason_length: + raise ReasonLengthException() + + if extra_results_callback and tool_calls_chunks: + tool_calls: list[ChatCompletionMessageToolCallParam] = [ + ChatCompletionMessageToolCallParam(**tool_call) + for tool_call in merge(*tool_calls_chunks) + ] + extra_results_callback.on_tool_calls(tool_calls) # TODO: Use a dedicated endpoint for counting tokens. # This request may throw an error if the number of tokens is too large. - async def count_tokens(self, messages: list[Message]) -> int: + async def count_tokens( + self, messages: list[ChatCompletionMessageParam] + ) -> int: class PromptTokensCallback(ExtraResultsCallback): def __init__(self): self.token_count: int | None = None @@ -131,7 +129,7 @@ def on_prompt_tokens(self, prompt_tokens: int): # TODO: Use a dedicated endpoint for discarded_messages. async def get_discarded_messages( - self, messages: list[Message], max_prompt_tokens: int + self, messages: list[ChatCompletionMessageParam], max_prompt_tokens: int ) -> int: class DiscardedMessagesCallback(ExtraResultsCallback): def __init__(self): diff --git a/aidial_assistant/tools_chain/__init__.py b/aidial_assistant/tools_chain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aidial_assistant/tools_chain/tools_chain.py b/aidial_assistant/tools_chain/tools_chain.py new file mode 100644 index 0000000..87fbadb --- /dev/null +++ b/aidial_assistant/tools_chain/tools_chain.py @@ -0,0 +1,231 @@ +import json +from typing import Any, Tuple, cast + +from openai import BadRequestError +from openai.types.chat import ( + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionToolParam, +) +from openai.types.chat.chat_completion_message_tool_call_param import Function + +from aidial_assistant.chain.callbacks.chain_callback import ChainCallback +from aidial_assistant.chain.callbacks.command_callback import CommandCallback +from aidial_assistant.chain.command_chain import CommandConstructor +from aidial_assistant.chain.command_result import ( + CommandInvocation, + CommandResult, + Commands, + Responses, + Status, + commands_to_text, + responses_to_text, +) +from aidial_assistant.chain.history import MessageScope, ScopedMessage +from aidial_assistant.chain.model_response_reader import ( + AssistantProtocolException, +) +from aidial_assistant.commands.base import Command +from aidial_assistant.model.model_client import ( + ExtraResultsCallback, + ModelClient, +) +from aidial_assistant.utils.exceptions import RequestParameterValidationError +from aidial_assistant.utils.open_ai import tool_calls_message, tool_message + + +def convert_commands_to_tools( + scoped_messages: list[ScopedMessage], +) -> list[ChatCompletionMessageParam]: + messages: list[ChatCompletionMessageParam] = [] + next_tool_id: int = 0 + last_call_count: int = 0 + for scoped_message in scoped_messages: + message = scoped_message.message + if scoped_message.scope == MessageScope.INTERNAL: + content = cast(str, message.get("content")) + if not content: + raise RequestParameterValidationError( + "State is broken. Content cannot be empty.", + param="messages", + ) + + if message["role"] == "assistant": + commands: Commands = json.loads(content) + messages.append( + tool_calls_message( + [ + ChatCompletionMessageToolCallParam( + id=str(next_tool_id + index), + function=Function( + name=command["command"], + arguments=json.dumps(command["arguments"]), + ), + type="function", + ) + for index, command in enumerate( + commands["commands"] + ) + ], + ) + ) + last_call_count = len(commands["commands"]) + next_tool_id += last_call_count + elif message["role"] == "user": + responses: Responses = json.loads(content) + response_count = len(responses["responses"]) + if response_count != last_call_count: + raise RequestParameterValidationError( + f"Expected {last_call_count} responses, but got {response_count}.", + param="messages", + ) + first_tool_id = next_tool_id - last_call_count + messages.extend( + [ + tool_message( + content=response["response"], + tool_call_id=str(first_tool_id + index), + ) + for index, response in enumerate(responses["responses"]) + ] + ) + else: + messages.append(scoped_message.message) + return messages + + +def _publish_command( + command_callback: CommandCallback, name: str, arguments: str +): + command_callback.on_command(name) + args_callback = command_callback.args_callback() + args_callback.on_args_start() + args_callback.on_args_chunk(arguments) + args_callback.on_args_end() + + +CommandTool = Tuple[CommandConstructor, ChatCompletionToolParam] +CommandToolDict = dict[str, CommandTool] + + +class ToolCallsCallback(ExtraResultsCallback): + def __init__(self): + self.tool_calls: list[ChatCompletionMessageToolCallParam] = [] + + def on_tool_calls( + self, tool_calls: list[ChatCompletionMessageToolCallParam] + ): + self.tool_calls = tool_calls + + +class ToolsChain: + def __init__(self, model: ModelClient, commands: CommandToolDict): + self.model = model + self.commands = commands + + async def run_chat( + self, + messages: list[ChatCompletionMessageParam], + callback: ChainCallback, + ): + result_callback = callback.result_callback() + dialogue: list[ChatCompletionMessageParam] = [] + last_message_block_length = 0 + tools = [tool for _, tool in self.commands.values()] + while True: + tool_calls_callback = ToolCallsCallback() + try: + async for chunk in self.model.agenerate( + messages + dialogue, tool_calls_callback, tools=tools + ): + result_callback.on_result(chunk) + except BadRequestError as e: + if len(dialogue) == 0 or e.code == "429": + raise + + # If the dialog size exceeds model context size then remove last message block + # and try again without tools. + dialogue = dialogue[:-last_message_block_length] + async for chunk in self.model.agenerate( + messages + dialogue, tool_calls_callback + ): + result_callback.on_result(chunk) + break + + if not tool_calls_callback.tool_calls: + break + + dialogue.append( + tool_calls_message( + tool_calls_callback.tool_calls, + ) + ) + result_messages = await self._run_tools( + tool_calls_callback.tool_calls, callback + ) + dialogue.extend(result_messages) + last_message_block_length = len(result_messages) + 1 + + def _create_command(self, name: str) -> Command: + if name not in self.commands: + raise AssistantProtocolException( + f"The tool '{name}' is expected to be one of {list(self.commands.keys())}" + ) + + command, _ = self.commands[name] + + return command() + + async def _run_tools( + self, + tool_calls: list[ChatCompletionMessageToolCallParam], + callback: ChainCallback, + ): + commands: list[CommandInvocation] = [] + command_results: list[CommandResult] = [] + result_messages: list[ChatCompletionMessageParam] = [] + for tool_call in tool_calls: + function = tool_call["function"] + name = function["name"] + arguments: dict[str, Any] = json.loads(function["arguments"]) + with callback.command_callback() as command_callback: + _publish_command(command_callback, name, json.dumps(arguments)) + command = self._create_command(name) + result = await self._execute_command( + command, + arguments, + command_callback, + ) + result_messages.append( + tool_message( + content=result["response"], + tool_call_id=tool_call["id"], + ) + ) + command_results.append(result) + + commands.append( + CommandInvocation(command=name, arguments=arguments) + ) + + callback.on_state( + commands_to_text(commands), responses_to_text(command_results) + ) + + return result_messages + + @staticmethod + async def _execute_command( + command: Command, + args: dict[str, Any], + command_callback: CommandCallback, + ) -> CommandResult: + try: + result = await command.execute( + args, command_callback.execution_callback() + ) + command_callback.on_result(result) + return CommandResult(status=Status.SUCCESS, response=result.text) + except Exception as e: + command_callback.on_error(e) + return CommandResult(status=Status.ERROR, response=str(e)) diff --git a/aidial_assistant/utils/exceptions.py b/aidial_assistant/utils/exceptions.py index bb03218..791fb9b 100644 --- a/aidial_assistant/utils/exceptions.py +++ b/aidial_assistant/utils/exceptions.py @@ -2,7 +2,7 @@ from functools import wraps from aidial_sdk import HTTPException -from openai import OpenAIError +from openai import APIError logger = logging.getLogger(__name__) @@ -26,18 +26,14 @@ def _to_http_exception(e: Exception) -> HTTPException: param=e.param, ) - if isinstance(e, OpenAIError): - http_status = e.http_status or 500 - if e.error: - return HTTPException( - message=e.error.message, - status_code=http_status, - type=e.error.type, - code=e.error.code, - param=e.error.param, - ) - - return HTTPException(message=str(e), status_code=http_status) + if isinstance(e, APIError): + raise HTTPException( + message=e.message, + status_code=getattr(e, "status_code") or 500, + type=e.type or "runtime_error", + code=e.code, + param=e.param, + ) return HTTPException( message=str(e), status_code=500, type="internal_server_error" diff --git a/aidial_assistant/utils/open_ai.py b/aidial_assistant/utils/open_ai.py new file mode 100644 index 0000000..b72acfc --- /dev/null +++ b/aidial_assistant/utils/open_ai.py @@ -0,0 +1,73 @@ +from typing import TypedDict + +from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, +) +from openai.types.shared_params import FunctionDefinition + + +class Usage(TypedDict): + prompt_tokens: int + completion_tokens: int + + +class Property(TypedDict, total=False): + type: str + description: str + + +def construct_tool( + name: str, + description: str, + properties: dict[str, Property], + required: list[str], +) -> ChatCompletionToolParam: + return ChatCompletionToolParam( + type="function", + function=FunctionDefinition( + name=name, + description=description, + parameters={ + "type": "object", + "properties": properties, + "required": required, + }, + ), + ) + + +def system_message(content: str) -> ChatCompletionSystemMessageParam: + return ChatCompletionSystemMessageParam(role="system", content=content) + + +def user_message(content: str) -> ChatCompletionUserMessageParam: + return ChatCompletionUserMessageParam(role="user", content=content) + + +def assistant_message(content: str) -> ChatCompletionAssistantMessageParam: + return ChatCompletionAssistantMessageParam( + role="assistant", content=content + ) + + +def tool_calls_message( + tool_calls: list[ChatCompletionMessageToolCallParam], +) -> ChatCompletionAssistantMessageParam: + return ChatCompletionAssistantMessageParam( + role="assistant", tool_calls=tool_calls + ) + + +def tool_message( + content: str, tool_call_id: str +) -> ChatCompletionToolMessageParam: + return ChatCompletionToolMessageParam( + role="tool", + content=content, + tool_call_id=tool_call_id, + ) diff --git a/aidial_assistant/utils/state.py b/aidial_assistant/utils/state.py index a1a6734..f8c9dd5 100644 --- a/aidial_assistant/utils/state.py +++ b/aidial_assistant/utils/state.py @@ -1,9 +1,19 @@ +import json from typing import TypedDict from aidial_sdk.chat_completion.request import CustomContent, Message, Role +from aidial_assistant.chain.command_result import ( + CommandInvocation, + commands_to_text, +) from aidial_assistant.chain.history import MessageScope, ScopedMessage -from aidial_assistant.model.model_client import Message as ModelMessage +from aidial_assistant.utils.exceptions import RequestParameterValidationError +from aidial_assistant.utils.open_ai import ( + assistant_message, + system_message, + user_message, +) class Invocation(TypedDict): @@ -32,6 +42,32 @@ def _get_invocations(custom_content: CustomContent | None) -> list[Invocation]: return invocations +def _convert_old_commands(string: str) -> str: + """Converts old commands to new format. + Previously saved conversations with assistant will stop working if state is not updated. + + Old format: + {"commands": [{"command": "run-addon", "args": ["", ""]}]} + New format: + {"commands": [{"command": "", "arguments": {"query": ""}}]} + """ + commands = json.loads(string) + result: list[CommandInvocation] = [] + + for command in commands["commands"]: + command_name = command["command"] + # run-addon was previously called run-plugin + if command_name in ("run-addon", "run-plugin"): + args = command["args"] + result.append( + CommandInvocation(command=args[0], arguments={"query": args[1]}) + ) + else: + result.append(command) + + return commands_to_text(result) + + def parse_history(history: list[Message]) -> list[ScopedMessage]: messages: list[ScopedMessage] = [] for message in history: @@ -41,22 +77,32 @@ def parse_history(history: list[Message]) -> list[ScopedMessage]: messages.append( ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.assistant(invocation["request"]), + message=assistant_message( + _convert_old_commands(invocation["request"]) + ), ) ) messages.append( ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.user(invocation["response"]), + message=user_message(invocation["response"]), ) ) - messages.append( - ScopedMessage( - message=ModelMessage( - role=message.role, content=message.content or "" - ) + messages.append( + ScopedMessage(message=assistant_message(message.content or "")) + ) + elif message.role == Role.USER: + messages.append( + ScopedMessage(message=user_message(message.content or "")) + ) + elif message.role == Role.SYSTEM: + messages.append( + ScopedMessage(message=system_message(message.content or "")) + ) + else: + raise RequestParameterValidationError( + f"Role {message.role} is not supported.", param="messages" ) - ) return messages diff --git a/poetry.lock b/poetry.lock index 7728ec9..55e2b0a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "aidial-sdk" -version = "0.5.0" +version = "0.5.1" description = "Framework to create applications and model adapters for AI DIAL" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "aidial_sdk-0.5.0-py3-none-any.whl", hash = "sha256:db0cb45440d055a4361cdd35baf3b7db4d51c3c5b7c63a901ca920638937a26f"}, - {file = "aidial_sdk-0.5.0.tar.gz", hash = "sha256:29df146c44953ed90cecb07fb58c2087c800c511fa6a1a515392ed4de3b44621"}, + {file = "aidial_sdk-0.5.1-py3-none-any.whl", hash = "sha256:345e8f59593adf616be9b9bad6f46b98b0a6e7fcd7cc17932fabe8c266b3cfe4"}, + {file = "aidial_sdk-0.5.1.tar.gz", hash = "sha256:5bb327882c90719b3054b52f1e211c00fb9667b2c2010aeb6bbd60f6f40ea1d4"}, ] [package.dependencies] @@ -21,19 +21,20 @@ opentelemetry-exporter-prometheus = {version = "1.12.0rc1", optional = true, mar opentelemetry-instrumentation = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-instrumentation-aiohttp-client = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-instrumentation-fastapi = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} +opentelemetry-instrumentation-httpx = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-instrumentation-logging = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-instrumentation-requests = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-instrumentation-system-metrics = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-instrumentation-urllib = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-sdk = {version = "1.20.0", optional = true, markers = "extra == \"telemetry\""} +prometheus-client = {version = "0.17.1", optional = true, markers = "extra == \"telemetry\""} pydantic = ">=1.10,<3" requests = ">=2.19,<3.0" -starlette-exporter = {version = "0.16.0", optional = true, markers = "extra == \"telemetry\""} uvicorn = ">=0.19,<1.0" wrapt = ">=1.14,<2.0" [package.extras] -telemetry = ["opentelemetry-api (==1.20.0)", "opentelemetry-distro (==0.41b0)", "opentelemetry-exporter-otlp-proto-grpc (==1.20.0)", "opentelemetry-exporter-prometheus (==1.12.0rc1)", "opentelemetry-instrumentation (==0.41b0)", "opentelemetry-instrumentation-aiohttp-client (==0.41b0)", "opentelemetry-instrumentation-fastapi (==0.41b0)", "opentelemetry-instrumentation-logging (==0.41b0)", "opentelemetry-instrumentation-requests (==0.41b0)", "opentelemetry-instrumentation-system-metrics (==0.41b0)", "opentelemetry-instrumentation-urllib (==0.41b0)", "opentelemetry-sdk (==1.20.0)", "starlette-exporter (==0.16.0)"] +telemetry = ["opentelemetry-api (==1.20.0)", "opentelemetry-distro (==0.41b0)", "opentelemetry-exporter-otlp-proto-grpc (==1.20.0)", "opentelemetry-exporter-prometheus (==1.12.0rc1)", "opentelemetry-instrumentation (==0.41b0)", "opentelemetry-instrumentation-aiohttp-client (==0.41b0)", "opentelemetry-instrumentation-fastapi (==0.41b0)", "opentelemetry-instrumentation-httpx (==0.41b0)", "opentelemetry-instrumentation-logging (==0.41b0)", "opentelemetry-instrumentation-requests (==0.41b0)", "opentelemetry-instrumentation-system-metrics (==0.41b0)", "opentelemetry-instrumentation-urllib (==0.41b0)", "opentelemetry-sdk (==1.20.0)", "prometheus-client (==0.17.1)"] [[package]] name = "aiocache" @@ -490,6 +491,17 @@ files = [ {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, ] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "fastapi" version = "0.103.2" @@ -782,6 +794,51 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "httpcore" +version = "1.0.2" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"}, + {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.23.0)"] + +[[package]] +name = "httpx" +version = "0.26.0" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.26.0-py3-none-any.whl", hash = "sha256:8915f5a3627c4d47b73e8202457cb28f1266982d1159bd5779d86a80c0eab1cd"}, + {file = "httpx-0.26.0.tar.gz", hash = "sha256:451b55c30d5185ea6b23c2c793abf9bb237d2a7dfb901ced6ff69ad37ec1dfaf"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "idna" version = "3.4" @@ -884,21 +941,22 @@ files = [ [[package]] name = "langchain" -version = "0.0.329" +version = "0.0.350" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchain-0.0.329-py3-none-any.whl", hash = "sha256:5f3e884991271e8b55eda4c63a11105dcd7da119682ce0e3d5d1385b3a4103d2"}, - {file = "langchain-0.0.329.tar.gz", hash = "sha256:488f3cb68a587696f136d4f01f97df8d8270e295b3cc56158057dab0f61f4166"}, + {file = "langchain-0.0.350-py3-none-any.whl", hash = "sha256:11b605f325a4271a7815baaec05bc7622e3ad1f10f26b05c752cafa27663ed38"}, + {file = "langchain-0.0.350.tar.gz", hash = "sha256:f0e68a92d200bb722586688ab7b411b2430bd98ad265ca03b264e7e7acbb6c01"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" -anyio = "<4.0" dataclasses-json = ">=0.5.7,<0.7" jsonpatch = ">=1.33,<2.0" -langsmith = ">=0.0.52,<0.1.0" +langchain-community = ">=0.0.2,<0.1" +langchain-core = ">=0.1,<0.2" +langsmith = ">=0.0.63,<0.1.0" numpy = ">=1,<2" pydantic = ">=1,<3" PyYAML = ">=5.3" @@ -907,29 +965,78 @@ SQLAlchemy = ">=1.4,<3" tenacity = ">=8.1.0,<9.0.0" [package.extras] -all = ["O365 (>=2.0.26,<3.0.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "amadeus (>=8.1.0)", "arxiv (>=1.4,<2.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "awadb (>=0.3.9,<0.4.0)", "azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-vision (>=0.11.1b1,<0.12.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "beautifulsoup4 (>=4,<5)", "clarifai (>=9.1.0)", "clickhouse-connect (>=0.5.14,<0.6.0)", "cohere (>=4,<5)", "deeplake (>=3.8.3,<4.0.0)", "docarray[hnswlib] (>=0.32.0,<0.33.0)", "duckduckgo-search (>=3.8.3,<4.0.0)", "elasticsearch (>=8,<9)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "google-api-python-client (==2.70.0)", "google-auth (>=2.18.1,<3.0.0)", "google-search-results (>=2,<3)", "gptcache (>=0.1.7)", "html2text (>=2020.1.16,<2021.0.0)", "huggingface_hub (>=0,<1)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "lancedb (>=0.1,<0.2)", "langkit (>=0.0.6,<0.1.0)", "lark (>=1.1.5,<2.0.0)", "librosa (>=0.10.0.post2,<0.11.0)", "lxml (>=4.9.2,<5.0.0)", "manifest-ml (>=0.0.1,<0.0.2)", "marqo (>=1.2.4,<2.0.0)", "momento (>=1.10.1,<2.0.0)", "nebula3-python (>=3.4.0,<4.0.0)", "neo4j (>=5.8.1,<6.0.0)", "networkx (>=2.6.3,<4)", "nlpcloud (>=1,<2)", "nltk (>=3,<4)", "nomic (>=1.0.43,<2.0.0)", "openai (>=0,<1)", "openlm (>=0.0.5,<0.0.6)", "opensearch-py (>=2.0.0,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pexpect (>=4.8.0,<5.0.0)", "pgvector (>=0.1.6,<0.2.0)", "pinecone-client (>=2,<3)", "pinecone-text (>=0.4.2,<0.5.0)", "psycopg2-binary (>=2.9.5,<3.0.0)", "pymongo (>=4.3.3,<5.0.0)", "pyowm (>=3.3.0,<4.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pytesseract (>=0.3.10,<0.4.0)", "python-arango (>=7.5.9,<8.0.0)", "pyvespa (>=0.33.0,<0.34.0)", "qdrant-client (>=1.3.1,<2.0.0)", "rdflib (>=6.3.2,<7.0.0)", "redis (>=4,<5)", "requests-toolbelt (>=1.0.0,<2.0.0)", "sentence-transformers (>=2,<3)", "singlestoredb (>=0.7.1,<0.8.0)", "tensorflow-text (>=2.11.0,<3.0.0)", "tigrisdb (>=1.0.0b6,<2.0.0)", "tiktoken (>=0.3.2,<0.6.0)", "torch (>=1,<3)", "transformers (>=4,<5)", "weaviate-client (>=3,<4)", "wikipedia (>=1,<2)", "wolframalpha (==5.0.0)"] -azure = ["azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-vision (>=0.11.1b1,<0.12.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-core (>=1.26.4,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "azure-search-documents (==11.4.0b8)", "openai (>=0,<1)"] +azure = ["azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-textanalytics (>=5.3.0,<6.0.0)", "azure-ai-vision (>=0.11.1b1,<0.12.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-core (>=1.26.4,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "azure-search-documents (==11.4.0b8)", "openai (<2)"] clarifai = ["clarifai (>=9.1.0)"] cli = ["typer (>=0.9.0,<0.10.0)"] cohere = ["cohere (>=4,<5)"] docarray = ["docarray[hnswlib] (>=0.32.0,<0.33.0)"] embeddings = ["sentence-transformers (>=2,<3)"] -extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "dashvector (>=1.0.1,<2.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.6.0,<0.7.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "html2text (>=2020.1.16,<2021.0.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (>=0,<1)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] +extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<5)", "couchbase (>=4.1.9,<5.0.0)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] javascript = ["esprima (>=4.0.1,<5.0.0)"] -llms = ["clarifai (>=9.1.0)", "cohere (>=4,<5)", "huggingface_hub (>=0,<1)", "manifest-ml (>=0.0.1,<0.0.2)", "nlpcloud (>=1,<2)", "openai (>=0,<1)", "openlm (>=0.0.5,<0.0.6)", "torch (>=1,<3)", "transformers (>=4,<5)"] -openai = ["openai (>=0,<1)", "tiktoken (>=0.3.2,<0.6.0)"] +llms = ["clarifai (>=9.1.0)", "cohere (>=4,<5)", "huggingface_hub (>=0,<1)", "manifest-ml (>=0.0.1,<0.0.2)", "nlpcloud (>=1,<2)", "openai (<2)", "openlm (>=0.0.5,<0.0.6)", "torch (>=1,<3)", "transformers (>=4,<5)"] +openai = ["openai (<2)", "tiktoken (>=0.3.2,<0.6.0)"] qdrant = ["qdrant-client (>=1.3.1,<2.0.0)"] text-helpers = ["chardet (>=5.1.0,<6.0.0)"] +[[package]] +name = "langchain-community" +version = "0.0.10" +description = "Community contributed LangChain integrations." +optional = false +python-versions = ">=3.8.1,<4.0" +files = [ + {file = "langchain_community-0.0.10-py3-none-any.whl", hash = "sha256:37123ce31018bc7ad3ffda8af73c46e16d568270527a546c34e8dbce713377af"}, + {file = "langchain_community-0.0.10.tar.gz", hash = "sha256:4d7b3510e04b80dfddace32fb5db0878e9bab7d4be7288f86112ed22dc5faf68"}, +] + +[package.dependencies] +aiohttp = ">=3.8.3,<4.0.0" +dataclasses-json = ">=0.5.7,<0.7" +langchain-core = ">=0.1.8,<0.2" +langsmith = ">=0.0.63,<0.1.0" +numpy = ">=1,<2" +PyYAML = ">=5.3" +requests = ">=2,<3" +SQLAlchemy = ">=1.4,<3" +tenacity = ">=8.1.0,<9.0.0" + +[package.extras] +cli = ["typer (>=0.9.0,<0.10.0)"] +extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<5)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)", "zhipuai (>=1.0.7,<2.0.0)"] + +[[package]] +name = "langchain-core" +version = "0.1.8" +description = "Building applications with LLMs through composability" +optional = false +python-versions = ">=3.8.1,<4.0" +files = [ + {file = "langchain_core-0.1.8-py3-none-any.whl", hash = "sha256:f4d1837d6d814ed36528b642211933d1f0bd84e1eff361f4630a8c750acc27d0"}, + {file = "langchain_core-0.1.8.tar.gz", hash = "sha256:93ab72f5ab202526310fad389a45626501fd76ecf56d451111c0d4abe8183407"}, +] + +[package.dependencies] +anyio = ">=3,<5" +jsonpatch = ">=1.33,<2.0" +langsmith = ">=0.0.63,<0.1.0" +packaging = ">=23.2,<24.0" +pydantic = ">=1,<3" +PyYAML = ">=5.3" +requests = ">=2,<3" +tenacity = ">=8.1.0,<9.0.0" + +[package.extras] +extended-testing = ["jinja2 (>=3,<4)"] + [[package]] name = "langsmith" -version = "0.0.54" +version = "0.0.78" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langsmith-0.0.54-py3-none-any.whl", hash = "sha256:55eca5967cadb661a49ad32aecda48a824fadef202ca384575209a9d6f823b74"}, - {file = "langsmith-0.0.54.tar.gz", hash = "sha256:76c8e34b4d10ad93541107138089635829f9d60601a7f6bddf5ba582d178e521"}, + {file = "langsmith-0.0.78-py3-none-any.whl", hash = "sha256:d7c8300700dde0cea87388177c2552187e87fb4ae789510712e7654db72b5c04"}, + {file = "langsmith-0.0.78.tar.gz", hash = "sha256:a7d7f1639072aeb12115a931eb6d4c53810a480a1fec90bc8744f232765f3c81"}, ] [package.dependencies] @@ -1200,25 +1307,26 @@ files = [ [[package]] name = "openai" -version = "0.28.1" -description = "Python client library for the OpenAI API" +version = "1.6.1" +description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-0.28.1-py3-none-any.whl", hash = "sha256:d18690f9e3d31eedb66b57b88c2165d760b24ea0a01f150dd3f068155088ce68"}, - {file = "openai-0.28.1.tar.gz", hash = "sha256:4be1dad329a65b4ce1a660fe6d5431b438f429b5855c883435f0f7fcb6d2dcc8"}, + {file = "openai-1.6.1-py3-none-any.whl", hash = "sha256:bc9f774838d67ac29fb24cdeb2d58faf57de8b311085dcd1348f7aa02a96c7ee"}, + {file = "openai-1.6.1.tar.gz", hash = "sha256:d553ca9dbf9486b08e75b09e8671e4f638462aaadccfced632bf490fc3d75fa2"}, ] [package.dependencies] -aiohttp = "*" -requests = ">=2.20" -tqdm = "*" +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.7,<5" [package.extras] -datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] -dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-mock"] -embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"] -wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] [[package]] name = "openapi-pydantic" @@ -1419,6 +1527,26 @@ opentelemetry-util-http = "0.41b0" instruments = ["fastapi (>=0.58,<1.0)"] test = ["httpx (>=0.22,<1.0)", "opentelemetry-instrumentation-fastapi[instruments]", "opentelemetry-test-utils (==0.41b0)", "requests (>=2.23,<3.0)"] +[[package]] +name = "opentelemetry-instrumentation-httpx" +version = "0.41b0" +description = "OpenTelemetry HTTPX Instrumentation" +optional = false +python-versions = ">=3.7" +files = [ + {file = "opentelemetry_instrumentation_httpx-0.41b0-py3-none-any.whl", hash = "sha256:6ada84b7caa95a2889b2d883c089a977546b0102c815658b88f1c2dae713e9b2"}, + {file = "opentelemetry_instrumentation_httpx-0.41b0.tar.gz", hash = "sha256:96ebc54f3f41bfcd2fc043349c8cee4b11737602512383d437e24c39a1e4adff"}, +] + +[package.dependencies] +opentelemetry-api = ">=1.12,<2.0" +opentelemetry-instrumentation = "0.41b0" +opentelemetry-semantic-conventions = "0.41b0" + +[package.extras] +instruments = ["httpx (>=0.18.0)"] +test = ["opentelemetry-instrumentation-httpx[instruments]", "opentelemetry-sdk (>=1.12,<2.0)", "opentelemetry-test-utils (==0.41b0)"] + [[package]] name = "opentelemetry-instrumentation-logging" version = "0.41b0" @@ -1604,13 +1732,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "prometheus-client" -version = "0.19.0" +version = "0.17.1" description = "Python client for the Prometheus monitoring system." optional = false -python-versions = ">=3.8" +python-versions = ">=3.6" files = [ - {file = "prometheus_client-0.19.0-py3-none-any.whl", hash = "sha256:c88b1e6ecf6b41cd8fb5731c7ae919bf66df6ec6fafa555cd6c0e16ca169ae92"}, - {file = "prometheus_client-0.19.0.tar.gz", hash = "sha256:4585b0d1223148c27a225b10dbec5ae9bc4c81a99a3fa80774fa6209935324e1"}, + {file = "prometheus_client-0.17.1-py3-none-any.whl", hash = "sha256:e537f37160f6807b8202a6fc4764cdd19bac5480ddd3e0d463c3002b34462101"}, + {file = "prometheus_client-0.17.1.tar.gz", hash = "sha256:21e674f39831ae3f8acde238afd9a27a37d0d2fb5a28ea094f0ce25d2cbf2091"}, ] [package.extras] @@ -2010,21 +2138,6 @@ anyio = ">=3.4.0,<5" [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"] -[[package]] -name = "starlette-exporter" -version = "0.16.0" -description = "Prometheus metrics exporter for Starlette applications." -optional = false -python-versions = "*" -files = [ - {file = "starlette_exporter-0.16.0-py3-none-any.whl", hash = "sha256:9dbe8dc647acbeb8680d53cedbbb8042ca75ca1b6987f609c5601ea96ddb7422"}, - {file = "starlette_exporter-0.16.0.tar.gz", hash = "sha256:728cccf975c85d3cf2844b0110b51e1fa2dce628ef68bc38da58ad691f9b5d68"}, -] - -[package.dependencies] -prometheus-client = ">=0.12" -starlette = "*" - [[package]] name = "tenacity" version = "8.2.3" @@ -2329,4 +2442,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "8628476fd0575c2f255e186b6893df2086abfdc9224a4cdebc417ab7beae8e85" +content-hash = "147e4e4a549e48fafe8133a5aeb00839fb1ee7dea31f2720c8630e9b10dc7f4c" diff --git a/pyproject.toml b/pyproject.toml index d690c78..6271286 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,12 +19,12 @@ clean = "scripts.clean:main" python = "^3.11" aiocache = "^0.12.2" jinja2 = "^3.1.2" -langchain = "^0.0.329" -openai = "^0.28.0" +langchain = "^0.0.350" +openai = "^1.3.9" pydantic = "1.10.13" pyyaml = "^6.0.1" typing-extensions = "^4.8.0" -aidial-sdk = { version = "^0.5.0", extras = ["telemetry"] } +aidial-sdk = { version = "^0.5.1", extras = ["telemetry"] } aiohttp = "^3.9.0" openapi-schema-pydantic = "^1.2.4" openapi-pydantic = "^0.3.2" diff --git a/tests/unit_tests/application/test_addons_dialogue_limiter.py b/tests/unit_tests/application/test_addons_dialogue_limiter.py index b323b0d..871f4f6 100644 --- a/tests/unit_tests/application/test_addons_dialogue_limiter.py +++ b/tests/unit_tests/application/test_addons_dialogue_limiter.py @@ -6,7 +6,12 @@ AddonsDialogueLimiter, ) from aidial_assistant.chain.command_chain import LimitExceededException -from aidial_assistant.model.model_client import Message, ModelClient +from aidial_assistant.model.model_client import ModelClient +from aidial_assistant.utils.open_ai import ( + assistant_message, + system_message, + user_message, +) MAX_TOKENS = 1 @@ -17,8 +22,11 @@ async def test_dialogue_size_is_ok(): model.count_tokens.side_effect = [1, 2] limiter = AddonsDialogueLimiter(MAX_TOKENS, model) - initial_messages = [Message.system("a"), Message.user("b")] - dialogue_messages = [Message.assistant("c"), Message.user("d")] + initial_messages = [system_message("a"), user_message("b")] + dialogue_messages = [ + assistant_message("c"), + user_message("d"), + ] await limiter.verify_limit(initial_messages) await limiter.verify_limit(initial_messages + dialogue_messages) @@ -35,8 +43,11 @@ async def test_dialogue_overflow(): model.count_tokens.side_effect = [1, 3] limiter = AddonsDialogueLimiter(MAX_TOKENS, model) - initial_messages = [Message.system("a"), Message.user("b")] - dialogue_messages = [Message.assistant("c"), Message.user("d")] + initial_messages = [system_message("a"), user_message("b")] + dialogue_messages = [ + assistant_message("c"), + user_message("d"), + ] await limiter.verify_limit(initial_messages) with pytest.raises(LimitExceededException) as exc_info: diff --git a/tests/unit_tests/chain/test_command_chain_best_effort.py b/tests/unit_tests/chain/test_command_chain_best_effort.py index 74244d9..f17fd4e 100644 --- a/tests/unit_tests/chain/test_command_chain_best_effort.py +++ b/tests/unit_tests/chain/test_command_chain_best_effort.py @@ -1,10 +1,10 @@ import json from unittest.mock import MagicMock, Mock, call +import httpx import pytest -from aidial_sdk.chat_completion import Role from jinja2 import Template -from openai import InvalidRequestError +from openai import BadRequestError from aidial_assistant.chain.callbacks.chain_callback import ChainCallback from aidial_assistant.chain.callbacks.result_callback import ResultCallback @@ -15,7 +15,12 @@ ) from aidial_assistant.chain.history import History, ScopedMessage from aidial_assistant.commands.base import Command, TextResult -from aidial_assistant.model.model_client import Message, ModelClient +from aidial_assistant.model.model_client import ModelClient +from aidial_assistant.utils.open_ai import ( + assistant_message, + system_message, + user_message, +) from tests.utils.async_helper import to_async_string, to_async_strings SYSTEM_MESSAGE = "" @@ -28,7 +33,11 @@ TEST_COMMAND_NAME = "" TEST_COMMAND_OUTPUT = "" TEST_COMMAND_REQUEST = json.dumps( - {"commands": [{"command": TEST_COMMAND_NAME, "args": ["test_arg"]}]} + { + "commands": [ + {"command": TEST_COMMAND_NAME, "arguments": {"arg": "value"}} + ] + } ) TEST_COMMAND_RESPONSE = json.dumps( {"responses": [{"status": "SUCCESS", "response": TEST_COMMAND_OUTPUT}]} @@ -41,10 +50,8 @@ "user_message={{message}}, error={{error}}, dialogue={{dialogue}}" ), scoped_messages=[ - ScopedMessage( - message=Message(role=Role.SYSTEM, content=SYSTEM_MESSAGE) - ), - ScopedMessage(message=Message(role=Role.USER, content=USER_MESSAGE)), + ScopedMessage(message=system_message(SYSTEM_MESSAGE)), + ScopedMessage(message=user_message(USER_MESSAGE)), ], ) @@ -76,14 +83,14 @@ async def test_model_doesnt_support_protocol(): assert model_client.agenerate.call_args_list == [ call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(SYSTEM_MESSAGE), - Message.user(USER_MESSAGE), + system_message(SYSTEM_MESSAGE), + user_message(USER_MESSAGE), ] ), ] @@ -111,8 +118,8 @@ async def test_model_partially_supports_protocol(): result_callback = Mock(spec=ResultCallback) chain_callback.result_callback.return_value = result_callback succeeded_dialogue = [ - Message.assistant(TEST_COMMAND_REQUEST), - Message.user(TEST_COMMAND_RESPONSE), + assistant_message(TEST_COMMAND_REQUEST), + user_message(TEST_COMMAND_RESPONSE), ] await command_chain.run_chat(history=TEST_HISTORY, callback=chain_callback) @@ -126,22 +133,22 @@ async def test_model_partially_supports_protocol(): assert model_client.agenerate.call_args_list == [ call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(USER_MESSAGE), - Message.assistant(TEST_COMMAND_REQUEST), - Message.user(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(USER_MESSAGE), + assistant_message(TEST_COMMAND_REQUEST), + user_message(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(SYSTEM_MESSAGE), - Message.user( + system_message(SYSTEM_MESSAGE), + user_message( f"user_message={USER_MESSAGE}, error={FAILED_PROTOCOL_ERROR}, dialogue={succeeded_dialogue}" ), ] @@ -154,7 +161,18 @@ async def test_no_tokens_for_tools(): model_client = Mock(spec=ModelClient) model_client.agenerate.side_effect = [ to_async_string(TEST_COMMAND_REQUEST), - InvalidRequestError(NO_TOKENS_ERROR, ""), + BadRequestError( + message=NO_TOKENS_ERROR, + response=httpx.Response( + request=httpx.Request("GET", "http://localhost"), + status_code=400, + ), + body={ + "type": "", + "code": "", + "param": "", + }, + ), to_async_string(BEST_EFFORT_ANSWER), ] test_command = Mock(spec=Command) @@ -180,22 +198,22 @@ async def test_no_tokens_for_tools(): assert model_client.agenerate.call_args_list == [ call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(USER_MESSAGE), - Message.assistant(TEST_COMMAND_REQUEST), - Message.user(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(USER_MESSAGE), + assistant_message(TEST_COMMAND_REQUEST), + user_message(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(SYSTEM_MESSAGE), - Message.user( + system_message(SYSTEM_MESSAGE), + user_message( f"user_message={USER_MESSAGE}, error={NO_TOKENS_ERROR}, dialogue=[]" ), ] @@ -238,14 +256,14 @@ async def test_model_request_limit_exceeded(): assert model_client.agenerate.call_args_list == [ call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(SYSTEM_MESSAGE), - Message.user( + system_message(SYSTEM_MESSAGE), + user_message( f"user_message={USER_MESSAGE}, error={LIMIT_EXCEEDED_ERROR}, dialogue=[]" ), ] @@ -254,16 +272,16 @@ async def test_model_request_limit_exceeded(): assert model_request_limiter.verify_limit.call_args_list == [ call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(USER_MESSAGE), - Message.assistant(TEST_COMMAND_REQUEST), - Message.user(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(USER_MESSAGE), + assistant_message(TEST_COMMAND_REQUEST), + user_message(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), ] ), ] diff --git a/tests/unit_tests/chain/test_history.py b/tests/unit_tests/chain/test_history.py index 0916ec1..c3e6317 100644 --- a/tests/unit_tests/chain/test_history.py +++ b/tests/unit_tests/chain/test_history.py @@ -4,7 +4,12 @@ from jinja2 import Template from aidial_assistant.chain.history import History, MessageScope, ScopedMessage -from aidial_assistant.model.model_client import Message, ModelClient +from aidial_assistant.model.model_client import ModelClient +from aidial_assistant.utils.open_ai import ( + assistant_message, + system_message, + user_message, +) TRUNCATION_TEST_DATA = [ (0, [0, 1, 2, 3, 4, 5, 6]), @@ -28,18 +33,19 @@ async def test_history_truncation( assistant_system_message_template=Template(""), best_effort_template=Template(""), scoped_messages=[ - ScopedMessage(message=Message.system(content="a")), - ScopedMessage(message=Message.user(content="b")), - ScopedMessage(message=Message.system(content="c")), + ScopedMessage(message=system_message("a")), + ScopedMessage(message=user_message("b")), + ScopedMessage(message=system_message("c")), ScopedMessage( - message=Message.assistant(content="d"), + message=assistant_message("d"), scope=MessageScope.INTERNAL, ), ScopedMessage( - message=Message.user(content="e"), scope=MessageScope.INTERNAL + message=user_message(content="e"), + scope=MessageScope.INTERNAL, ), - ScopedMessage(message=Message.assistant(content="f")), - ScopedMessage(message=Message.user(content="g")), + ScopedMessage(message=assistant_message("f")), + ScopedMessage(message=user_message("g")), ], ) @@ -64,8 +70,8 @@ async def test_truncation_overflow(): assistant_system_message_template=Template(""), best_effort_template=Template(""), scoped_messages=[ - ScopedMessage(message=Message.system(content="a")), - ScopedMessage(message=Message.user(content="b")), + ScopedMessage(message=system_message("a")), + ScopedMessage(message=user_message("b")), ], ) @@ -87,9 +93,10 @@ async def test_truncation_with_incorrect_message_sequence(): best_effort_template=Template(""), scoped_messages=[ ScopedMessage( - message=Message.user(content="a"), scope=MessageScope.INTERNAL + message=user_message("a"), + scope=MessageScope.INTERNAL, ), - ScopedMessage(message=Message.user(content="b")), + ScopedMessage(message=user_message("b")), ], ) @@ -106,25 +113,25 @@ async def test_truncation_with_incorrect_message_sequence(): def test_protocol_messages_with_system_message(): - system_message = "" - user_message = "" - assistant_message = "" + system_content = "" + user_content = "" + assistant_content = "" history = History( assistant_system_message_template=Template( "system message={{system_prefix}}" ), best_effort_template=Template(""), scoped_messages=[ - ScopedMessage(message=Message.system(system_message)), - ScopedMessage(message=Message.user(user_message)), - ScopedMessage(message=Message.assistant(assistant_message)), + ScopedMessage(message=system_message(system_content)), + ScopedMessage(message=user_message(user_content)), + ScopedMessage(message=assistant_message(assistant_content)), ], ) assert history.to_protocol_messages() == [ - Message.system(f"system message={system_message}"), - Message.user(user_message), - Message.assistant( - f'{{"commands": [{{"command": "reply", "args": ["{assistant_message}"]}}]}}' + system_message(f"system message={system_content}"), + user_message(user_content), + assistant_message( + f'{{"commands": [{{"command": "reply", "arguments": {{"message": "{assistant_content}"}}}}]}}' ), ] diff --git a/tests/unit_tests/chain/test_model_client.py b/tests/unit_tests/chain/test_model_client.py deleted file mode 100644 index 3457901..0000000 --- a/tests/unit_tests/chain/test_model_client.py +++ /dev/null @@ -1,105 +0,0 @@ -from unittest import mock -from unittest.mock import Mock, call - -import pytest - -from aidial_assistant.model.model_client import ( - ExtraResultsCallback, - Message, - ModelClient, - ReasonLengthException, -) -from aidial_assistant.utils.text import join_string -from tests.utils.async_helper import to_async_iterator - -API_METHOD = "openai.ChatCompletion.acreate" -MODEL_ARGS = {"model": "args"} -BUFFER_SIZE = 321 - - -@mock.patch(API_METHOD) -@pytest.mark.asyncio -async def test_discarded_messages(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator( - [ - { - "choices": [{"delta": {"content": ""}}], - "statistics": {"discarded_messages": 2}, - } - ] - ) - extra_results_callback = Mock(spec=ExtraResultsCallback) - - await join_string(model_client.agenerate([], extra_results_callback)) - - assert extra_results_callback.on_discarded_messages.call_args_list == [ - call(2) - ] - - -@mock.patch(API_METHOD) -@pytest.mark.asyncio -async def test_content(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator( - [ - {"choices": [{"delta": {"content": "one, "}}]}, - {"choices": [{"delta": {"content": "two, "}}]}, - {"choices": [{"delta": {"content": "three"}}]}, - ] - ) - - assert await join_string(model_client.agenerate([])) == "one, two, three" - - -@mock.patch(API_METHOD) -@pytest.mark.asyncio -async def test_reason_length_with_usage(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator( - [ - {"choices": [{"delta": {"content": "text"}}]}, - { - "choices": [ - {"delta": {"content": ""}, "finish_reason": "length"} # type: ignore - ] - }, - { - "choices": [{"delta": {"content": ""}}], - "usage": {"prompt_tokens": 1, "completion_tokens": 2}, - }, - ] - ) - - with pytest.raises(ReasonLengthException): - async for chunk in model_client.agenerate([]): - assert chunk == "text" - - assert model_client.total_prompt_tokens == 1 - assert model_client.total_completion_tokens == 2 - - -@mock.patch(API_METHOD) -@pytest.mark.asyncio -async def test_api_args(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator([]) - messages = [ - Message.system(content="a"), - Message.user(content="b"), - Message.assistant(content="c"), - ] - - await join_string(model_client.agenerate(messages)) - - assert api.call_args_list == [ - call( - messages=[ - {"role": "system", "content": "a"}, - {"role": "user", "content": "b"}, - {"role": "assistant", "content": "c"}, - ], - **MODEL_ARGS, - ) - ] diff --git a/tests/unit_tests/model/test_model_client.py b/tests/unit_tests/model/test_model_client.py index 3457901..a5ed1cf 100644 --- a/tests/unit_tests/model/test_model_client.py +++ b/tests/unit_tests/model/test_model_client.py @@ -1,34 +1,56 @@ -from unittest import mock from unittest.mock import Mock, call import pytest +from openai import AsyncOpenAI +from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall +from pydantic import BaseModel from aidial_assistant.model.model_client import ( ExtraResultsCallback, - Message, ModelClient, ReasonLengthException, ) +from aidial_assistant.utils.open_ai import ( + Usage, + assistant_message, + system_message, + user_message, +) from aidial_assistant.utils.text import join_string -from tests.utils.async_helper import to_async_iterator +from tests.utils.async_helper import to_awaitable_iterator -API_METHOD = "openai.ChatCompletion.acreate" MODEL_ARGS = {"model": "args"} -BUFFER_SIZE = 321 -@mock.patch(API_METHOD) +class Delta(BaseModel): + content: str + tool_calls: list[ChoiceDeltaToolCall] | None = None + + +class Choice(BaseModel): + delta: Delta + finish_reason: str | None = None + + +class Chunk(BaseModel): + choices: list[Choice] + statistics: dict[str, int] | None = None + usage: Usage | None = None + + @pytest.mark.asyncio -async def test_discarded_messages(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator( +async def test_discarded_messages(): + openai_client = Mock(spec=AsyncOpenAI) + openai_client.chat = Mock() + openai_client.chat.completions.create.return_value = to_awaitable_iterator( [ - { - "choices": [{"delta": {"content": ""}}], - "statistics": {"discarded_messages": 2}, - } + Chunk( + choices=[Choice(delta=Delta(content=""))], + statistics={"discarded_messages": 2}, + ) ] ) + model_client = ModelClient(openai_client, MODEL_ARGS) extra_results_callback = Mock(spec=ExtraResultsCallback) await join_string(model_client.agenerate([], extra_results_callback)) @@ -38,39 +60,41 @@ async def test_discarded_messages(api): ] -@mock.patch(API_METHOD) @pytest.mark.asyncio -async def test_content(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator( +async def test_content(): + openai_client = Mock(spec=AsyncOpenAI) + openai_client.chat = Mock() + openai_client.chat.completions.create.return_value = to_awaitable_iterator( [ - {"choices": [{"delta": {"content": "one, "}}]}, - {"choices": [{"delta": {"content": "two, "}}]}, - {"choices": [{"delta": {"content": "three"}}]}, + Chunk(choices=[Choice(delta=Delta(content="one, "))]), + Chunk(choices=[Choice(delta=Delta(content="two, "))]), + Chunk(choices=[Choice(delta=Delta(content="three"))]), ] ) + model_client = ModelClient(openai_client, MODEL_ARGS) assert await join_string(model_client.agenerate([])) == "one, two, three" -@mock.patch(API_METHOD) @pytest.mark.asyncio -async def test_reason_length_with_usage(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator( +async def test_reason_length_with_usage(): + openai_client = Mock(spec=AsyncOpenAI) + openai_client.chat = Mock() + openai_client.chat.completions.create.return_value = to_awaitable_iterator( [ - {"choices": [{"delta": {"content": "text"}}]}, - { - "choices": [ - {"delta": {"content": ""}, "finish_reason": "length"} # type: ignore + Chunk(choices=[Choice(delta=Delta(content="text"))]), + Chunk( + choices=[ + Choice(delta=Delta(content=""), finish_reason="length") ] - }, - { - "choices": [{"delta": {"content": ""}}], - "usage": {"prompt_tokens": 1, "completion_tokens": 2}, - }, + ), + Chunk( + choices=[Choice(delta=Delta(content=""))], + usage=Usage(prompt_tokens=1, completion_tokens=2), + ), ] ) + model_client = ModelClient(openai_client, MODEL_ARGS) with pytest.raises(ReasonLengthException): async for chunk in model_client.agenerate([]): @@ -80,20 +104,23 @@ async def test_reason_length_with_usage(api): assert model_client.total_completion_tokens == 2 -@mock.patch(API_METHOD) @pytest.mark.asyncio -async def test_api_args(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator([]) +async def test_api_args(): + openai_client = Mock(spec=AsyncOpenAI) + openai_client.chat = Mock() + openai_client.chat.completions.create.return_value = to_awaitable_iterator( + [] + ) + model_client = ModelClient(openai_client, MODEL_ARGS) messages = [ - Message.system(content="a"), - Message.user(content="b"), - Message.assistant(content="c"), + system_message("a"), + user_message("b"), + assistant_message("c"), ] - await join_string(model_client.agenerate(messages)) + await join_string(model_client.agenerate(messages, extra="args")) - assert api.call_args_list == [ + assert openai_client.chat.completions.create.call_args_list == [ call( messages=[ {"role": "system", "content": "a"}, @@ -101,5 +128,7 @@ async def test_api_args(api): {"role": "assistant", "content": "c"}, ], **MODEL_ARGS, + stream=True, + extra_body={"extra": "args"}, ) ] diff --git a/tests/unit_tests/tools_chain/__init__.py b/tests/unit_tests/tools_chain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/utils/test_exception_handler.py b/tests/unit_tests/utils/test_exception_handler.py index caf73b2..a9bfc31 100644 --- a/tests/unit_tests/utils/test_exception_handler.py +++ b/tests/unit_tests/utils/test_exception_handler.py @@ -1,6 +1,7 @@ +import httpx import pytest from aidial_sdk import HTTPException -from openai import OpenAIError +from openai import APIStatusError, OpenAIError from aidial_assistant.utils.exceptions import ( RequestParameterValidationError, @@ -29,19 +30,17 @@ async def function(): @pytest.mark.asyncio async def test_openai_error(): - http_status = 123 - @unhandled_exception_handler async def function(): - raise OpenAIError(message=ERROR_MESSAGE, http_status=http_status) + raise OpenAIError(ERROR_MESSAGE) with pytest.raises(HTTPException) as exc_info: await function() assert ( repr(exc_info.value) - == f"HTTPException(message='{ERROR_MESSAGE}', status_code={http_status}," - f" type='runtime_error', param=None, code=None)" + == f"HTTPException(message='{ERROR_MESSAGE}', status_code=500," + f" type='internal_server_error', param=None, code=None)" ) @@ -51,17 +50,21 @@ async def test_openai_error_with_json_body(): error_type = "" error_code = "" json_body = { - "error": { - "message": ERROR_MESSAGE, - "type": error_type, - "code": error_code, - "param": PARAM, - } + "type": error_type, + "code": error_code, + "param": PARAM, } @unhandled_exception_handler async def function(): - raise OpenAIError(json_body=json_body, http_status=http_status) + raise APIStatusError( + ERROR_MESSAGE, + response=httpx.Response( + request=httpx.Request("GET", "http://localhost"), + status_code=http_status, + ), + body=json_body, + ) with pytest.raises(HTTPException) as exc_info: await function() diff --git a/tests/unit_tests/utils/test_state.py b/tests/unit_tests/utils/test_state.py index d7b29a2..58569e8 100644 --- a/tests/unit_tests/utils/test_state.py +++ b/tests/unit_tests/utils/test_state.py @@ -1,15 +1,16 @@ from aidial_sdk.chat_completion import CustomContent, Message, Role from aidial_assistant.chain.history import MessageScope, ScopedMessage -from aidial_assistant.model.model_client import Message as ModelMessage +from aidial_assistant.utils.open_ai import assistant_message, user_message from aidial_assistant.utils.state import parse_history FIRST_USER_MESSAGE = "" SECOND_USER_MESSAGE = "" FIRST_ASSISTANT_MESSAGE = "" SECOND_ASSISTANT_MESSAGE = "" -FIRST_REQUEST = "" -SECOND_REQUEST = "" +FIRST_REQUEST = '{"commands": [{"command": "run-addon", "args": ["", ""]}]}' +FIRST_REQUEST_FIXED = '{"commands": [{"command": "", "arguments": {"query": ""}}]}' +SECOND_REQUEST = '{"commands": [{"command": "", "arguments": {"query": ""}}]}' FIRST_RESPONSE = "" SECOND_RESPONSE = "" @@ -44,34 +45,34 @@ def test_parse_history(): assert parse_history(messages) == [ ScopedMessage( scope=MessageScope.USER, - message=ModelMessage.user(FIRST_USER_MESSAGE), + message=user_message(FIRST_USER_MESSAGE), ), ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.assistant(FIRST_REQUEST), + message=assistant_message(FIRST_REQUEST_FIXED), ), ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.user(FIRST_RESPONSE), + message=user_message(FIRST_RESPONSE), ), ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.assistant(SECOND_REQUEST), + message=assistant_message(SECOND_REQUEST), ), ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.user(content=SECOND_RESPONSE), + message=user_message(content=SECOND_RESPONSE), ), ScopedMessage( scope=MessageScope.USER, - message=ModelMessage.assistant(FIRST_ASSISTANT_MESSAGE), + message=assistant_message(FIRST_ASSISTANT_MESSAGE), ), ScopedMessage( scope=MessageScope.USER, - message=ModelMessage.user(SECOND_USER_MESSAGE), + message=user_message(SECOND_USER_MESSAGE), ), ScopedMessage( scope=MessageScope.USER, - message=ModelMessage.assistant(SECOND_ASSISTANT_MESSAGE), + message=assistant_message(SECOND_ASSISTANT_MESSAGE), ), ] diff --git a/tests/utils/async_helper.py b/tests/utils/async_helper.py index 00e3bbb..5021323 100644 --- a/tests/utils/async_helper.py +++ b/tests/utils/async_helper.py @@ -20,3 +20,7 @@ def to_async_repeated_string( async def to_async_iterator(sequence: Iterable[T]) -> AsyncIterator[T]: for item in sequence: yield item + + +async def to_awaitable_iterator(sequence: Iterable[T]) -> AsyncIterator[T]: + return to_async_iterator(sequence)