Skip to content

Commit

Permalink
Fix history formatting. (#34) (#35)
Browse files Browse the repository at this point in the history
(cherry picked from commit faea324)
  • Loading branch information
Oleksii-Klimov authored Nov 28, 2023
1 parent eea6a04 commit b5753b2
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 13 deletions.
14 changes: 9 additions & 5 deletions aidial_assistant/chain/command_chain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, AsyncIterator, Callable, Tuple
from typing import Any, AsyncIterator, Callable, Tuple, cast

from aidial_sdk.chat_completion.request import Role
from openai import InvalidRequestError
Expand All @@ -10,8 +10,10 @@
from aidial_assistant.chain.callbacks.command_callback import CommandCallback
from aidial_assistant.chain.callbacks.result_callback import ResultCallback
from aidial_assistant.chain.command_result import (
CommandInvocation,
CommandResult,
Status,
commands_to_text,
responses_to_text,
)
from aidial_assistant.chain.dialogue import Dialogue
Expand Down Expand Up @@ -127,7 +129,7 @@ async def _run_with_protocol_failure_retries(
)

if responses:
request_text = json.dumps({"commands": commands})
request_text = commands_to_text(commands)
response_text = responses_to_text(responses)

callback.on_state(request_text, response_text)
Expand Down Expand Up @@ -162,12 +164,12 @@ async def _run_with_protocol_failure_retries(

async def _run_commands(
self, chunk_stream: AsyncIterator[str], callback: ChainCallback
) -> Tuple[list[dict[str, Any]], list[CommandResult]]:
) -> Tuple[list[CommandInvocation], list[CommandResult]]:
char_stream = CharacterStream(chunk_stream)
await skip_to_json_start(char_stream)

async with JsonParser.parse(char_stream) as root_node:
commands: list[dict[str, Any]] = []
commands: list[CommandInvocation] = []
responses: list[CommandResult] = []
request_reader = CommandsReader(root_node)
async for invocation in request_reader.parse_invocations():
Expand All @@ -190,7 +192,9 @@ async def _run_commands(
command_name, command, args, callback
)

commands.append(invocation.node.value())
commands.append(
cast(CommandInvocation, invocation.node.value())
)
responses.append(response)

return commands, responses
Expand Down
9 changes: 9 additions & 0 deletions aidial_assistant/chain/command_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,14 @@ class CommandResult(TypedDict):
error messages for the failed one."""


class CommandInvocation(TypedDict):
command: str
args: list[str]


def responses_to_text(responses: List[CommandResult]) -> str:
return json.dumps({"responses": responses})


def commands_to_text(commands: List[CommandInvocation]) -> str:
return json.dumps({"commands": commands})
18 changes: 10 additions & 8 deletions aidial_assistant/chain/history.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
from enum import Enum

from aidial_sdk.chat_completion import Role
from jinja2 import Template
from pydantic import BaseModel

from aidial_assistant.chain.command_result import (
CommandInvocation,
commands_to_text,
)
from aidial_assistant.chain.dialogue import Dialogue
from aidial_assistant.chain.model_client import Message
from aidial_assistant.commands.reply import Reply
Expand Down Expand Up @@ -54,13 +57,12 @@ def to_protocol_messages(self) -> list[Message]:

elif scope == MessageScope.USER and message.role == Role.ASSISTANT:
# Clients see replies in plain text, but the model should understand how to reply appropriately.
content = json.dumps(
{
"commands": {
"command": Reply.token(),
"args": [message.content],
}
}
content = commands_to_text(
[
CommandInvocation(
command=Reply.token(), args=[message.content]
)
]
)
messages.append(Message.assistant(content=content))
else:
Expand Down
37 changes: 37 additions & 0 deletions tests/unit_tests/chain/test_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from jinja2 import Template

from aidial_assistant.chain.history import History, MessageScope, ScopedMessage
from aidial_assistant.chain.model_client import Message

SYSTEM_MESSAGE = "<system message>"
USER_MESSAGE = "<user message>"
ASSISTANT_MESSAGE = "<assistant message>"


def test_protocol_messages():
history = History(
assistant_system_message_template=Template(
"system message={{system_prefix}}"
),
best_effort_template=Template(""),
scoped_messages=[
ScopedMessage(
scope=MessageScope.USER, message=Message.system(SYSTEM_MESSAGE)
),
ScopedMessage(
scope=MessageScope.USER, message=Message.user(USER_MESSAGE)
),
ScopedMessage(
scope=MessageScope.USER,
message=Message.assistant(ASSISTANT_MESSAGE),
),
],
)

assert history.to_protocol_messages() == [
Message.system(f"system message={SYSTEM_MESSAGE}"),
Message.user(USER_MESSAGE),
Message.assistant(
f'{{"commands": [{{"command": "reply", "args": ["{ASSISTANT_MESSAGE}"]}}]}}'
),
]

0 comments on commit b5753b2

Please sign in to comment.