Skip to content

Commit

Permalink
tool_segment_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ieaves committed Dec 15, 2023
1 parent e881f12 commit 6dccaab
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 35 deletions.
87 changes: 54 additions & 33 deletions grai-server/app/grAI/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from openai import AsyncOpenAI
from django.conf import settings
from asyncio import gather
from grAI.chat_types import SupportedMessageTypes, SystemMessage, FunctionMessage
from grAI.chat_types import SupportedMessageTypes, SystemMessage, FunctionMessage, UserMessage
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat import ChatCompletion

Expand Down Expand Up @@ -83,10 +83,11 @@ def __init__(
):
super().__init__(prompt_string=prompt_string, model=model, client=client, max_tokens=max_tokens)

async def call(self, input_obj: SupportedMessageTypes):
async def call(self, input_obj: SupportedMessageTypes) -> str:
query = self.query(input_obj)
self.validate(query["content"])
return self.completion(query)
response = await self.completion(query)
return response.choices[0].message.content

def validate(self, content: str) -> list[int]:
encoding = self.encoder.encode(content)
Expand All @@ -110,16 +111,26 @@ def __init__(
):
super().__init__(prompt_string=prompt_string, model=model, client=client, max_tokens=max_tokens)

def prompt(self, input_obj: list[SupportedMessageTypes]) -> str:
def prompt(self, content: str | list[SupportedMessageTypes], **kwargs) -> str:
if isinstance(content, list):
content = self.prompt_content(content)
return self.prompt_string.format(content=content, **kwargs)

@staticmethod
def prompt_content(input_obj: list[SupportedMessageTypes]) -> str:
component_iter = (f"{inp.role}\n---\n{inp.content}" for inp in input_obj)
content = "\n---\n".join(component_iter)
return self.prompt_string.format(content=content)
return content

def query(self, content: list[SupportedMessageTypes] | str) -> dict:
if isinstance(content, list):
content = self.prompt(content)
def query(self, content: list[SupportedMessageTypes] | str, **kwargs) -> dict:
prompt = self.prompt(content=content, **kwargs)
return {"role": "system", "content": prompt}

return {"role": "system", "content": content}
async def call(self, input_obj: list[SupportedMessageTypes], **kwargs) -> str:
query = self.query(input_obj, **kwargs)
self.validate(query["content"])
response = await self.completion(query)
return response.choices[0].message.content


DEFAULT_REDUCE_PROMPT = DEFAULT_SUMMARIZER_PROMPT
Expand All @@ -135,8 +146,8 @@ def __init__(
):
super().__init__(prompt_string=prompt_string, model=model, client=client, max_tokens=max_tokens)

async def call(self, items: list[SupportedMessageTypes]) -> str:
query = self.query(items)
async def call(self, items: list[SupportedMessageTypes], **kwargs) -> str:
query = self.query(items, **kwargs)
encoding = self.validate(query["content"])
response = await self.completion(query)
return response.choices[0].message.content
Expand Down Expand Up @@ -165,8 +176,8 @@ def __init__(
):
super().__init__(prompt_string=prompt_string, model=model, client=client, max_tokens=max_tokens)

async def call(self, items: list[SupportedMessageTypes]) -> list[str]:
encoding = self.encoder.encode(self.prompt(items))
async def call(self, items: list[SupportedMessageTypes], **kwargs) -> list[str]:
encoding = self.encoder.encode(self.prompt_content(items))

queries = (self.query(self.encoder.decode(chunk)) for chunk in chunker(encoding, self.max_tokens - 50))
responses = await gather(*[self.completion(query) for query in queries])
Expand Down Expand Up @@ -205,8 +216,8 @@ def __init__(
):
super().__init__(prompt_string=prompt_string, model=model, client=client, max_tokens=max_tokens)

