Skip to content

Commit

Permalink
Return discarded messages as a list of message indices. (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
Oleksii-Klimov authored Mar 5, 2024
1 parent 0520ba9 commit c2442ad
Show file tree
Hide file tree
Showing 12 changed files with 210 additions and 181 deletions.
41 changes: 32 additions & 9 deletions aidial_assistant/application/assistant_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
CommandConstructor,
CommandDict,
)
from aidial_assistant.chain.history import History
from aidial_assistant.chain.history import History, ScopedMessage
from aidial_assistant.commands.reply import Reply
from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin
from aidial_assistant.commands.run_tool import RunTool
Expand Down Expand Up @@ -109,6 +109,25 @@ def _construct_tool(name: str, description: str) -> ChatCompletionToolParam:
)


def _create_history(
messages: list[ScopedMessage], plugins: list[PluginInfo]
) -> History:
plugin_descriptions = {
plugin.info.ai_plugin.name_for_model: plugin.info.open_api.info.description
or plugin.info.ai_plugin.description_for_human
for plugin in plugins
}
return History(
assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build(
addons=plugin_descriptions
),
best_effort_template=MAIN_BEST_EFFORT_TEMPLATE.build(
addons=plugin_descriptions
),
scoped_messages=messages,
)


class AssistantApplication(ChatCompletion):
def __init__(
self, config_dir: Path, tools_supporting_deployments: set[str]
Expand Down Expand Up @@ -204,21 +223,25 @@ def create_command(addon: PluginInfo):
or addon.info.ai_plugin.description_for_human
for addon in addons
}
scoped_messages = parse_history(request.messages)
history = History(
assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build(
addons=addon_descriptions
),
best_effort_template=MAIN_BEST_EFFORT_TEMPLATE.build(
addons=addon_descriptions
),
scoped_messages=parse_history(request.messages),
scoped_messages=scoped_messages,
)
discarded_messages: int | None = None
discarded_user_messages: set[int] | None = None
if request.max_prompt_tokens is not None:
original_size = history.user_message_count
history = await history.truncate(request.max_prompt_tokens, model)
truncated_size = history.user_message_count
discarded_messages = original_size - truncated_size
history, discarded_messages = await history.truncate(
model, request.max_prompt_tokens
)
discarded_user_messages = set(
scoped_messages[index].user_index
for index in discarded_messages
)
# TODO: else compare the history size to the max prompt tokens of the underlying model

choice = response.create_single_choice()
Expand All @@ -243,8 +266,8 @@ def create_command(addon: PluginInfo):
model.total_prompt_tokens, model.total_completion_tokens
)

if discarded_messages is not None:
response.set_discarded_messages(discarded_messages)
if discarded_user_messages is not None:
response.set_discarded_messages(list(discarded_user_messages))

@staticmethod
async def _run_native_tools_chat(
Expand Down
136 changes: 73 additions & 63 deletions aidial_assistant/chain/history.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from enum import Enum
from typing import Tuple, cast

from jinja2 import Template
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
)
from pydantic import BaseModel

