Skip to content

Commit

Permalink
Refactor context and chat_id wise-agents#363
Browse files Browse the repository at this point in the history
  • Loading branch information
maeste committed Sep 30, 2024
1 parent d7d68f1 commit e50f23a
Show file tree
Hide file tree
Showing 21 changed files with 488 additions and 764 deletions.
15 changes: 9 additions & 6 deletions src/wiseagents/agents/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class AssistantAgent(WiseAgent):
_response_delivery = None
_cond = threading.Condition()
_response : WiseAgentMessage = None
_chat_id = None
_ctx = None

def __new__(cls, *args, **kwargs):
"""Create a new instance of the class, setting default values for the optional instance variables."""
Expand Down Expand Up @@ -52,14 +52,17 @@ def __repr__(self):

def start_agent(self):
super().start_agent()
self._chat_id = str(uuid.uuid4())
WiseAgentRegistry.get_or_create_context("default").set_collaboration_type(self._chat_id,
WiseAgentCollaborationType.CHAT)
self._ctx = f'{self.name}.{str(uuid.uuid4())}'
WiseAgentRegistry.create_context(self._ctx).set_collaboration_type(WiseAgentCollaborationType.CHAT)
gradio.ChatInterface(self.slow_echo).launch(prevent_thread_lock=True)

def stop_agent(self):
super().stop_agent()
WiseAgentRegistry.remove_context(self._ctx)

def slow_echo(self, message, history):
with self._cond:
self.handle_request(WiseAgentMessage(message=message, sender=self.name, chat_id=self._chat_id))
self.handle_request(WiseAgentMessage(message=message, sender=self.name, context_name=self._ctx))
self._cond.wait()
return self._response.message

