Skip to content

Commit

Permalink
Improved common error handling in thread
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Aug 19, 2024
1 parent 0c233b9 commit d9c3245
Showing 1 changed file with 48 additions and 38 deletions.
86 changes: 48 additions & 38 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from typing import List, Optional, Type, Union

from openai import BadRequestError
from openai import APIError, BadRequestError
from openai.types.beta import AssistantToolChoice
from openai.types.beta.threads.message import Attachment
from openai.types.beta.threads.run import TruncationStrategy
Expand Down Expand Up @@ -40,6 +40,8 @@ def __init__(self, agent: Union[Agent, User], recipient_agent: Agent):
self.run = None
self.stream = None

self.num_run_retries = 0

def init_thread(self):
if self.id:
self.thread = self.client.beta.threads.retrieve(self.id)
Expand Down Expand Up @@ -235,22 +237,22 @@ def handle_output(tool_call, output):
# error
elif self.run.status == "failed":
full_message += self._get_last_message_text()
# retry run 2 times
if error_attempts < 1 and ("something went wrong" in self.run.last_error.message.lower() or "The server had an error processing your request" in self.run.last_error.message.lower()):
time.sleep(1)
self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format)
error_attempts += 1
elif 1 <= error_attempts < 5 and "something went wrong" in self.run.last_error.message.lower():
self.create_message(
message="Continue.",
role="user"
)
self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format)
common_errors = ["something went wrong", "the server had an error processing your request", "rate limit reached"]
error_message = self.run.last_error.message.lower()

if error_attempts < 3 and any(error in error_message for error in common_errors):
if error_attempts < 2:
time.sleep(1 + error_attempts)
else:
self.create_message(message="Continue.", role="user")

self._create_run(recipient_agent, additional_instructions, event_handler,
tool_choice, response_format=response_format)
error_attempts += 1
else:
raise Exception("OpenAI Run Failed. Error: ", self.run.last_error.message)
elif self.run.status == "incomplete":
raise Exception("OpenAI Run Incomplete. Error: ", self.run.incomplete_details)
raise Exception("OpenAI Run Incomplete. Details: ", self.run.incomplete_details)
# return assistant message
else:
message_obj = self._get_last_assistant_message()
Expand Down Expand Up @@ -302,40 +304,48 @@ def handle_output(tool_call, output):
return last_message

def _create_run(self, recipient_agent, additional_instructions, event_handler, tool_choice, temperature=None, response_format: Optional[dict] = None):
if event_handler:
with self.client.beta.threads.runs.stream(
try:
if event_handler:
with self.client.beta.threads.runs.stream(
thread_id=self.thread.id,
event_handler=event_handler(),
assistant_id=recipient_agent.id,
additional_instructions=additional_instructions,
tool_choice=tool_choice,
max_prompt_tokens=recipient_agent.max_prompt_tokens,
max_completion_tokens=recipient_agent.max_completion_tokens,
truncation_strategy=recipient_agent.truncation_strategy,
temperature=temperature,
extra_body={"parallel_tool_calls": recipient_agent.parallel_tool_calls},
response_format=response_format
) as stream:
stream.until_done()
self.run = stream.get_final_run()
else:
self.run = self.client.beta.threads.runs.create(
thread_id=self.thread.id,
event_handler=event_handler(),
assistant_id=recipient_agent.id,
additional_instructions=additional_instructions,
tool_choice=tool_choice,
max_prompt_tokens=recipient_agent.max_prompt_tokens,
max_completion_tokens=recipient_agent.max_completion_tokens,
truncation_strategy=recipient_agent.truncation_strategy,
temperature=temperature,
extra_body={"parallel_tool_calls": recipient_agent.parallel_tool_calls},
parallel_tool_calls=recipient_agent.parallel_tool_calls,
response_format=response_format
) as stream:
stream.until_done()
self.run = stream.get_final_run()
else:
self.run = self.client.beta.threads.runs.create(
thread_id=self.thread.id,
assistant_id=recipient_agent.id,
additional_instructions=additional_instructions,
tool_choice=tool_choice,
max_prompt_tokens=recipient_agent.max_prompt_tokens,
max_completion_tokens=recipient_agent.max_completion_tokens,
truncation_strategy=recipient_agent.truncation_strategy,
temperature=temperature,
parallel_tool_calls=recipient_agent.parallel_tool_calls,
response_format=response_format
)
self.run = self.client.beta.threads.runs.poll(
thread_id=self.thread.id,
run_id=self.run.id,
# poll_interval_ms=500,
)
)
self.run = self.client.beta.threads.runs.poll(
thread_id=self.thread.id,
run_id=self.run.id,
# poll_interval_ms=500,
)
except APIError as e:
if "The server had an error processing your request" in e.message and self.num_run_retries < 3:
time.sleep(1 + self.num_run_retries)
self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format)
self.num_run_retries += 1
else:
raise e

def _run_until_done(self):
while self.run.status in ['queued', 'in_progress', "cancelling"]:
Expand Down

0 comments on commit d9c3245

Please sign in to comment.