from aidial_assistant.chain.command_result import (
Expand All @@ -26,6 +30,7 @@ class MessageScope(str, Enum):
class ScopedMessage(BaseModel):
scope: MessageScope = MessageScope.USER
message: ChatCompletionMessageParam
user_index: int


class History:
Expand All @@ -40,35 +45,32 @@ def __init__(
)
self.best_effort_template = best_effort_template
self.scoped_messages = scoped_messages
self._user_message_count = sum(
1
for message in scoped_messages
if message.scope == MessageScope.USER
)

def to_protocol_messages(self) -> list[ChatCompletionMessageParam]:
messages: list[ChatCompletionMessageParam] = []
for index, scoped_message in enumerate(self.scoped_messages):
scoped_message_iterator = iter(self.scoped_messages)
if self._is_first_system_message():
message = cast(
ChatCompletionSystemMessageParam,
next(scoped_message_iterator).message,
)
messages.append(
system_message(
self.assistant_system_message_template.render(
system_prefix=message["content"]
)
)
)
else:
messages.append(
system_message(self.assistant_system_message_template.render())
)

for scoped_message in scoped_message_iterator:
message = scoped_message.message
scope = scoped_message.scope

if index == 0:
if message["role"] == "system":
messages.append(
system_message(
self.assistant_system_message_template.render(
system_prefix=message["content"]
)
)
)
else:
messages.append(
system_message(
self.assistant_system_message_template.render()
)
)
messages.append(message)
elif scope == MessageScope.USER and message["role"] == "assistant":
if scope == MessageScope.USER and message["role"] == "assistant":
# Clients see replies in plain text, but the model should understand how to reply appropriately.
content = commands_to_text(
[
Expand Down Expand Up @@ -107,51 +109,59 @@ def to_best_effort_messages(
return messages

async def truncate(
self, max_prompt_tokens: int, model_client: ModelClient
) -> "History":
discarded_messages = await model_client.get_discarded_messages(
self.to_protocol_messages(),
max_prompt_tokens,
self, model_client: ModelClient, max_prompt_tokens: int
) -> Tuple["History", list[int]]:
discarded_messages = await self._get_discarded_messages(
model_client, max_prompt_tokens
)

if discarded_messages > 0:
return History(
if not discarded_messages:
return self, []

discarded_messages_set = set(discarded_messages)
return (
History(
assistant_system_message_template=self.assistant_system_message_template,
best_effort_template=self.best_effort_template,
scoped_messages=self._skip_messages(discarded_messages),
)

return self

@property
def user_message_count(self) -> int:
return self._user_message_count

def _skip_messages(self, discarded_messages: int) -> list[ScopedMessage]:
messages: list[ScopedMessage] = []
current_message = self.scoped_messages[0]
message_iterator = iter(self.scoped_messages)
for _ in range(discarded_messages):
current_message = next(message_iterator)
while current_message.message["role"] == "system":
# System messages should be kept in the history
messages.append(current_message)
current_message = next(message_iterator)
scoped_messages=[
scoped_message
for index, scoped_message in enumerate(self.scoped_messages)
if index not in discarded_messages_set
],
),
discarded_messages,
)

if current_message.scope == MessageScope.INTERNAL:
while current_message.scope == MessageScope.INTERNAL:
current_message = next(message_iterator)
async def _get_discarded_messages(
self, model_client: ModelClient, max_prompt_tokens: int
) -> list[int]:
discarded_protocol_messages = await model_client.get_discarded_messages(
self.to_protocol_messages(),
max_prompt_tokens,
)

# Internal messages (i.e. addon requests/responses) are always followed by an assistant reply
assert (
current_message.message["role"] == "assistant"
), "Internal messages must be followed by an assistant reply."
if discarded_protocol_messages:
discarded_protocol_messages.sort()
discarded_messages = (
discarded_protocol_messages
if self._is_first_system_message()
else [index - 1 for index in discarded_protocol_messages]
)
user_indices = set(
self.scoped_messages[index].user_index
for index in discarded_messages
)

remaining_messages = list(message_iterator)
assert (
len(remaining_messages) > 0
), "No user messages left after history truncation."
return [
index
for index, scoped_message in enumerate(self.scoped_messages)
if scoped_message.user_index in user_indices
]

messages += remaining_messages
return discarded_protocol_messages

return messages
def _is_first_system_message(self) -> bool:
return (
len(self.scoped_messages) > 0
and self.scoped_messages[0].message["role"] == "system"
)
4 changes: 3 additions & 1 deletion aidial_assistant/commands/run_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def create_command(op: APIOperation):
best_effort_template=ADDON_BEST_EFFORT_TEMPLATE.build(
api_schema=api_schema
),
scoped_messages=[ScopedMessage(message=user_message(query))],
scoped_messages=[
ScopedMessage(message=user_message(query), user_index=0)
],
)

chat = CommandChain(
Expand Down
41 changes: 31 additions & 10 deletions aidial_assistant/model/model_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
from itertools import islice
from typing import Any, AsyncIterator, List

from aidial_sdk.utils.merge_chunks import merge
Expand All @@ -16,7 +17,7 @@ class ReasonLengthException(Exception):


class ExtraResultsCallback:
def on_discarded_messages(self, discarded_messages: int):
def on_discarded_messages(self, discarded_messages: list[int]):
pass

def on_prompt_tokens(self, prompt_tokens: int):
Expand All @@ -36,6 +37,21 @@ async def _flush_stream(stream: AsyncIterator[str]):
pass


def _discarded_messages_count_to_indices(
messages: list[ChatCompletionMessageParam], discarded_messages: int
) -> list[int]:
return list(
islice(
(
i
for i, message in enumerate(messages)
if message["role"] != "system"
),
discarded_messages,
)
)


class ModelClient(ABC):
def __init__(self, client: AsyncOpenAI, model_args: dict[str, Any]):
self.client = client
Expand Down Expand Up @@ -70,12 +86,16 @@ async def agenerate(
extra_results_callback.on_prompt_tokens(prompt_tokens)

if extra_results_callback:
discarded_messages: int | None = chunk_dict.get(
discarded_messages: int | list[int] | None = chunk_dict.get(
"statistics", {}
).get("discarded_messages")
if discarded_messages is not None:
extra_results_callback.on_discarded_messages(
discarded_messages
_discarded_messages_count_to_indices(
messages, discarded_messages
)
if isinstance(discarded_messages, int)
else discarded_messages
)

choice = chunk.choices[0]
Expand Down Expand Up @@ -128,15 +148,16 @@ def on_prompt_tokens(self, prompt_tokens: int):
return callback.token_count

# TODO: Use a dedicated endpoint for discarded_messages.
# https://github.com/epam/ai-dial-assistant/issues/39
async def get_discarded_messages(
self, messages: list[ChatCompletionMessageParam], max_prompt_tokens: int
) -> int:
) -> list[int]:
class DiscardedMessagesCallback(ExtraResultsCallback):
def __init__(self):
self.message_count: int | None = None
self.discarded_messages: list[int] | None = None

def on_discarded_messages(self, discarded_messages: int):
self.message_count = discarded_messages
def on_discarded_messages(self, discarded_messages: list[int]):
self.discarded_messages = discarded_messages

callback = DiscardedMessagesCallback()
await _flush_stream(
Expand All @@ -147,10 +168,10 @@ def on_discarded_messages(self, discarded_messages: int):
max_tokens=1,
)
)
if callback.message_count is None:
raise Exception("No message count received.")
if callback.discarded_messages is None:
raise Exception("Discarded messages were not provided.")

return callback.message_count
return callback.discarded_messages

@property
def total_prompt_tokens(self) -> int:
Expand Down
Loading

0 comments on commit c2442ad

Please sign in to comment.