diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index ccab5567..5c3d0393 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -5,7 +5,7 @@ import time from typing import List, Optional, Type, Union -from openai import BadRequestError +from openai import APIError, BadRequestError from openai.types.beta import AssistantToolChoice from openai.types.beta.threads.message import Attachment from openai.types.beta.threads.run import TruncationStrategy @@ -40,6 +40,8 @@ def __init__(self, agent: Union[Agent, User], recipient_agent: Agent): self.run = None self.stream = None + self.num_run_retries = 0 + def init_thread(self): if self.id: self.thread = self.client.beta.threads.retrieve(self.id) @@ -235,22 +237,22 @@ def handle_output(tool_call, output): # error elif self.run.status == "failed": full_message += self._get_last_message_text() - # retry run 2 times - if error_attempts < 1 and ("something went wrong" in self.run.last_error.message.lower() or "The server had an error processing your request" in self.run.last_error.message.lower()): - time.sleep(1) - self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) - error_attempts += 1 - elif 1 <= error_attempts < 5 and "something went wrong" in self.run.last_error.message.lower(): - self.create_message( - message="Continue.", - role="user" - ) - self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) + common_errors = ["something went wrong", "the server had an error processing your request", "rate limit reached"] + error_message = self.run.last_error.message.lower() + + if error_attempts < 3 and any(error in error_message for error in common_errors): + if error_attempts < 2: + time.sleep(1 + error_attempts) + else: + self.create_message(message="Continue.", role="user") + + self._create_run(recipient_agent, additional_instructions, event_handler, + tool_choice, response_format=response_format) error_attempts += 1 else: raise Exception("OpenAI Run Failed. Error: ", self.run.last_error.message) elif self.run.status == "incomplete": - raise Exception("OpenAI Run Incomplete. Error: ", self.run.incomplete_details) + raise Exception("OpenAI Run Incomplete. Details: ", self.run.incomplete_details) # return assistant message else: message_obj = self._get_last_assistant_message() @@ -302,10 +304,26 @@ def handle_output(tool_call, output): return last_message def _create_run(self, recipient_agent, additional_instructions, event_handler, tool_choice, temperature=None, response_format: Optional[dict] = None): - if event_handler: - with self.client.beta.threads.runs.stream( + try: + if event_handler: + with self.client.beta.threads.runs.stream( + thread_id=self.thread.id, + event_handler=event_handler(), + assistant_id=recipient_agent.id, + additional_instructions=additional_instructions, + tool_choice=tool_choice, + max_prompt_tokens=recipient_agent.max_prompt_tokens, + max_completion_tokens=recipient_agent.max_completion_tokens, + truncation_strategy=recipient_agent.truncation_strategy, + temperature=temperature, + extra_body={"parallel_tool_calls": recipient_agent.parallel_tool_calls}, + response_format=response_format + ) as stream: + stream.until_done() + self.run = stream.get_final_run() + else: + self.run = self.client.beta.threads.runs.create( thread_id=self.thread.id, - event_handler=event_handler(), assistant_id=recipient_agent.id, additional_instructions=additional_instructions, tool_choice=tool_choice, @@ -313,29 +331,21 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t max_completion_tokens=recipient_agent.max_completion_tokens, truncation_strategy=recipient_agent.truncation_strategy, temperature=temperature, - extra_body={"parallel_tool_calls": recipient_agent.parallel_tool_calls}, + parallel_tool_calls=recipient_agent.parallel_tool_calls, response_format=response_format - ) as stream: - stream.until_done() - self.run = stream.get_final_run() - else: - self.run = self.client.beta.threads.runs.create( - thread_id=self.thread.id, - assistant_id=recipient_agent.id, - additional_instructions=additional_instructions, - tool_choice=tool_choice, - max_prompt_tokens=recipient_agent.max_prompt_tokens, - max_completion_tokens=recipient_agent.max_completion_tokens, - truncation_strategy=recipient_agent.truncation_strategy, - temperature=temperature, - parallel_tool_calls=recipient_agent.parallel_tool_calls, - response_format=response_format - ) - self.run = self.client.beta.threads.runs.poll( - thread_id=self.thread.id, - run_id=self.run.id, - # poll_interval_ms=500, - ) + ) + self.run = self.client.beta.threads.runs.poll( + thread_id=self.thread.id, + run_id=self.run.id, + # poll_interval_ms=500, + ) + except APIError as e: + if "The server had an error processing your request" in e.message and self.num_run_retries < 3: + time.sleep(1 + self.num_run_retries) + self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice, response_format=response_format) + self.num_run_retries += 1 + else: + raise e def _run_until_done(self): while self.run.status in ['queued', 'in_progress', "cancelling"]: