Skip to content

Commit

Permalink
Add bad_output_process_model option and use_fixed_model_version o…
Browse files Browse the repository at this point in the history
…ption for all generation methods, to avoid future OpenAI API changes break Sotopia running. (#196)

* Two major updates: 1) add "bad_output_process_model" option to all `agenerate_xxx()` methods so users can decide which model to use for handling bad outputs. By default, this is set to be `gpt-4o-mini`. 2) add `use_fixed_model_version` option for all generation methods, as some fixed model version may no longer available in the future. Users should have the right to bypass the fixed model version mapping instead of getting stuck in an error. Document (`generation.md`) has been updated for these two major changes correspondingly.

* [autofix.ci] apply automated fixes

---------

Co-authored-by: Chenghao Yang <yangalan1996@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 24, 2024
1 parent 78e8eb8 commit b8a6dbc
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 7 deletions.
16 changes: 16 additions & 0 deletions docs/pages/concepts/generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ async def agenerate(
output_parser: BaseOutputParser[OutputType],
temperature: float = 0.7,
structured_output: bool = False,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True
) -> OutputType:
input_variables = re.findall(r"(?<!{){([^{}]+)}(?!})", template)
```
Expand All @@ -23,6 +25,12 @@ The `agenerate` function is versatile by taking the output_parser as an argument
* `gpt-4o-mini-2024-07-18` and later
* `gpt-4o-2024-08-06` and later

The `bad_output_process_model` is used to process the bad output. `DEFAULT_BAD_OUTPUT_PROCESS_MODEL` is set to be `gpt-4o-mini` (At the publication time of Sotopia, we used `gpt-3.5-turbo-0613`. However this model has been taken off the shelf by OpenAI.).

The `use_fixed_model_version` is used to determine whether to use the fixed model version. If set to `True`, the model version will be fixed to the version that was used in Sotopia paper. If set to `False`, the model version will be the latest version available.

Warning: As some fixed model versions might not be available in the OpenAI API, setting `use_fixed_model_version = True` might result in an error.

</Callout>

Here are a few examples of how to use the `agenerate` function:
Expand All @@ -37,6 +45,8 @@ async def agenerate_env_profile(
inspiration_prompt: str = "asking my boyfriend to stop being friends with his ex",
examples: str = "",
temperature: float = 0.7,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True
) -> tuple[EnvironmentProfile, str]:
"""
Using langchain to generate the background
Expand All @@ -56,6 +66,8 @@ async def agenerate_env_profile(
),
output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile),
temperature=temperature,
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version
)
```
### Other generation functions
Expand All @@ -66,6 +78,8 @@ Similarly, there are other utility functions that builds upon the `agenerate` fu
async def agenerate_relationship_profile(
model_name: str,
agents_profiles: list[str],
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True
) -> tuple[RelationshipProfile, str]
```

Expand All @@ -78,5 +92,7 @@ async def agenerate_script(
agent_name: str = "",
history: str = "",
single_step: bool = False,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True
) -> tuple[ScriptInteractionReturnType, str]
```
70 changes: 63 additions & 7 deletions sotopia/generation_utils/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
"redis",
"groq/llama3-70b-8192",
]
# subject to future OpenAI changes
DEFAULT_BAD_OUTPUT_PROCESS_MODEL = "gpt-4o-mini"