Expand All @@ -79,7 +82,7 @@ def process_request(self, request: WiseAgentMessage,
no string response yet
"""
print(f"AssistantAgent: process_request: {request}")
WiseAgentRegistry.get_or_create_context("default").append_chat_completion(self._chat_id, {"role": "user", "content": request.message})
WiseAgentRegistry.get_context(request.context_name).append_chat_completion({"role": "user", "content": request.message})
self.send_request(request, self.destination_agent_name)
return None

Expand Down
121 changes: 59 additions & 62 deletions src/wiseagents/agents/coordinator_wise_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,13 @@ def handle_request(self, request):
logging.debug(f"Sequential coordinator received request: {request}")

# Generate a chat ID that will be used to collaborate on this query
chat_id = str(uuid.uuid4())
sub_ctx_name = f'{self.name}.{str(uuid.uuid4())}'

ctx = WiseAgentRegistry.get_or_create_context(request.context_name)
ctx.set_collaboration_type(chat_id, WiseAgentCollaborationType.SEQUENTIAL)
ctx.set_agents_sequence(chat_id, self._agents)
ctx.set_route_response_to(chat_id, request.sender)
self.send_request(WiseAgentMessage(message=request.message, sender=self.name, context_name=request.context_name,
chat_id=chat_id), self._agents[0])
ctx = WiseAgentRegistry.create_sub_context(request.context_name, sub_ctx_name)
ctx.set_collaboration_type(WiseAgentCollaborationType.SEQUENTIAL)
ctx.set_agents_sequence(self._agents)
ctx.set_route_response_to(request.sender)
self.send_request(WiseAgentMessage(message=request.message, sender=self.name, context_name=ctx.name), self._agents[0])

def process_response(self, response):
"""
Expand Down Expand Up @@ -134,17 +133,16 @@ def handle_request(self, request):
logging.debug(f"Sequential coordinator received request: {request}")

# Generate a chat ID that will be used to collaborate on this query
chat_id = str(uuid.uuid4())
sub_ctx_name = f'{self.name}.{str(uuid.uuid4())}'

ctx = WiseAgentRegistry.get_or_create_context(request.context_name)
ctx.set_collaboration_type(chat_id, WiseAgentCollaborationType.SEQUENTIAL_MEMORY)
ctx.append_chat_completion(chat_uuid=chat_id, messages={"role": "system", "content": self.metadata.system_message})
ctx = WiseAgentRegistry.create_sub_context(request.context_name, sub_ctx_name)
ctx.set_collaboration_type(WiseAgentCollaborationType.SEQUENTIAL_MEMORY)
ctx.append_chat_completion(messages={"role": "system", "content": self.metadata.system_message})

ctx.set_agents_sequence(chat_id, self._agents)
ctx.set_route_response_to(chat_id, request.sender)
ctx.add_query(chat_id, request.message)
self.send_request(WiseAgentMessage(message=request.message, sender=self.name, context_name=request.context_name,
chat_id=chat_id), self._agents[0])
ctx.set_agents_sequence(self._agents)
ctx.set_route_response_to(request.sender)
ctx.add_query(request.message)
self.send_request(WiseAgentMessage(message=request.message, sender=self.name, context_name=ctx.name), self._agents[0])


class PhasedCoordinatorWiseAgent(WiseAgent):
Expand Down Expand Up @@ -216,25 +214,27 @@ def handle_request(self, request):
logging.debug(f"Coordinator received request: {request}")

# Generate a chat ID that will be used to collaborate on this query
chat_id = str(uuid.uuid4())

ctx = WiseAgentRegistry.get_or_create_context(request.context_name)
ctx.set_collaboration_type(chat_id, WiseAgentCollaborationType.PHASED)
ctx.set_route_response_to(chat_id, request.sender)

sub_ctx_name = f'{self.name}.{str(uuid.uuid4())}'

ctx = WiseAgentRegistry.create_sub_context(request.context_name, sub_ctx_name)
ctx.set_collaboration_type(WiseAgentCollaborationType.PHASED)
ctx.set_route_response_to(request.sender)
logging.debug(f"set_collaboration_type (0) parent context: {WiseAgentRegistry.get_context(request.context_name)} use_redis: {WiseAgentRegistry.get_context(request.context_name)._use_redis}")
logging.debug(f"set_collaboration_type (0) Created context: {ctx} use_redis: {ctx._use_redis}, config: {ctx._config}")
logging.debug(f"Registred context: {WiseAgentRegistry.get_context(ctx.name)}")
# Determine the agents required to answer the query
agent_selection_prompt = ("Given the following query and a description of the agents that are available," +
" determine all of the agents that could be required to solve the query." +
" Format the response as a space separated list of agent names and don't include " +
" anything else in the response.\n" +
" Query: " + request.message + "\n" + "Available agents:\n" +
"\n".join(WiseAgentRegistry.get_agent_names_and_descriptions()) + "\n")
ctx.append_chat_completion(chat_uuid=chat_id, messages={"role": "system", "content": self.metadata.system_message or self.llm.system_message})
ctx.append_chat_completion(chat_uuid=chat_id, messages={"role": "user", "content": agent_selection_prompt})
ctx.append_chat_completion(messages={"role": "system", "content": self.metadata.system_message or self.llm.system_message})
ctx.append_chat_completion(messages={"role": "user", "content": agent_selection_prompt})

logging.debug(f"messages: {ctx.llm_chat_completion[chat_id]}")
llm_response = self.llm.process_chat_completion(ctx.llm_chat_completion[chat_id], tools=[])
ctx.append_chat_completion(chat_uuid=chat_id, messages=llm_response.choices[0].message)
logging.debug(f"messages: {ctx.llm_chat_completion}")
llm_response = self.llm.process_chat_completion(ctx.llm_chat_completion, tools=[])
ctx.append_chat_completion(messages=llm_response.choices[0].message)

# Assign the agents to phases
agent_assignment_prompt = ("Assign each of the agents that will be required to solve the query to one of the following phases:\n" +
Expand All @@ -243,20 +243,20 @@ def handle_request(self, request):
" Format the response as a space separated list of agents for each phase, where the first"
" line contains the list of agents for the first phase and second line contains the list of"
" agents for the second phase and so on. Don't include anything else in the response.\n")
ctx.append_chat_completion(chat_uuid=chat_id, messages={"role": "user", "content": agent_assignment_prompt})
llm_response = self.llm.process_chat_completion(ctx.llm_chat_completion[chat_id], tools=[])
ctx.append_chat_completion(chat_uuid=chat_id, messages=llm_response.choices[0].message)
ctx.append_chat_completion(messages={"role": "user", "content": agent_assignment_prompt})
llm_response = self.llm.process_chat_completion(ctx.llm_chat_completion, tools=[])
ctx.append_chat_completion(messages=llm_response.choices[0].message)
phases = [phase.split() for phase in llm_response.choices[0].message.content.splitlines()]
ctx.set_agent_phase_assignments(chat_id, phases)
ctx.set_current_phase(chat_id, 0)
ctx.add_query(chat_id, request.message)
ctx.set_agent_phase_assignments(phases)
ctx.set_current_phase(0)
ctx.add_query(request.message)

# Kick off the first phase
for agent in phases[0]:
self.send_request(WiseAgentMessage(message=request.message, sender=self.name,
context_name=request.context_name, chat_id=chat_id), agent)
context_name=ctx.name), agent)

def process_response(self, response):
def process_response(self, response : WiseAgentMessage):
"""
Process a response message. If this message is from the last agent remaining in the current phase, then
kick off the next phase of collaboration if there are more phases. Otherwise, determine if we should
Expand All @@ -265,19 +265,18 @@ def process_response(self, response):
Args:
response (WiseAgentMessage): the response message to process
"""
ctx = WiseAgentRegistry.get_or_create_context(response.context_name)
chat_id = response.chat_id

ctx = WiseAgentRegistry.get_context(response.context_name)