async def call(self, items: list[SupportedMessageTypes]) -> str:
content = self.prompt(items)
async def call(self, items: list[SupportedMessageTypes], **kwargs) -> str:
content = self.prompt(items, **kwargs)
encoding = self.encoder.encode(content)
while len(encoding) > self.max_tokens:
query = self.query(self.encoder.decode(encoding[: self.max_tokens]))
Expand All @@ -231,17 +242,17 @@ def __init__(self, strategy: ProgressiveSummarization | MapReduceSummarization,
def tool_segments(items: list[SupportedMessageTypes]) -> ToolSegmentReturnType:
return tool_segments(items)

async def call(self, items: list[SupportedMessageTypes]) -> str:
async def call(self, items: list[SupportedMessageTypes], **kwargs) -> str:
if len(items) == 0:
return ""

segment: list[SupportedMessageTypes] = []
for pre_tool_segment, tool_segment in self.tool_segments(items):
content = await self.strategy.call([*segment, *pre_tool_segment])
content = await self.strategy.call([*segment, *pre_tool_segment], **kwargs)
pre_tool_context = SystemMessage(content=content)

if tool_segment is not None:
tool_summary = await self.strategy.call([pre_tool_context, *tool_segment])
tool_summary = await self.strategy.call([pre_tool_context, *tool_segment], **kwargs)
segment = [SystemMessage(content=tool_summary)]
else:
segment = [pre_tool_context]
Expand All @@ -253,7 +264,7 @@ async def call(self, items: list[SupportedMessageTypes]) -> str:
'{content}'
The user is attempting to answer the following question:
'{question}'
Keeping in mind the user's question, please distill the conversation such that a future agent can answer the user's question.
Please distill the conversation to date including all information needed by a future agent to answer the user's question.
"""


Expand All @@ -262,17 +273,24 @@ def __init__(self, model, client, max_tokens, **kwargs):
self.progressive = ProgressiveSummarization(
model=model, prompt_string=DEFAULT_GRAI_PROMPT, client=client, max_tokens=max_tokens
)

self.map_reduce = MapReduceSummarization(
model=model,
client=client,
max_tokens=max_tokens,
map=Map(model=model, client=client, max_tokens=max_tokens),
reduce=Reduce(model=model, client=client, max_tokens=max_tokens),
map=Map(model=model, client=client, max_tokens=max_tokens, prompt_string=DEFAULT_GRAI_PROMPT),
reduce=Reduce(model=model, client=client, max_tokens=max_tokens, prompt_string=DEFAULT_GRAI_PROMPT),
)

self.conversation = ConversationSummarizer(
model=model, client=client, max_tokens=max_tokens, prompt_string=DEFAULT_GRAI_PROMPT
)
self.tool = ToolSummarization(strategy=self.progressive)

super().__init__(**kwargs)

def user_messages(self, items: list[SupportedMessageTypes]) -> int:
@staticmethod
def user_messages(items: list[SupportedMessageTypes]) -> int:
i = 0

for i, item in enumerate(items[::-1]):
Expand All @@ -281,17 +299,20 @@ def user_messages(self, items: list[SupportedMessageTypes]) -> int:
idx = len(items) - i
return idx

async def call(self, items: list[SupportedMessageTypes]) -> str:
prompt = "Please identify the problem the user needs help with."
async def call(self, items: list[SupportedMessageTypes]) -> list[SupportedMessageTypes]:
prompt = SystemMessage(content="Please identify the problem the user needs help with.")
last_user_message_idx = self.user_messages(items)
prompt_items = [items[:last_user_message_idx], SystemMessage(content=prompt)]
content = await self.map_reduce.call(prompt_items)

self.progressive.call()
prompt_items: list[SupportedMessageTypes] = [*items[:last_user_message_idx], prompt]
# Conversation up to the last user message should fit in context window
user_question = UserMessage(content=await self.conversation.call(prompt_items))

summary = await self.tool.call(items, question=user_question.content)
return [SystemMessage(content=summary), user_question]

prompt = """
You've requested a tool to help you with your problem, however the response from the tool was too long
to fit in the context window. The tool response requested was {message.message.name} with arguments
{message.message.args}. Please provide a brief description of the details you're looking for which a future
agent will use to summarize the tool response. Ensure you do not actually call any tools in your response.
"""
prompt = """
You've requested a tool to help you with your problem, however the response from the tool was too long
to fit in the context window. The tool response requested was {message.message.name} with arguments
{message.message.args}. Please provide a brief description of the details you're looking for which a future
agent will use to summarize the tool response. Ensure you do not actually call any tools in your response.
"""
9 changes: 7 additions & 2 deletions grai-server/app/grAI/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def tool_segments(items: list[SupportedMessageTypes]) -> list[ToolSegmentReturnT
if tool_segment is None:
if isinstance(item, ChatCompletionMessage) and item.tool_calls is not None:
tool_segment = []
# pre_tool_segment = []
elif item.role == "tool":
for stuff in result:
print(stuff)
Expand All @@ -77,8 +78,12 @@ def tool_segments(items: list[SupportedMessageTypes]) -> list[ToolSegmentReturnT
else:
result.append((pre_tool_segment, tool_segment))
# yield pre_tool_segment, tool_segment
pre_tool_segment = [item]
tool_segment = None
if isinstance(item, ChatCompletionMessage) and item.tool_calls is not None:
tool_segment = []
pre_tool_segment = []
else:
pre_tool_segment = [item]
tool_segment = None
result.append((pre_tool_segment, tool_segment))
return result
# yield pre_tool_segment, tool_segment

0 comments on commit 6dccaab

Please sign in to comment.