From e9dfe8fe01b517871094132db10679db6ab12ef9 Mon Sep 17 00:00:00 2001 From: Haofei Yu <1125027232@qq.com> Date: Fri, 10 Nov 2023 22:28:28 -0500 Subject: [PATCH] Feature/support better sceanrio goal generation (#92) * support inference on the whole dataset * add initial code for scenario and social goal generation * modify readme * add step1 and step2 but still not correct * add a test * support generating env and match it with existing agents to be a combo * add readme * support gpt-4-turbo change and add db doc and conert to json * fix prompt to generate better scenario based on gpt-4-turbo * complete the overall 4 steps * modify readme * support bert score filtering * change name * delete file * delete jsonl * delete dump.rdb * modify readme --- README.md | 3 +- data_generate/README.md | 20 + data_generate/generate.py | 741 ++++++++++++++++++ data_generate/generate_specific_envs.py | 135 ++++ data_generate/requirments.txt | 5 + ...ep1.5_delete_too_sim_scenario_with_test.py | 47 ++ data_generate/step1.sh | 1 + data_generate/step1_generate_env_profile.py | 60 ++ data_generate/step2.sh | 4 + ...step2_push_agent_relationship_env_to_db.py | 151 ++++ data_generate/step3_convert_db_into_json.py | 31 + .../step4_convert_json_to_gen_input.py | 146 ++++ data_generate/test_redis.py | 10 + data_generate/test_redisjson.py | 11 + 14 files changed, 1364 insertions(+), 1 deletion(-) create mode 100644 data_generate/README.md create mode 100644 data_generate/generate.py create mode 100644 data_generate/generate_specific_envs.py create mode 100644 data_generate/requirments.txt create mode 100644 data_generate/step1.5_delete_too_sim_scenario_with_test.py create mode 100644 data_generate/step1.sh create mode 100644 data_generate/step1_generate_env_profile.py create mode 100644 data_generate/step2.sh create mode 100644 data_generate/step2_push_agent_relationship_env_to_db.py create mode 100644 data_generate/step3_convert_db_into_json.py create mode 100644 data_generate/step4_convert_json_to_gen_input.py create mode 100644 data_generate/test_redis.py create mode 100644 data_generate/test_redisjson.py diff --git a/README.md b/README.md index ee4e743f..d1baa453 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,10 @@ We split our overall framework into multiple parts +0. Data Generate --> Input None / Output new data on redis 1. Data Processing --> Output general form of sotopia train and test data 2. Together AI Finetuning --> Input the train and test data / Output model checkpoint 3. LLM Finetuning --> Input the train and test data / Output model checkpoint 4. LLM Deplyment --> Input LLM Finetuned model checkpoint / Output Deployable OpenAI type API 5. Eval --> Input model checkpoint / Output evaluation scores -6. Generate --> Input None / Output new data on redis \ No newline at end of file + diff --git a/data_generate/README.md b/data_generate/README.md new file mode 100644 index 00000000..54b4ca9c --- /dev/null +++ b/data_generate/README.md @@ -0,0 +1,20 @@ +# Data Generation + +For the first step, we generate envProfile (including scenario / social goal / relationship restriction) based on inspiring prompt. + +For the 2.1 step, we put the original agentProfile and relationshipProfile into our new redis database + +For the 2.2 step, we combine them together to be combos based on conditiona sampling (the restriction is the relationship) + +All the EnvProfile (new generated), AgentProfile (sotopia original), RelationshipProfile (sotopia original), and envagentcombo are on the redis database that is new created. + +For the third step, we need to use another version of redis and convert it into json file and save the whole data in the database on the local machine. + +For the final step, we convert the whole thing into Ruiyi's format. + +# Local Redis Setting +Since the redis-server cannot directly input json data, it requires loading a RedisJson model into the redis-server to enable this function. Therefore, we need to load a docker based on RedisJson: + +docker run -p 6379:6379 --name redis-stack redis/redis-stack:latest + +Link: diff --git a/data_generate/generate.py b/data_generate/generate.py new file mode 100644 index 00000000..57401720 --- /dev/null +++ b/data_generate/generate.py @@ -0,0 +1,741 @@ +import logging +import re +from typing import TypeVar, cast + +import gin +from beartype import beartype +from beartype.typing import Type +from langchain.callbacks import StdOutCallbackHandler +from langchain.chains import LLMChain +from langchain.chat_models import ChatOpenAI +from langchain.llms import OpenAI +from langchain.output_parsers import PydanticOutputParser +from langchain.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + PromptTemplate, +) +from langchain.schema import ( + BaseOutputParser, + HumanMessage, + OutputParserException, +) +from pydantic import BaseModel, Field, validator +from rich import print +from rich.logging import RichHandler +from typing_extensions import Literal + +from sotopia.database import EnvironmentProfile, RelationshipProfile +from sotopia.messages import ( + ActionType, + AgentAction, + ScriptBackground, + ScriptEnvironmentResponse, +) +from sotopia.utils import format_docstring + +from sotopia.generation_utils.langchain_callback_handler import LoggingCallbackHandler +from sotopia.generation_utils.llama2 import Llama2 + +log = logging.getLogger("generate") +logging_handler = LoggingCallbackHandler("langchain") + +LLM_Name = Literal[ + "togethercomputer/llama-2-7b-chat", + "togethercomputer/llama-2-70b-chat", + "togethercomputer/mpt-30b-chat", + "gpt-3.5-turbo", + "text-davinci-003", + "gpt-4", + "gpt-4-turbo", + "human", + "redis", +] + +OutputType = TypeVar("OutputType", bound=object) + + +class EnvResponse(BaseModel): + reasoning: str = Field( + description="first reiterate agents' social goals and then reason about what agents say/do and whether that aligns with their goals." + ) + p1_rate: int = Field( + description="rating of participant 1, on the scale of 0 to 9" + ) + p2_rate: int = Field( + description="rating of participant 2, on the scale of 0 to 9" + ) + + +class EnvResponsePydanticOutputParser(PydanticOutputParser[EnvResponse]): + def __init__(self, pydantic_object: Type[BaseModel] = EnvResponse) -> None: + super(EnvResponsePydanticOutputParser, self).__init__( + pydantic_object=pydantic_object + ) + + def parse(self, text: str) -> EnvResponse: + # remove trailing commas before ) or ] from text + text = re.sub(r",\s*(\)|\])", r"\1", text) + return super().parse(text) + + def get_format_instructions(self) -> str: + format_instruction = super().get_format_instructions() + return format_instruction + + +class ListOfIntOutputParser(BaseOutputParser[list[int]]): + number_of_int: int | None + range_of_int: tuple[int, int] | None + + def __init__( + self, + number_of_int: int | None = None, + range_of_int: tuple[int, int] | None = None, + ): + """ + Parse the output to a list of integers + + Args: + number_of_int (int | None): The number of integers in the output. If None, the number of integers is not fixed. + """ + super().__init__() + self.number_of_int = number_of_int + self.range_of_int = range_of_int + + def _get_description_text(self) -> str: + return f"a list of{' ' + str(self.number_of_int) if self.number_of_int else ''} intergers{' within the range of' + str(self.range_of_int) if self.range_of_int else ''} separated by space" + + def get_format_instructions(self) -> str: + return "Please output " + self._get_description_text() + + def parse(self, output: str) -> list[int]: + try: + if ":" in output: + output = output.split(":")[1] + result = [int(x) for x in output.split(" ") if x] + if self.number_of_int and len(result) != self.number_of_int: + msg = ( + f"Expect {self.number_of_int} integers, got {len(result)}" + ) + raise OutputParserException(msg) + if self.range_of_int: + for x in result: + if x < self.range_of_int[0] or x > self.range_of_int[1]: + msg = f"Expect integers within the range of {self.range_of_int}, got {result}" + raise OutputParserException(msg) + return result + except KeyboardInterrupt: + raise KeyboardInterrupt + except Exception as e: + msg = f"Exception {e}: the output format is not correct. Expect {self._get_description_text()}, got {output}" + raise OutputParserException(msg) + + @property + def _type(self) -> str: + """Return the type key.""" + return "list[int]" + + +class ListOfStrOutputParser(BaseOutputParser[list[str]]): + number_of_str: int | None + + def __init__( + self, + number_of_str: int | None = None, + ): + """ + Parse the output to a list of strings + + Args: + number_of_str (int | None): The number of strings in the output. If None, the number of strings is not fixed. + """ + super().__init__() + self.number_of_str = number_of_str + + def _get_description_text(self) -> str: + return f"a list of{' ' + str(self.number_of_str) if self.number_of_str else ''} strings separated by space" + + def get_format_instructions(self) -> str: + return "Please output " + self._get_description_text() + + def parse(self, output: str) -> list[str]: + try: + result = output.split(" ") + if self.number_of_str and len(result) != self.number_of_str: + msg = f"Expect {self.number_of_str} strings, got {len(result)}" + raise OutputParserException(msg) + return result + except KeyboardInterrupt: + raise KeyboardInterrupt + except Exception as e: + msg = f"Exception {e}: the output format is not correct. Expect {self._get_description_text()}, got {output}" + raise OutputParserException(msg) + + @property + def _type(self) -> str: + """Return the type key.""" + return "list[str]" + + +class StrOutputParser(BaseOutputParser[str]): + def __init__(self) -> None: + super().__init__() + + def get_format_instructions(self) -> str: + return "Please output a string" + + def parse(self, output: str) -> str: + return output + + @property + def _type(self) -> str: + """Return the type key.""" + return "str" + + +def _return_fixed_model_version( + model_name: Literal["gpt-3.5-turbo", "gpt-4"] +) -> str: + return { + "gpt-3.5-turbo": "gpt-3.5-turbo-0613", + "gpt-4": "gpt-4-0613", + "gpt-4-turbo": "gpt-4-1106-preview" + }[model_name] + + +@gin.configurable +@beartype +def obtain_chain( + model_name: LLM_Name, + template: str, + input_variables: list[str], + temperature: float = 0.7, + max_retries: int = 6, +) -> LLMChain: + """ + Using langchain to sample profiles for participants + """ + match model_name: + case "gpt-3.5-turbo" | "gpt-4" | "gpt-4-turbo": + human_message_prompt = HumanMessagePromptTemplate( + prompt=PromptTemplate( + template=template, + input_variables=input_variables, + ) + ) + chat_prompt_template = ChatPromptTemplate.from_messages( + [human_message_prompt] + ) + chat = ChatOpenAI( + model_name=_return_fixed_model_version(model_name), + temperature=temperature, + max_retries=max_retries, + ) + chain = LLMChain(llm=chat, prompt=chat_prompt_template) + return chain + case "text-davinci-003": + # Warning: no interactive mode for 003 + llm = OpenAI( + model_name=model_name, + temperature=temperature, + max_retries=max_retries, + ) + prompt = PromptTemplate( + input_variables=input_variables, + template=template, + ) + chain = LLMChain(llm=llm, prompt=prompt) + return chain + case "togethercomputer/llama-2-7b-chat" | "togethercomputer/llama-2-70b-chat": + human_message_prompt = HumanMessagePromptTemplate( + prompt=PromptTemplate( + template=template, + input_variables=input_variables, + ) + ) + chat_prompt_template = ChatPromptTemplate.from_messages( + [human_message_prompt] + ) + together_llm = Llama2( + model_name=model_name, temperature=temperature + ) + chain = LLMChain(llm=together_llm, prompt=chat_prompt_template) + return chain + case _: + raise ValueError(f"Invalid model name: {model_name}") + + +@beartype +def format_bad_output( + ill_formed_output: str, + format_instructions: str, + model_name: LLM_Name = "gpt-3.5-turbo", +) -> str: + template = """ + Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser. + Original string: {ill_formed_output} + + Format instructions: {format_instructions} + + Please only generate the JSON: + """ + chain = obtain_chain( + model_name=model_name, + template=template, + input_variables=re.findall(r"{(.*?)}", template), + ) + input_values = { + "ill_formed_output": ill_formed_output, + "format_instructions": format_instructions, + } + reformat = chain.predict([logging_handler], **input_values) + log.info(f"Reformated output: {reformat}") + return reformat + + +@beartype +def generate( + model_name: LLM_Name, + template: str, + input_values: dict[str, str], + output_parser: BaseOutputParser[OutputType], + temperature: float = 0.7, +) -> OutputType: + input_variables = re.findall(r"{(.*?)}", template) + assert set(input_variables) == set( + list(input_values.keys()) + ["format_instructions"] + ) or set(input_variables) == set( + list(input_values.keys()) + ), f"The variables in the template must match input_values except for format_instructions. Got {sorted(input_values.keys())}, expect {sorted(input_variables)}" + # process template + template = format_docstring(template) + chain = obtain_chain( + model_name=model_name, + template=template, + input_variables=input_variables, + temperature=temperature, + ) + if "format_instructions" not in input_values: + input_values[ + "format_instructions" + ] = output_parser.get_format_instructions() + result = chain.predict([logging_handler], **input_values) + import pdb; pdb.set_trace() + try: + parsed_result = output_parser.parse(result) + except KeyboardInterrupt: + raise KeyboardInterrupt + except Exception as e: + log.debug( + f"[red] Failed to parse result: {result}\nEncounter Exception {e}\nstart to reparse", + extra={"markup": True}, + ) + reformat_parsed_result = format_bad_output( + result, format_instructions=output_parser.get_format_instructions() + ) + parsed_result = output_parser.parse(reformat_parsed_result) + log.info(f"Generated result: {parsed_result}") + return parsed_result + + +@gin.configurable +@beartype +async def agenerate( + model_name: LLM_Name, + template: str, + input_values: dict[str, str], + output_parser: BaseOutputParser[OutputType], + temperature: float = 0.7, +) -> tuple[OutputType, str]: + input_variables = re.findall(r"{(.*?)}", template) + assert set(input_variables) == set( + list(input_values.keys()) + ["format_instructions"] + ) or set(input_variables) == set( + list(input_values.keys()) + ), f"The variables in the template must match input_values except for format_instructions. Got {sorted(input_values.keys())}, expect {sorted(input_variables)}" + # process template + template = format_docstring(template) + chain = obtain_chain( + model_name=model_name, + template=template, + input_variables=input_variables, + temperature=temperature, + ) + if "format_instructions" not in input_values: + input_values[ + "format_instructions" + ] = output_parser.get_format_instructions() + result = await chain.apredict([logging_handler], **input_values) + prompt = logging_handler.retrive_prompt() + try: + parsed_result = output_parser.parse(result) + except Exception as e: + log.debug( + f"[red] Failed to parse result: {result}\nEncounter Exception {e}\nstart to reparse", + extra={"markup": True}, + ) + reformat_parsed_result = format_bad_output( + result, format_instructions=output_parser.get_format_instructions() + ) + parsed_result = output_parser.parse(reformat_parsed_result) + log.info(f"Generated result: {parsed_result}") + return parsed_result, prompt + + +# deprecated function +@beartype +def generate_episode( + model_name: LLM_Name, + participants: str = "Jack (a greedy person), Rose", + topic: str = "lawsuit", + extra_info: str = "", +) -> EnvResponse: + """ + Using langchain to generate an example episode + """ + return generate( + model_name=model_name, + template=""" + Please generate a episode for the interaction between {participants} regarding {topic}. + You should generate the personal backgrounds and goals in this interaction. + Use the following extra info if given: {extra_info} + Please use the following format: + {format_instructions} + """, + input_values=dict( + participants=participants, + topic=topic, + extra_info=extra_info, + ), + output_parser=EnvResponsePydanticOutputParser(), + ) + + +@gin.configurable +@beartype +async def agenerate_env_profile( + model_name: LLM_Name, + inspiration_prompt: str = "asking my boyfriend to stop being friends with his ex", + examples: str = "", + temperature: float = 0.7, +) -> tuple[EnvironmentProfile, str]: + """ + Using langchain to generate the background + """ + return await agenerate( + model_name=model_name, + template="""Please generate scenarios and goals following the examples below. + Examples: + {examples} + Additionally, generate creative scenarios based on one or more inspirational prompt. The scenario and social goal is motivated by them but not very related to those prompts, when creating the goals, try to find one point that both sides may not agree upon initially and need to collaboratively resolve it. + Inspirational prompt: {inspiration_prompt} + Please use the following format and follow that format strictly: + {format_instructions} + """, + input_values=dict( + inspiration_prompt=inspiration_prompt, + examples=examples, + ), + output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile), + temperature=temperature, + ) + + +@beartype +async def agenerate_relationship_profile( + model_name: LLM_Name, + agents_profiles: list[str], +) -> tuple[RelationshipProfile, str]: + """ + Using langchain to generate the background + """ + agent_profile = "\n".join(agents_profiles) + return await agenerate( + model_name=model_name, + template="""Please generate relationship between two agents based on the agents' profiles below. Note that you generate + {agent_profile} + Please use the following format: + {format_instructions} + """, + input_values=dict( + agent_profile=agent_profile, + ), + output_parser=PydanticOutputParser( + pydantic_object=RelationshipProfile + ), + ) + + +@beartype +async def agenerate_enviroment_profile( + model_name: LLM_Name, + inspiration_prompt: str = "asking my boyfriend to stop being friends with his ex", + examples: str = "", +) -> tuple[EnvironmentProfile, str]: + """ + Using langchain to generate the background + """ + return await agenerate( + model_name=model_name, + template="""Please generate scenarios and goals based on the examples below as well as the inspirational prompt, when creating the goals, try to find one point that both sides may not agree upon initially and need to collaboratively resolve it. + Examples: + {examples} + Inspirational prompt: {inspiration_prompt} + Please use the following format: + {format_instructions} + """, + input_values=dict( + inspiration_prompt=inspiration_prompt, + examples=examples, + ), + output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile), + ) + + +@beartype +def fill_in_background( + model_name: LLM_Name, + partial_background: ScriptBackground, +) -> ScriptBackground: + """ + Fill in the missing information of the background + """ + return generate( + model_name=model_name, + template="""Please fill in all missing information of the given background, don't leave any tag: + {partial_background} + Please use the following format: + {format_instructions} + """, + input_values=dict( + partial_background=partial_background.to_natural_language(), + ), + output_parser=PydanticOutputParser(pydantic_object=ScriptBackground), + ) + + +@beartype +def generate_action( + model_name: LLM_Name, + history: str, + turn_number: int, + action_types: list[ActionType], + agent: str, + goal: str, +) -> AgentAction: + """ + Using langchain to generate an example episode + """ + try: + return generate( + model_name=model_name, + template=""" + Imagine you are {agent}, your task is to act/speak like {agent} with {agent}'s social goal in mind. + You can find {agent}'s background and goal in the following history: + {history} + You are at Turn #{turn_number}. Your available action types are + {action_list}. + Note: You can "leave" this conversation if 1. this conversation makes you uncomfortable, 2. you find it uninteresting/you lose your patience, 3. you have achieved your social goals, 4. or for other reasons you want to leave. + + Please only generate a JSON string including the action type and the argument. + Your action should follow the given format: + {format_instructions} + """, + input_values=dict( + agent=agent, + turn_number=str(turn_number), + history=history, + action_list=" ".join(action_types), + ), + output_parser=PydanticOutputParser(pydantic_object=AgentAction), + ) + except KeyboardInterrupt: + raise KeyboardInterrupt + except: + return AgentAction(action_type="none", argument="") + + +@beartype +def generate_action_speak( + model_name: LLM_Name, + history: str, + turn_number: int, + action_types: list[ActionType], + agent: str, + goal: str, +) -> AgentAction: + """ + Using langchain to generate the action but only speak action is allowed + """ + try: + utterance = generate( + model_name=model_name, + template=""" + You are {agent}. + {history} + + You are at Turn #{turn_number}. Your available action type is speak. + Your goal is: {goal} + Follow the given format: + {agent} said: + should not include any quotation marks, "Turn #", or etc. + """, + input_values=dict( + agent=agent, + turn_number=str(turn_number), + history=history, + goal=goal, + ), + output_parser=StrOutputParser(), + ) + # delete the first line + utterance = utterance.replace(f"{agent} said:", "") + utterance = utterance.replace(f"Turn #{turn_number}:", "") + utterance = utterance.strip() + utterance = utterance.replace('"', "") + return AgentAction(action_type="speak", argument=utterance) + except KeyboardInterrupt: + raise KeyboardInterrupt + except: + return AgentAction(action_type="none", argument="") + + +@gin.configurable +@beartype +async def agenerate_action( + model_name: LLM_Name, + history: str, + turn_number: int, + action_types: list[ActionType], + agent: str, + goal: str, + temperature: float = 0.7, +) -> tuple[AgentAction, str]: + """ + Using langchain to generate an example episode + """ + try: + return await agenerate( + model_name=model_name, + template=""" + Imagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal. + You can find {agent}'s background and goal in the 'Here is the context of the interaction' field. + Note that {agent}'s secret and goal is only visible to you. + You should try your best to achieve {agent}'s goal in a way that align with their character traits. + Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before). + {history}. + You are at Turn #{turn_number}. Your available action types are + {action_list}. + Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave. + + Please only generate a JSON string including the action type and the argument. + Your action should follow the given format: + {format_instructions} + """, + input_values=dict( + agent=agent, + turn_number=str(turn_number), + history=history, + action_list=" ".join(action_types), + ), + output_parser=PydanticOutputParser(pydantic_object=AgentAction), + temperature=temperature, + ) + except: + return AgentAction(action_type="none", argument=""), "" + + +@beartype +def process_history( + script: ScriptBackground | EnvResponse | dict[str, AgentAction] +) -> str: + """ + Format the script background + """ + result = "" + if isinstance(script, ScriptBackground | EnvResponse): + script = script.dict() + result = "The initial observation\n\n" + for key, value in script.items(): + if value: + result += f"{key}: {value} \n" + return result + + +@beartype +def generate_init_profile( + model_name: LLM_Name, basic_info: dict[str, str] +) -> str: + """ + Using langchain to generate the background + """ + return generate( + model_name=model_name, + template="""Please expand a fictional background for {name}. Here is the basic information: + {name}'s age: {age} + {name}'s gender identity: {gender_identity} + {name}'s pronouns: {pronoun} + {name}'s occupation: {occupation} + {name}'s big 5 personality traits: {bigfive} + {name}'s moral Foundation: think {mft} is more important than others + {name}'s Schwartz portrait value: {schwartz} + {name}'s decision-making style: {decision_style} + {name}'s secret: {secret} + Include the previous information in the background. + Then expand the personal backgrounds with concrete details (e.g, look, family, hobbies, friends and etc.) + For the personality and values (e.g., MBTI, moral foundation, and etc.), + remember to use examples and behaviors in the person's life to demonstrate it. + """, + input_values=dict( + name=basic_info["name"], + age=basic_info["age"], + gender_identity=basic_info["gender_identity"], + pronoun=basic_info["pronoun"], + occupation=basic_info["occupation"], + bigfive=basic_info["Big_Five_Personality"], + mft=basic_info["Moral_Foundation"], + schwartz=basic_info["Schwartz_Portrait_Value"], + decision_style=basic_info["Decision_making_Style"], + secret=basic_info["secret"], + ), + output_parser=StrOutputParser(), + ) + + +@beartype +def convert_narratives(model_name: LLM_Name, narrative: str, text: str) -> str: + if narrative == "first": + return generate( + model_name=model_name, + template="""Please convert the following text into a first-person narrative. + e.g, replace name, he, she, him, her, his, and hers with I, me, my, and mine. + {text}""", + input_values=dict(text=text), + output_parser=StrOutputParser(), + ) + elif narrative == "second": + return generate( + model_name=model_name, + template="""Please convert the following text into a second-person narrative. + e.g, replace name, he, she, him, her, his, and hers with you, your, and yours. + {text}""", + input_values=dict(text=text), + output_parser=StrOutputParser(), + ) + else: + raise ValueError(f"Narrative {narrative} is not supported.") + + +@beartype +def generate_goal(model_name: LLM_Name, background: str) -> str: + """ + Using langchain to generate the background + """ + return generate( + model_name=model_name, + template="""Please generate your goal based on the background: + {background} + """, + input_values=dict(background=background), + output_parser=StrOutputParser(), + ) \ No newline at end of file diff --git a/data_generate/generate_specific_envs.py b/data_generate/generate_specific_envs.py new file mode 100644 index 00000000..94ab8014 --- /dev/null +++ b/data_generate/generate_specific_envs.py @@ -0,0 +1,135 @@ +"""This file is used to generate specific environments based on existing +datasets. The generation functions below should call agenerate_env_profile +in `sotopia/generation_utils/generate.py` with the appropriate parameters. +Here are the datasets we have so far: +1. Mutual-Friend (https://huggingface.co/datasets/mutual_friends) +""" +import asyncio +from typing import Hashable + +import datasets +import names +import numpy as np +from datasets import DatasetDict, load_dataset + +from generate import ( + ListOfStrOutputParser, + StrOutputParser, + agenerate, + generate, +) + + +async def generate_mutual_friend_envs() -> tuple[str, list[str]]: + """Generate environments based on the mutual-friend dataset.""" + mutual_friend_dataset: DatasetDict = load_dataset("mutual_friends") + all_data = mutual_friend_dataset["train"] + # sample one datum from all data + datum = np.random.choice(all_data) + friends = datum["scenario_kbs"] + num_of_friends_in_total = sum(map(len, friends)) + # generate names for the friends + set_of_names = set() + for _ in range(num_of_friends_in_total): + name = names.get_first_name() + while name in set_of_names: + name = names.get_first_name() + set_of_names.add(name) + list_of_names = list(set_of_names) + friend_map: dict[tuple[str, ...], str] = {} + friend_list_map: list[list[str]] = [[] for _ in range(len(friends))] + friend_description_keys: list[str] = datum["scenario_attributes"]["name"] + name_pointer = 0 + for i, friends_array in enumerate(friends): + for friend in friends_array: + assert ( + len(friend) == 2 + ) # in [[key1, key2, ...], [value1, value2, ...]] format + if not tuple(friend[1]) in friend_map: + friend_map[tuple(friend[1])] = list_of_names[name_pointer] + name_pointer += 1 + friend_list_map[i].append(friend_map[tuple(friend[1])]) + friend_set_map: list[set[str]] = [ + set(friend_list) for friend_list in friend_list_map + ] + common_friends = [] + for friend_description, friend_name in friend_map.items(): + if all([friend_name in friend_set for friend_set in friend_set_map]): + common_friends.append(friend_name) + scenario = ( + f'{len(friends)} strangers are meeting at a party.

They have {len(common_friends)} common friends: ' + f"{', '.join(common_friends[:-1])}" + + (" and " if len(common_friends) > 1 else "") + + common_friends[-1] + + ".

" + ) + goals: list[str] = [] + for friends_array in friends: + template = f"You are trying to figure out whether you have a mutual friend with the other person. \n" + template += f" You know the following friends" + for friend in friends_array: + friend_name = friend_map[tuple(friend[1])] + friend_description = friend[1] + template += f" {friend_name}: {' '.join([(i + ': ' + j + ' ') if i != 'Name' else '' for i, j in zip(friend[0], friend_description)])}\n" + template += f"" + goals.append(template) + + return scenario, goals + + +async def generate_craigslist_bargains_envs() -> tuple[str, list[str]]: + """Generate environments based on the craigslist_bargains dataset.""" + craigslist_bargains_dataset: DatasetDict = load_dataset( + "craigslist_bargains" + ) + all_data = craigslist_bargains_dataset["train"] + # sample one datum from all data + datum = np.random.choice(all_data) + scenario = generate( + model_name="gpt-4", + template="The following sentence is automatically generated with the following" + 'template: "One person is selling for , another person is' + 'trying to buy it. Here is the description of the item: ." with item = {title}, ' + "price={price}, and description={description} Please make the sentence" + "fluent and natural.", + input_values={ + "title": datum["items"]["Title"][0], + "price": datum["items"]["Price"][0], + "description": datum["items"]["Description"][0], + }, + output_parser=StrOutputParser(), + ) + + goals: list[str] = [] + for i in range(2): + if datum["agent_info"]["Role"][i] == "seller": + markup_ratio = np.random.exponential(0.5) + datum["agent_info"]["Target"][i] = datum["items"]["Price"][0] / ( + 1 + markup_ratio + ) + goal = generate( + model_name="gpt-4", + template="The following sentence is automatically generated with the following" + 'template: "You want to this item. Your target price ' + "is $ (round up to two decimals). You will get penalty if you sell or buy it " + "for a price that is significantly lower than (if is seller) or significantly" + "higher than (if is buyer) the target price, but will get bonus if you successfully " + "sell it higher than the target price (if is seller) or buy it for lower than" + 'the target price (if is buyer)." ' + "with role = {role} and price = {price}. Please make the sentence" + "fluent and natural. Do not change the original meaning of the sentence.", + input_values={ + "role": datum["agent_info"]["Role"][i], + "price": datum["agent_info"]["Target"][i], + }, + output_parser=StrOutputParser(), + ) + goals.append(goal) + + return scenario, goals + + +if __name__ == '__main__': + for i in range(10): + scenario, goals = asyncio.run(generate_mutual_friend_envs()) + import pdb; pdb.set_trace() \ No newline at end of file diff --git a/data_generate/requirments.txt b/data_generate/requirments.txt new file mode 100644 index 00000000..242fc8a8 --- /dev/null +++ b/data_generate/requirments.txt @@ -0,0 +1,5 @@ +sotopia +rejson +evaluate +redis==3.5.3 # for step3 and step4 +redis==5.0.1 # for step2 \ No newline at end of file diff --git a/data_generate/step1.5_delete_too_sim_scenario_with_test.py b/data_generate/step1.5_delete_too_sim_scenario_with_test.py new file mode 100644 index 00000000..a8e737d7 --- /dev/null +++ b/data_generate/step1.5_delete_too_sim_scenario_with_test.py @@ -0,0 +1,47 @@ +import csv +import pandas as pd +import evaluate +from tqdm import tqdm + +bertscore = evaluate.load("bertscore") + +def convert_string_to_list(string): + # Assuming the list is separated by commas and without extra spaces + return string.strip('[]').split(',') + +def is_similar_to_any(text_tgt, list_src, similarity_threshold): + scores = bertscore.compute(predictions=[text_tgt]*len(list_src), references=list_src, lang="en") + print(scores['f1']) + return any(score > similarity_threshold for score in scores['f1']) + +tgt_df = pd.read_csv( + 'env_filtered.csv', + converters={'agent_goals': convert_string_to_list} + ) + +tgt_scenario_list = tgt_df['scenario'].tolist() + +src_df = pd.read_csv( + 'HardEnvProfile.csv', + converters={'agent_goals': convert_string_to_list} + ) + +src_scenario_list = src_df['scenario'].tolist() + +similarity_threshold = 0.875 # Adjust as needed + +dropped_tgt_scenario_list = [] +for text_tgt in tqdm(tgt_scenario_list): + if is_similar_to_any(text_tgt, src_scenario_list, similarity_threshold): + dropped_tgt_scenario_list.append(text_tgt) + #print(text_tgt) + #print("===") + #for text_src in src_scenario_list: + # print(text_src) + +# iterate the df and drop the rows +for text_tgt in dropped_tgt_scenario_list: + tgt_df = tgt_df[tgt_df["scenario"] != text_tgt] + +# save tgt_df to csv +tgt_df.to_csv('env_filtered_filtered_0.875.csv', index=False) diff --git a/data_generate/step1.sh b/data_generate/step1.sh new file mode 100644 index 00000000..64adf09e --- /dev/null +++ b/data_generate/step1.sh @@ -0,0 +1 @@ +python step1_generate_env_profile.py \ No newline at end of file diff --git a/data_generate/step1_generate_env_profile.py b/data_generate/step1_generate_env_profile.py new file mode 100644 index 00000000..e73020e2 --- /dev/null +++ b/data_generate/step1_generate_env_profile.py @@ -0,0 +1,60 @@ +import asyncio +import random +from typing import TypeVar +from tqdm import tqdm + +import pandas as pd +import rich +from pydantic import BaseModel + +from sotopia.database import EnvironmentProfile +from generate import agenerate_env_profile + +T = TypeVar("T", bound=BaseModel) + +def pydantics_to_csv(filename: str, data: list[T]) -> None: + pd.DataFrame([item.dict() for item in data]).to_csv(filename, index=False) + + +#random.seed(41) + +envs = EnvironmentProfile.find().all() +ins_prompts = pd.read_csv("./inspirational_prompt_for_env.csv") +prompts = [prompt.strip().replace('\"', '') for prompt in ins_prompts["prompt"].tolist()] + +# randomly choose 3 prompts +sampled_examples = [] +sampled_prompts = [] + +target_num = 500 + +for i in range(target_num): + sampled_envs = random.sample(envs, 5) + sampled_prompt = random.sample(prompts, 5) + sampled_examples.append(f"1.{sampled_envs[0].json()}\n2.{sampled_envs[1].json()}\n3.{sampled_envs[2].json()}\n4.{sampled_envs[3].json()}\n5.{sampled_envs[4].json()}") + sampled_prompts.append(f"1.{sampled_prompt[0]}\n2.{sampled_prompt[1]}\n3.{sampled_prompt[2]}\n4.{sampled_prompt[3]}\n5.{sampled_prompt[4]}") + +assert len(sampled_examples) == target_num +assert len(sampled_prompts) == target_num + +backgrounds = [] +for prompt, sampled_example in tqdm(zip(sampled_prompts, sampled_examples), total=target_num): + rich.print(prompt) + try: + background, prompt_full = asyncio.run( + agenerate_env_profile( + model_name="gpt-4-turbo", + inspiration_prompt=prompt, + examples=sampled_example, + temperature=0.5, + ) + ) + except Exception as e: + print(e) + print('error! Skip') + continue + rich.print(prompt_full) + rich.print(background) + backgrounds.append(background) + + pydantics_to_csv("./backgrounds_gpt-4-turbo_jason.csv", backgrounds) \ No newline at end of file diff --git a/data_generate/step2.sh b/data_generate/step2.sh new file mode 100644 index 00000000..2043883a --- /dev/null +++ b/data_generate/step2.sh @@ -0,0 +1,4 @@ +python step2_push_agent_relationship_env_to_db.py ./env_filtered.csv environment +python step2_push_agent_relationship_env_to_db.py ./AgentProfile.csv agent +python step2_push_agent_relationship_env_to_db.py ./RelationshipProfile.csv relationship +python step2_push_agent_relationship_env_to_db.py ./env_filtered.csv agentenvcombo \ No newline at end of file diff --git a/data_generate/step2_push_agent_relationship_env_to_db.py b/data_generate/step2_push_agent_relationship_env_to_db.py new file mode 100644 index 00000000..55ec62ed --- /dev/null +++ b/data_generate/step2_push_agent_relationship_env_to_db.py @@ -0,0 +1,151 @@ +import ast +import sys +from typing import Any, cast + +import pandas as pd +from redis_om import Migrator + +from sotopia.database.persistent_profile import ( + AgentProfile, + EnvironmentProfile, + RelationshipProfile, +) +from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage +from sotopia.samplers import ConstraintBasedSampler +from sotopia.messages import AgentAction, Observation +from sotopia.agents import LLMAgent + + + +def add_agent_to_database(**kwargs: dict[str, Any]) -> None: + agent = AgentProfile(**kwargs) + agent.save() + + +def add_agents_to_database(agents: list[dict[str, Any]]) -> None: + for agent in agents: + add_agent_to_database(**agent) + + +def retrieve_agent_by_first_name(first_name: str) -> AgentProfile: + result = AgentProfile.find(AgentProfile.first_name == first_name).all() + if len(result) == 0: + raise ValueError(f"Agent with first name {first_name} not found") + elif len(result) > 1: + raise ValueError(f"Multiple agents with first name {first_name} found") + else: + assert isinstance(result[0], AgentProfile) + return result[0] + + +def add_env_profile(**kwargs: dict[str, Any]) -> None: + env_profile = EnvironmentProfile(**kwargs) + env_profile.save() + + +def add_env_profiles(env_profiles: list[dict[str, Any]]) -> None: + for env_profile in env_profiles: + add_env_profile(**env_profile) + + +def add_relationship_profile(**kwargs: dict[str, Any]) -> None: + relationship_profile = RelationshipProfile(**kwargs) + relationship_profile.save() + + +def add_relationship_profiles( + relationship_profiles: list[dict[str, Any]] +) -> None: + for relationship_profile in relationship_profiles: + add_relationship_profile(**relationship_profile) + + +def delete_all_agents() -> None: + pks = AgentProfile.all_pks() + pks_list = list(pks) + #for id in pks: + # AgentProfile.delete(id) + + +def delete_all_env_profiles() -> None: + pks = EnvironmentProfile.all_pks() + #for id in pks: + # EnvironmentProfile.delete(id) + + +def delete_all_relationships() -> None: + pks = list(RelationshipProfile.all_pks()) + #for id in pks: + # RelationshipProfile.delete(id) + pks = list(RelationshipProfile.all_pks()) + print("Relationships deleted, all relationships: ", len(list(pks))) + + +def sample_env_agent_combo_and_push_to_db(env_id: str) -> None: + sampler = ConstraintBasedSampler[Observation, AgentAction]( + env_candidates=[env_id] + ) + try: + env_agent_combo_list = list( + sampler.sample(agent_classes=[LLMAgent] * 2, replacement=False) + ) + except: + return + print(len(env_agent_combo_list)) + for env, agent in env_agent_combo_list: + EnvAgentComboStorage( + env_id=env.profile.pk, + agent_ids=[agent[0].profile.pk, agent[1].profile.pk], + ).save() + + +def relationship_map(relationship: str) -> int: + return int(eval(relationship)) + + +if __name__ == "__main__": + assert ( + len(sys.argv) == 3 + ), "Please provide a csv file with agent or environment profiles, and the type of profile (agent or environment)" + df = pd.read_csv(sys.argv[1]) + type = sys.argv[2] + if type == "agent": + agents = cast(list[dict[str, Any]], df.to_dict(orient="records")) + for agent in agents: + agent["age"] = int(agent["age"]) + agent["moral_values"] = agent["moral_values"].split(",") + agent["schwartz_personal_values"] = agent[ + "schwartz_personal_values" + ].split(",") + add_agents_to_database(agents) + Migrator().run() + elif type == "environment": + df = df[ + [ + "codename", + "scenario", + "agent_goals", + "relationship", + "age_constraint", + "occupation_constraint", + "source", + ] + ] + envs = cast(list[dict[str, Any]], df.to_dict(orient="records")) + for env in envs: + env["agent_goals"] = ast.literal_eval(env["agent_goals"]) + assert isinstance(env["relationship"], int) + add_env_profiles(envs) + Migrator().run() + elif type == "relationship": + relationships = cast( + list[dict[str, Any]], df.to_dict(orient="records") + ) + for relationship in relationships: + assert isinstance(relationship["relationship"], int) + add_relationship_profiles(relationships) + Migrator().run() + elif type == 'agentenvcombo': + env_ids = list(EnvironmentProfile.all_pks()) + for env_id in env_ids: + sample_env_agent_combo_and_push_to_db(env_id) \ No newline at end of file diff --git a/data_generate/step3_convert_db_into_json.py b/data_generate/step3_convert_db_into_json.py new file mode 100644 index 00000000..778ab9a8 --- /dev/null +++ b/data_generate/step3_convert_db_into_json.py @@ -0,0 +1,31 @@ +from rejson import Client, Path +import json + +redis_host = 'localhost' +redis_port = 6379 +redis_password = '' + +rj = Client(host=redis_host, port=redis_port, password=redis_password, decode_responses=True) + +def get_redisjson_value(key): + try: + return rj.jsonget(key, Path.rootPath()) + except Exception as e: + print(f"Could not retrieve JSON for key {key}: {e}") + return None + +cursor = '0' +all_json_data = {} +while cursor != 0: + cursor, keys = rj.scan(cursor=cursor, match='*') + for key in keys: + key_type = rj.type(key) + if key_type == 'ReJSON-RL': + json_value = get_redisjson_value(key) + if json_value is not None: + all_json_data[key] = json_value + else: + print(f"Key {key} is not of type ReJSON-RL, it's type is {key_type}") + +with open('redis_json_data.json', 'w') as f: + json.dump(all_json_data, f, indent=4) \ No newline at end of file diff --git a/data_generate/step4_convert_json_to_gen_input.py b/data_generate/step4_convert_json_to_gen_input.py new file mode 100644 index 00000000..736ddcb9 --- /dev/null +++ b/data_generate/step4_convert_json_to_gen_input.py @@ -0,0 +1,146 @@ +import json +import jsonlines + +format_instruction = 'Your available action types are\n\"none action speak non-verbal communication leave\".\nNote: You can \"leave\" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.\n\nPlease only generate a JSON string including the action type and the argument.\nYour action should follow the given format:\n\nAs an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]}\nthe object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance of the schema. The object {\"properties\": {\"foo\": [\"bar\", \"baz\"]}} is not well-formatted.\n\nHere is the output schema:\n\\n{\\"description\\": \\"An interface for messages.\\\\nThere is only one required method: to_natural_language\\", \\"properties\\": {\\"action_type\\": {\\"title\\": \\"Action Type\\", \\"description\\": \\"whether to speak at this turn or choose to not do anything\\", \\"enum\\": [\\"none\\", \\"speak\\", \\"non-verbal communication\\", \\"action\\", \\"leave\\"], \\"type\\": \\"string\\"}, \\"argument\\": {\\"title\\": \\"Argument\\", \\"description\\": \\"the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action\\", \\"type\\": \\"string\\"}}, \\"required\\": [\\"action_type\\", \\"argument\\"]}\\n\u001b[0m\n' + + +def get_agent_info(agent1_pk, agent2_pk, env_pk, agent_dict, env_dict): + agent1_name = agent_dict[agent1_pk]['first_name'] + ' ' + agent_dict[agent1_pk]['last_name'] + agent2_name = agent_dict[agent2_pk]['first_name'] + ' ' + agent_dict[agent2_pk]['last_name'] + + agent1_age = agent_dict[agent1_pk]['age'] + agent2_age = agent_dict[agent2_pk]['age'] + + agent1_occupation = agent_dict[agent1_pk]['occupation'] + agent2_occupation = agent_dict[agent2_pk]['occupation'] + + if agent_dict[agent1_pk]['gender'] == 'Man': + agent1_gender = 'male' + elif agent_dict[agent1_pk]['gender'] == 'Woman': + agent1_gender = 'female' + elif agent_dict[agent1_pk]['gender'] == 'Nonbinary': + agent1_gender = 'nonbinary' + + # agent2 the same + if agent_dict[agent2_pk]['gender'] == 'Man': + agent2_gender = 'male' + elif agent_dict[agent2_pk]['gender'] == 'Woman': + agent2_gender = 'female' + elif agent_dict[agent2_pk]['gender'] == 'Nonbinary': + agent2_gender = 'nonbinary' + + agent1_public_info = agent_dict[agent1_pk]['public_info'] + agent2_public_info = agent_dict[agent2_pk]['public_info'] + + agent1_personality_and_values = agent_dict[agent1_pk]['personality_and_values'] + agent2_personality_and_values = agent_dict[agent2_pk]['personality_and_values'] + + agent1_secret = agent_dict[agent1_pk]['secret'] + agent2_secret = agent_dict[agent2_pk]['secret'] + + agent1_goal = env_dict[env_pk]['agent_goals'][0].replace('', '') + agent2_goal = env_dict[env_pk]['agent_goals'][1].replace('', '') + + agent1_info = { + 'agent_name': agent1_name, + 'agent_age': agent1_age, + 'agent_occupation': agent1_occupation, + 'agent_gender': agent1_gender, + 'agent_public_info': agent1_public_info, + 'agent_personality_and_values': agent1_personality_and_values, + 'agent_secret': agent1_secret, + 'agent_goal': agent1_goal, + } + + agent2_info = { + 'agent_name': agent2_name, + 'agent_age': agent2_age, + 'agent_occupation': agent2_occupation, + 'agent_gender': agent2_gender, + 'agent_public_info': agent2_public_info, + 'agent_personality_and_values': agent2_personality_and_values, + 'agent_secret': agent2_secret, + 'agent_goal': agent2_goal, + } + return agent1_info, agent2_info + + +def fill_template(agent1_info, agent2_info, scenario): + # Assuming the scenario is a string that is passed to the function + # Gender pronouns are typically 'he/him', 'she/her', 'they/them', etc. + # I'm adding placeholders for these pronouns; you'll need to replace them with actual values. + agent1_pronoun = "their" # Replace with actual pronoun + agent2_pronoun = "their" # Replace with actual pronoun + + prompt_template = ( + "Prompt after formatting:\n" + "Imagine you are {agent1_name}, your task is to act/speak as {agent1_name} would, " + "keeping in mind {agent1_name}s social goal.\n" + "You can find {agent1_name}'s background and goal in the 'Here is the context of the interaction' field.\n" + "Note that {agent1_name}'s secret and goal is only visible to you.\n" + "You should try your best to achieve {agent1_name}'s goal in a way that align with their character traits.\n" + "Additionally, maintaining the conversation's naturalness and realism is essential " + "(e.g., do not repeat what other people has already said before).\n\n" + "Here is the context of this interaction:\n" + "Scenario: {scenario}\n" + "{agent1_name}'s background: {agent1_name} is a {agent1_age}-year-old {agent1_gender} {agent1_occupation}. " + "{agent1_pronoun} pronouns. {agent1_public_info} " + "Personality and values description: {agent1_personality_and_values} " + "{agent1_name}'s secrets: {agent1_secret}\n" + "{agent2_name}'s goal: Unknown\n" + "{agent1_name}'s goal: {agent1_goal}\n" + "Conversation Starts:\n.\nYou are at Turn #0." + ) + + prompt = prompt_template.format( + agent1_name=agent1_info['agent_name'], + agent1_age=agent1_info['agent_age'], + agent1_gender=agent1_info['agent_gender'], + agent1_occupation=agent1_info['agent_occupation'], + agent1_pronoun=agent1_pronoun, + agent1_public_info=agent1_info['agent_public_info'], + agent1_personality_and_values=agent1_info['agent_personality_and_values'], + agent1_secret=agent1_info['agent_secret'], + agent1_goal=agent1_info['agent_goal'], + agent2_name=agent2_info['agent_name'], + agent2_age=agent2_info['agent_age'], + agent2_gender=agent2_info['agent_gender'], + agent2_occupation=agent2_info['agent_occupation'], + agent2_pronoun=agent2_pronoun, + agent2_public_info=agent2_info['agent_public_info'], + agent2_personality_and_values=agent2_info['agent_personality_and_values'], + scenario=scenario + ) + + return prompt + format_instruction + + + + +with open('redis_json_data.json', 'r') as f: + all_json_data = json.load(f) + +agent_dict = {} +env_dict = {} +for key, data in all_json_data.items(): + if 'AgentProfile' in key: + agent_dict[data['pk']] = data + if 'EnvironmentProfile' in key: + env_dict[data['pk']] = data + +full_prompts = [] +for key, data in all_json_data.items(): + #if data['pk'] != "01HER590MH0W1TPCPYKCAWMNXW": + # continue + if 'EnvAgentComboStorage' in key: + env_id = data['env_id'] + agent_ids = data['agent_ids'] + agent1_info, agent2_info = get_agent_info(agent_ids[0], agent_ids[1], env_id, agent_dict, env_dict) + full_prompt = fill_template(agent1_info, agent2_info, env_dict[env_id]['scenario']) + full_prompts.append({'text': full_prompt}) + full_prompt = fill_template(agent2_info, agent1_info, env_dict[env_id]['scenario']) + full_prompts.append({'text': full_prompt}) + +print('Total number of prompts: ', len(full_prompts)) +with jsonlines.open('full_prompts.jsonl', 'w') as writer: + writer.write_all(full_prompts) diff --git a/data_generate/test_redis.py b/data_generate/test_redis.py new file mode 100644 index 00000000..34c18a52 --- /dev/null +++ b/data_generate/test_redis.py @@ -0,0 +1,10 @@ +import redis + +r = redis.Redis( + host='us1-normal-burro-37804.upstash.io', + port=37804, + password='a870a438f928424bb507d5895b3ab3fc' +) + +r.set('foo', 'bar') +print(r.get('foo')) \ No newline at end of file diff --git a/data_generate/test_redisjson.py b/data_generate/test_redisjson.py new file mode 100644 index 00000000..06c691c9 --- /dev/null +++ b/data_generate/test_redisjson.py @@ -0,0 +1,11 @@ +from redis_om import JsonModel, get_redis_connection + +class Person(JsonModel): + name: str + age: int + +# Create an instance of your model +person = Person(name="John", age=30) + +# Save to Redis with a specific key +person.save() \ No newline at end of file