From b8a6dbcebc6b9228dd23a5bc96ca4b13109debb1 Mon Sep 17 00:00:00 2001 From: "Chenghao (Alan) Yang" Date: Tue, 24 Sep 2024 16:31:27 -0500 Subject: [PATCH] Add `bad_output_process_model` option and `use_fixed_model_version` option for all generation methods, to avoid future OpenAI API changes break Sotopia running. (#196) * Two major updates: 1) add "bad_output_process_model" option to all `agenerate_xxx()` methods so users can decide which model to use for handling bad outputs. By default, this is set to be `gpt-4o-mini`. 2) add `use_fixed_model_version` option for all generation methods, as some fixed model version may no longer available in the future. Users should have the right to bypass the fixed model version mapping instead of getting stuck in an error. Document (`generation.md`) has been updated for these two major changes correspondingly. * [autofix.ci] apply automated fixes --------- Co-authored-by: Chenghao Yang Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- docs/pages/concepts/generation.md | 16 +++++++ sotopia/generation_utils/generate.py | 70 +++++++++++++++++++++++++--- 2 files changed, 79 insertions(+), 7 deletions(-) diff --git a/docs/pages/concepts/generation.md b/docs/pages/concepts/generation.md index 41895aec3..8f633c6bc 100644 --- a/docs/pages/concepts/generation.md +++ b/docs/pages/concepts/generation.md @@ -12,6 +12,8 @@ async def agenerate( output_parser: BaseOutputParser[OutputType], temperature: float = 0.7, structured_output: bool = False, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True ) -> OutputType: input_variables = re.findall(r"(? Here are a few examples of how to use the `agenerate` function: @@ -37,6 +45,8 @@ async def agenerate_env_profile( inspiration_prompt: str = "asking my boyfriend to stop being friends with his ex", examples: str = "", temperature: float = 0.7, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True ) -> tuple[EnvironmentProfile, str]: """ Using langchain to generate the background @@ -56,6 +66,8 @@ async def agenerate_env_profile( ), output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile), temperature=temperature, + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version ) ``` ### Other generation functions @@ -66,6 +78,8 @@ Similarly, there are other utility functions that builds upon the `agenerate` fu async def agenerate_relationship_profile( model_name: str, agents_profiles: list[str], + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True ) -> tuple[RelationshipProfile, str] ``` @@ -78,5 +92,7 @@ async def agenerate_script( agent_name: str = "", history: str = "", single_step: bool = False, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True ) -> tuple[ScriptInteractionReturnType, str] ``` diff --git a/sotopia/generation_utils/generate.py b/sotopia/generation_utils/generate.py index 8d517ec57..92c4c7457 100644 --- a/sotopia/generation_utils/generate.py +++ b/sotopia/generation_utils/generate.py @@ -55,6 +55,8 @@ "redis", "groq/llama3-70b-8192", ] +# subject to future OpenAI changes +DEFAULT_BAD_OUTPUT_PROCESS_MODEL = "gpt-4o-mini" OutputType = TypeVar("OutputType", bound=object) client = OpenAI() @@ -304,6 +306,7 @@ def obtain_chain( input_variables: list[str], temperature: float = 0.7, max_retries: int = 6, + use_fixed_model_version: bool = True, ) -> RunnableSerializable[dict[Any, Any], BaseMessage]: """ Using langchain to sample profiles for participants @@ -315,7 +318,8 @@ def obtain_chain( ) ) chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt]) - model_name = _return_fixed_model_version(model_name) + if use_fixed_model_version: + model_name = _return_fixed_model_version(model_name) if model_name.startswith("together_ai"): model_name = "/".join(model_name.split("/")[1:]) chat_openai = ChatOpenAI( @@ -391,7 +395,8 @@ def format_bad_output_for_script( ill_formed_output: str, format_instructions: str, agents: list[str], - model_name: str = "gpt-4o-mini", + model_name: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, ) -> BaseMessage: template = """ Given the string that can not be parsed by a parser, reformat it to a string that can be parsed by the parser which uses the following format instructions. Do not add or delete any information. @@ -410,6 +415,7 @@ def format_bad_output_for_script( model_name=model_name, template=template, input_variables=re.findall(r"{(.*?)}", template), + use_fixed_model_version=use_fixed_model_version, ) input_values = { "ill_formed_output": ill_formed_output, @@ -425,7 +431,8 @@ def format_bad_output_for_script( def format_bad_output( ill_formed_output: BaseMessage, format_instructions: str, - model_name: str = "gpt-4o-mini", + model_name: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, ) -> BaseMessage: template = """ Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser. @@ -439,6 +446,7 @@ def format_bad_output( model_name=model_name, template=template, input_variables=re.findall(r"{(.*?)}", template), + use_fixed_model_version=use_fixed_model_version, ) input_values = { "ill_formed_output": ill_formed_output.content, @@ -458,6 +466,8 @@ async def agenerate( output_parser: BaseOutputParser[OutputType], temperature: float = 0.7, structured_output: bool = False, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, ) -> OutputType: input_variables = re.findall( r"(? tuple[EnvironmentProfile, str]: """ Using langchain to generate the background @@ -549,6 +565,8 @@ async def agenerate_env_profile( ), output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile), temperature=temperature, + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) @@ -556,6 +574,8 @@ async def agenerate_env_profile( async def agenerate_relationship_profile( model_name: str, agents_profiles: list[str], + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, ) -> tuple[RelationshipProfile, str]: """ Using langchain to generate the background @@ -572,6 +592,8 @@ async def agenerate_relationship_profile( agent_profile=agent_profile, ), output_parser=PydanticOutputParser(pydantic_object=RelationshipProfile), + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) @@ -586,6 +608,8 @@ async def agenerate_action( goal: str, temperature: float = 0.7, script_like: bool = False, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, ) -> AgentAction: """ Using langchain to generate an example episode @@ -635,6 +659,8 @@ async def agenerate_action( ), output_parser=PydanticOutputParser(pydantic_object=AgentAction), temperature=temperature, + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) except Exception: return AgentAction(action_type="none", argument="") @@ -650,6 +676,8 @@ async def agenerate_script( agent_name: str = "", history: str = "", single_step: bool = False, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, ) -> tuple[ScriptInteractionReturnType, str]: """ Using langchain to generate an the script interactions between two agent @@ -683,6 +711,8 @@ async def agenerate_script( single_turn=True, ), temperature=temperature, + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) else: @@ -705,6 +735,8 @@ async def agenerate_script( single_turn=False, ), temperature=temperature, + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) except Exception as e: # TODO raise(e) # Maybe we do not want to return anything? @@ -733,7 +765,12 @@ def process_history( @beartype -async def agenerate_init_profile(model_name: str, basic_info: dict[str, str]) -> str: +async def agenerate_init_profile( + model_name: str, + basic_info: dict[str, str], + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, +) -> str: """ Using langchain to generate the background """ @@ -767,11 +804,19 @@ async def agenerate_init_profile(model_name: str, basic_info: dict[str, str]) -> secret=basic_info["secret"], ), output_parser=StrOutputParser(), + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) @beartype -async def convert_narratives(model_name: str, narrative: str, text: str) -> str: +async def convert_narratives( + model_name: str, + narrative: str, + text: str, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, +) -> str: if narrative == "first": return await agenerate( model_name=model_name, @@ -780,6 +825,8 @@ async def convert_narratives(model_name: str, narrative: str, text: str) -> str: {text}""", input_values=dict(text=text), output_parser=StrOutputParser(), + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) elif narrative == "second": return await agenerate( @@ -789,13 +836,20 @@ async def convert_narratives(model_name: str, narrative: str, text: str) -> str: {text}""", input_values=dict(text=text), output_parser=StrOutputParser(), + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, ) else: raise ValueError(f"Narrative {narrative} is not supported.") @beartype -async def agenerate_goal(model_name: str, background: str) -> str: +async def agenerate_goal( + model_name: str, + background: str, + bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL, + use_fixed_model_version: bool = True, +) -> str: """ Using langchain to generate the background """ @@ -806,4 +860,6 @@ async def agenerate_goal(model_name: str, background: str) -> str: """, input_values=dict(background=background), output_parser=StrOutputParser(), + bad_output_process_model=bad_output_process_model, + use_fixed_model_version=use_fixed_model_version, )