if response.message_type != WiseAgentMessageType.ACK:
raise ValueError(f"Unexpected response message: {response.message}")
raise ValueError(f"Unexpected response message_type: {response.message_type} with message: {response.message}")

# Remove the agent from the required agents for this phase
ctx.remove_required_agent_for_current_phase(chat_id, response.sender)
ctx.remove_required_agent_for_current_phase(response.sender)

# If there are no more agents remaining in this phase, move on to the next phase,
# return the final answer, or iterate
if len(ctx.get_required_agents_for_current_phase(chat_id)) == 0:
next_phase = ctx.get_agents_for_next_phase(chat_id)
if len(ctx.get_required_agents_for_current_phase()) == 0:
next_phase = ctx.get_agents_for_next_phase()
if next_phase is None:
# Determine the final answer
final_answer_prompt = ("What is the final answer for the original query? Provide the answer followed" +
Expand All @@ -286,9 +285,8 @@ def process_response(self, response):
" the confidence score on the next line. For example:\n" +
" Your answer goes here.\n"
" 85\n")
ctx.append_chat_completion(chat_uuid=chat_id,
messages={"role": "user", "content": final_answer_prompt})
llm_response = self.llm.process_chat_completion(ctx.llm_chat_completion[chat_id], tools=[])
ctx.append_chat_completion(messages={"role": "user", "content": final_answer_prompt})
llm_response = self.llm.process_chat_completion(ctx.llm_chat_completion, tools=[])
final_answer_and_score = llm_response.choices[0].message.content.splitlines()
final_answer = "\n".join(final_answer_and_score[:-1])
if final_answer_and_score[-1].strip().isnumeric():
Expand All @@ -300,37 +298,36 @@ def process_response(self, response):
# Determine if we should return the final answer or iterate
if score >= self.confidence_score_threshold:
self.send_response(WiseAgentMessage(message=final_answer, sender=self.name,
context_name=response.context_name, chat_id=chat_id), ctx.get_route_response_to(chat_id))
elif len(ctx.get_queries(chat_id)) == self.max_iterations:
context_name=response.context_name), ctx.get_route_response_to())
elif len(ctx.get_queries()) == self.max_iterations:
self.send_response(WiseAgentMessage(message=CANNOT_ANSWER, message_type=WiseAgentMessageType.CANNOT_ANSWER,
sender=self.name, context_name=response.context_name, chat_id=chat_id),
ctx.get_route_response_to(chat_id))
sender=self.name, context_name=response.context_name),
ctx.get_route_response_to())
else:
# Rephrase the query and iterate
if len(ctx.get_queries(chat_id)) < self.max_iterations:
if len(ctx.get_queries()) < self.max_iterations:
rephrase_query_prompt = ("The final answer was not considered good enough to respond to the original query.\n" +
" The original query was: " + ctx.get_queries(chat_id)[0] + "\n" +
" The original query was: " + ctx.get_queries()[0] + "\n" +
" Your task is to analyze the original query for its intent along with the conversation" +
" history and final answer to rephrase the original query to yield a better final answer." +
" The response should contain only the rephrased query."
" Don't include anything else in the response.\n")
ctx.append_chat_completion(chat_uuid=chat_id,
messages={"role": "user", "content": rephrase_query_prompt})
# Note that llm_chat_completion[chat_id] is being used here so we have the full history
llm_response = self.llm.process_chat_completion(ctx.llm_chat_completion[chat_id], tools=[])
ctx.append_chat_completion(messages={"role": "user", "content": rephrase_query_prompt})
# Note that llm_chat_completion is being used here so we have the full history
llm_response = self.llm.process_chat_completion(ctx.llm_chat_completion, tools=[])
rephrased_query = llm_response.choices[0].message.content
ctx.append_chat_completion(chat_uuid=chat_id, messages=llm_response.choices[0].message)
ctx.set_current_phase(chat_id, 0)
ctx.add_query(chat_id, rephrased_query)
for agent in ctx.get_required_agents_for_current_phase(chat_id):
ctx.append_chat_completion(messages=llm_response.choices[0].message)
ctx.set_current_phase(0)
ctx.add_query(rephrased_query)
for agent in ctx.get_required_agents_for_current_phase():
self.send_request(WiseAgentMessage(message=rephrased_query, sender=self.name,
context_name=response.context_name, chat_id=chat_id),
context_name=response.context_name),
agent)
else:
# Kick off the next phase
for agent in next_phase:
self.send_request(WiseAgentMessage(message=ctx.get_current_query(chat_id), sender=self.name,
context_name=response.context_name, chat_id=chat_id), agent)
self.send_request(WiseAgentMessage(message=ctx.get_current_query(), sender=self.name,
context_name=response.context_name), agent)
return True

def process_event(self, event):
Expand Down
Loading

0 comments on commit e50f23a

Please sign in to comment.