OutputType = TypeVar("OutputType", bound=object)
client = OpenAI()
Expand Down Expand Up @@ -304,6 +306,7 @@ def obtain_chain(
input_variables: list[str],
temperature: float = 0.7,
max_retries: int = 6,
use_fixed_model_version: bool = True,
) -> RunnableSerializable[dict[Any, Any], BaseMessage]:
"""
Using langchain to sample profiles for participants
Expand All @@ -315,7 +318,8 @@ def obtain_chain(
)
)
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
model_name = _return_fixed_model_version(model_name)
if use_fixed_model_version:
model_name = _return_fixed_model_version(model_name)
if model_name.startswith("together_ai"):
model_name = "/".join(model_name.split("/")[1:])
chat_openai = ChatOpenAI(
Expand Down Expand Up @@ -391,7 +395,8 @@ def format_bad_output_for_script(
ill_formed_output: str,
format_instructions: str,
agents: list[str],
model_name: str = "gpt-4o-mini",
model_name: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> BaseMessage:
template = """
Given the string that can not be parsed by a parser, reformat it to a string that can be parsed by the parser which uses the following format instructions. Do not add or delete any information.
Expand All @@ -410,6 +415,7 @@ def format_bad_output_for_script(
model_name=model_name,
template=template,
input_variables=re.findall(r"{(.*?)}", template),
use_fixed_model_version=use_fixed_model_version,
)
input_values = {
"ill_formed_output": ill_formed_output,
Expand All @@ -425,7 +431,8 @@ def format_bad_output_for_script(
def format_bad_output(
ill_formed_output: BaseMessage,
format_instructions: str,
model_name: str = "gpt-4o-mini",
model_name: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> BaseMessage:
template = """
Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser.
Expand All @@ -439,6 +446,7 @@ def format_bad_output(
model_name=model_name,
template=template,
input_variables=re.findall(r"{(.*?)}", template),
use_fixed_model_version=use_fixed_model_version,
)
input_values = {
"ill_formed_output": ill_formed_output.content,
Expand All @@ -458,6 +466,8 @@ async def agenerate(
output_parser: BaseOutputParser[OutputType],
temperature: float = 0.7,
structured_output: bool = False,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> OutputType:
input_variables = re.findall(
r"(?<!{){([^{}]+)}(?!})", template
Expand All @@ -473,6 +483,7 @@ async def agenerate(
template=template,
input_variables=input_variables,
temperature=temperature,
use_fixed_model_version=use_fixed_model_version,
)

if "format_instructions" not in input_values:
Expand Down Expand Up @@ -516,7 +527,10 @@ async def agenerate(
extra={"markup": True},
)
reformat_parsed_result = format_bad_output(
result, format_instructions=output_parser.get_format_instructions()
result,
format_instructions=output_parser.get_format_instructions(),
model_name=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)
parsed_result = output_parser.invoke(reformat_parsed_result)
log.info(f"Generated result: {parsed_result}")
Expand All @@ -530,6 +544,8 @@ async def agenerate_env_profile(
inspiration_prompt: str = "asking my boyfriend to stop being friends with his ex",
examples: str = "",
temperature: float = 0.7,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> tuple[EnvironmentProfile, str]:
"""
Using langchain to generate the background
Expand All @@ -549,13 +565,17 @@ async def agenerate_env_profile(
),
output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile),
temperature=temperature,
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)


@beartype
async def agenerate_relationship_profile(
model_name: str,
agents_profiles: list[str],
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> tuple[RelationshipProfile, str]:
"""
Using langchain to generate the background
Expand All @@ -572,6 +592,8 @@ async def agenerate_relationship_profile(
agent_profile=agent_profile,
),
output_parser=PydanticOutputParser(pydantic_object=RelationshipProfile),
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)


Expand All @@ -586,6 +608,8 @@ async def agenerate_action(
goal: str,
temperature: float = 0.7,
script_like: bool = False,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> AgentAction:
"""
Using langchain to generate an example episode
Expand Down Expand Up @@ -635,6 +659,8 @@ async def agenerate_action(
),
output_parser=PydanticOutputParser(pydantic_object=AgentAction),
temperature=temperature,
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)
except Exception:
return AgentAction(action_type="none", argument="")
Expand All @@ -650,6 +676,8 @@ async def agenerate_script(
agent_name: str = "",
history: str = "",
single_step: bool = False,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> tuple[ScriptInteractionReturnType, str]:
"""
Using langchain to generate an the script interactions between two agent
Expand Down Expand Up @@ -683,6 +711,8 @@ async def agenerate_script(
single_turn=True,
),
temperature=temperature,
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)

else:
Expand All @@ -705,6 +735,8 @@ async def agenerate_script(
single_turn=False,
),
temperature=temperature,
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)
except Exception as e:
# TODO raise(e) # Maybe we do not want to return anything?
Expand Down Expand Up @@ -733,7 +765,12 @@ def process_history(


@beartype
async def agenerate_init_profile(model_name: str, basic_info: dict[str, str]) -> str:
async def agenerate_init_profile(
model_name: str,
basic_info: dict[str, str],
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> str:
"""
Using langchain to generate the background
"""
Expand Down Expand Up @@ -767,11 +804,19 @@ async def agenerate_init_profile(model_name: str, basic_info: dict[str, str]) ->
secret=basic_info["secret"],
),
output_parser=StrOutputParser(),
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)


@beartype
async def convert_narratives(model_name: str, narrative: str, text: str) -> str:
async def convert_narratives(
model_name: str,
narrative: str,
text: str,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> str:
if narrative == "first":
return await agenerate(
model_name=model_name,
Expand All @@ -780,6 +825,8 @@ async def convert_narratives(model_name: str, narrative: str, text: str) -> str:
{text}""",
input_values=dict(text=text),
output_parser=StrOutputParser(),
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)
elif narrative == "second":
return await agenerate(
Expand All @@ -789,13 +836,20 @@ async def convert_narratives(model_name: str, narrative: str, text: str) -> str:
{text}""",
input_values=dict(text=text),
output_parser=StrOutputParser(),
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)
else:
raise ValueError(f"Narrative {narrative} is not supported.")


@beartype
async def agenerate_goal(model_name: str, background: str) -> str:
async def agenerate_goal(
model_name: str,
background: str,
bad_output_process_model: str = DEFAULT_BAD_OUTPUT_PROCESS_MODEL,
use_fixed_model_version: bool = True,
) -> str:
"""
Using langchain to generate the background
"""
Expand All @@ -806,4 +860,6 @@ async def agenerate_goal(model_name: str, background: str) -> str:
""",
input_values=dict(background=background),
output_parser=StrOutputParser(),
bad_output_process_model=bad_output_process_model,
use_fixed_model_version=use_fixed_model_version,
)

0 comments on commit b8a6dbc

Please sign in to comment.