diff --git a/README.md b/README.md
index 07b2e92a..ee4e743f 100644
--- a/README.md
+++ b/README.md
@@ -6,4 +6,5 @@ We split our overall framework into multiple parts
2. Together AI Finetuning --> Input the train and test data / Output model checkpoint
3. LLM Finetuning --> Input the train and test data / Output model checkpoint
4. LLM Deplyment --> Input LLM Finetuned model checkpoint / Output Deployable OpenAI type API
-5. Eval --> Input model checkpoint / Output evaluation scores
\ No newline at end of file
+5. Eval --> Input model checkpoint / Output evaluation scores
+6. Generate --> Input None / Output new data on redis
\ No newline at end of file
diff --git a/llm_generate/README.md b/llm_generate/README.md
new file mode 100644
index 00000000..869e5537
--- /dev/null
+++ b/llm_generate/README.md
@@ -0,0 +1,9 @@
+# Data Generation
+
+For the first step, we generate envProfile (including scenario / social goal / relationship restriction) based on inspiring prompt.
+
+For the second step, we put the original agentProfile and relationshipProfile into our new redis database
+
+For the third step, we combine them together to be combos based on conditiona sampling (the restriction is the relationship)
+
+All the EnvProfile (new generated), AgentProfile (sotopia original), RelationshipProfile (sotopia original), and envagentcombo are on the redis database that is new created.
\ No newline at end of file
diff --git a/llm_generate/generate_specific_envs.py b/llm_generate/generate_specific_envs.py
new file mode 100644
index 00000000..94ab8014
--- /dev/null
+++ b/llm_generate/generate_specific_envs.py
@@ -0,0 +1,135 @@
+"""This file is used to generate specific environments based on existing
+datasets. The generation functions below should call agenerate_env_profile
+in `sotopia/generation_utils/generate.py` with the appropriate parameters.
+Here are the datasets we have so far:
+1. Mutual-Friend (https://huggingface.co/datasets/mutual_friends)
+"""
+import asyncio
+from typing import Hashable
+
+import datasets
+import names
+import numpy as np
+from datasets import DatasetDict, load_dataset
+
+from generate import (
+ ListOfStrOutputParser,
+ StrOutputParser,
+ agenerate,
+ generate,
+)
+
+
+async def generate_mutual_friend_envs() -> tuple[str, list[str]]:
+ """Generate environments based on the mutual-friend dataset."""
+ mutual_friend_dataset: DatasetDict = load_dataset("mutual_friends")
+ all_data = mutual_friend_dataset["train"]
+ # sample one datum from all data
+ datum = np.random.choice(all_data)
+ friends = datum["scenario_kbs"]
+ num_of_friends_in_total = sum(map(len, friends))
+ # generate names for the friends
+ set_of_names = set()
+ for _ in range(num_of_friends_in_total):
+ name = names.get_first_name()
+ while name in set_of_names:
+ name = names.get_first_name()
+ set_of_names.add(name)
+ list_of_names = list(set_of_names)
+ friend_map: dict[tuple[str, ...], str] = {}
+ friend_list_map: list[list[str]] = [[] for _ in range(len(friends))]
+ friend_description_keys: list[str] = datum["scenario_attributes"]["name"]
+ name_pointer = 0
+ for i, friends_array in enumerate(friends):
+ for friend in friends_array:
+ assert (
+ len(friend) == 2
+ ) # in [[key1, key2, ...], [value1, value2, ...]] format
+ if not tuple(friend[1]) in friend_map:
+ friend_map[tuple(friend[1])] = list_of_names[name_pointer]
+ name_pointer += 1
+ friend_list_map[i].append(friend_map[tuple(friend[1])])
+ friend_set_map: list[set[str]] = [
+ set(friend_list) for friend_list in friend_list_map
+ ]
+ common_friends = []
+ for friend_description, friend_name in friend_map.items():
+ if all([friend_name in friend_set for friend_set in friend_set_map]):
+ common_friends.append(friend_name)
+ scenario = (
+ f'{len(friends)} strangers are meeting at a party.
They have {len(common_friends)} common friends: '
+ f"{', '.join(common_friends[:-1])}"
+ + (" and " if len(common_friends) > 1 else "")
+ + common_friends[-1]
+ + ".
"
+ )
+ goals: list[str] = []
+ for friends_array in friends:
+ template = f"You are trying to figure out whether you have a mutual friend with the other person. \n"
+ template += f" You know the following friends"
+ for friend in friends_array:
+ friend_name = friend_map[tuple(friend[1])]
+ friend_description = friend[1]
+ template += f" {friend_name}: {' '.join([(i + ': ' + j + ' ') if i != 'Name' else '' for i, j in zip(friend[0], friend_description)])}\n"
+ template += f""
+ goals.append(template)
+
+ return scenario, goals
+
+
+async def generate_craigslist_bargains_envs() -> tuple[str, list[str]]:
+ """Generate environments based on the craigslist_bargains dataset."""
+ craigslist_bargains_dataset: DatasetDict = load_dataset(
+ "craigslist_bargains"
+ )
+ all_data = craigslist_bargains_dataset["train"]
+ # sample one datum from all data
+ datum = np.random.choice(all_data)
+ scenario = generate(
+ model_name="gpt-4",
+ template="The following sentence is automatically generated with the following"
+ 'template: "One person is selling - for , another person is'
+ 'trying to buy it. Here is the description of the item: ." with item = {title}, '
+ "price={price}, and description={description} Please make the sentence"
+ "fluent and natural.",
+ input_values={
+ "title": datum["items"]["Title"][0],
+ "price": datum["items"]["Price"][0],
+ "description": datum["items"]["Description"][0],
+ },
+ output_parser=StrOutputParser(),
+ )
+
+ goals: list[str] = []
+ for i in range(2):
+ if datum["agent_info"]["Role"][i] == "seller":
+ markup_ratio = np.random.exponential(0.5)
+ datum["agent_info"]["Target"][i] = datum["items"]["Price"][0] / (
+ 1 + markup_ratio
+ )
+ goal = generate(
+ model_name="gpt-4",
+ template="The following sentence is automatically generated with the following"
+ 'template: "You want to this item. Your target price '
+ "is $ (round up to two decimals). You will get penalty if you sell or buy it "
+ "for a price that is significantly lower than (if is seller) or significantly"
+ "higher than (if is buyer) the target price, but will get bonus if you successfully "
+ "sell it higher than the target price (if is seller) or buy it for lower than"
+ 'the target price (if is buyer)." '
+ "with role = {role} and price = {price}. Please make the sentence"
+ "fluent and natural. Do not change the original meaning of the sentence.",
+ input_values={
+ "role": datum["agent_info"]["Role"][i],
+ "price": datum["agent_info"]["Target"][i],
+ },
+ output_parser=StrOutputParser(),
+ )
+ goals.append(goal)
+
+ return scenario, goals
+
+
+if __name__ == '__main__':
+ for i in range(10):
+ scenario, goals = asyncio.run(generate_mutual_friend_envs())
+ import pdb; pdb.set_trace()
\ No newline at end of file
diff --git a/llm_generate/requirments.txt b/llm_generate/requirments.txt
new file mode 100644
index 00000000..804f2e66
--- /dev/null
+++ b/llm_generate/requirments.txt
@@ -0,0 +1 @@
+sotopia
\ No newline at end of file
diff --git a/llm_generate/step1_generate_env_profile.py b/llm_generate/step1_generate_env_profile.py
new file mode 100644
index 00000000..4e9f9824
--- /dev/null
+++ b/llm_generate/step1_generate_env_profile.py
@@ -0,0 +1,51 @@
+import asyncio
+import random
+from typing import TypeVar
+from tqdm import tqdm
+
+import pandas as pd
+import rich
+from pydantic import BaseModel
+
+from sotopia.database import EnvironmentProfile
+from sotopia.generation_utils.generate import agenerate_env_profile
+
+random.seed(41)
+
+env_borrowMoney = EnvironmentProfile.find(
+ EnvironmentProfile.codename == "borrow_money"
+).all()[0]
+env_roadtrip = EnvironmentProfile.find(
+ EnvironmentProfile.codename == "take_turns"
+).all()[0]
+env_prisonerDillema = EnvironmentProfile.find(
+ EnvironmentProfile.codename == "prison_dilemma"
+).all()[0]
+
+examples = f"{env_borrowMoney.json()}\n\n{env_roadtrip.json()}\n\n{env_prisonerDillema.json()}"
+
+ins_prompts = pd.read_csv("./inspirational_prompt_for_env.csv")
+prompts = ins_prompts["prompt"].tolist()
+
+T = TypeVar("T", bound=BaseModel)
+
+
+def pydantics_to_csv(filename: str, data: list[T]) -> None:
+ pd.DataFrame([item.dict() for item in data]).to_csv(filename, index=False)
+
+
+backgrounds = []
+for prompt in tqdm(prompts):
+ rich.print(prompt)
+ background, prompt_full = asyncio.run(
+ agenerate_env_profile(
+ model_name="gpt-4",
+ inspiration_prompt=prompt,
+ examples=examples,
+ )
+ )
+ rich.print(background)
+ rich.print(prompt_full)
+ backgrounds.append(background)
+
+ pydantics_to_csv("./backgrounds.csv", backgrounds)
\ No newline at end of file
diff --git a/llm_generate/step2_push_agent_relationship_env_to_db.py b/llm_generate/step2_push_agent_relationship_env_to_db.py
new file mode 100644
index 00000000..7a8fe8e8
--- /dev/null
+++ b/llm_generate/step2_push_agent_relationship_env_to_db.py
@@ -0,0 +1,150 @@
+import ast
+import sys
+from typing import Any, cast
+
+import pandas as pd
+from redis_om import Migrator
+
+from sotopia.database.persistent_profile import (
+ AgentProfile,
+ EnvironmentProfile,
+ RelationshipProfile,
+)
+from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage
+from sotopia.samplers import ConstraintBasedSampler
+from sotopia.messages import AgentAction, Observation
+from sotopia.agents import LLMAgent
+
+
+
+def add_agent_to_database(**kwargs: dict[str, Any]) -> None:
+ agent = AgentProfile(**kwargs)
+ agent.save()
+
+
+def add_agents_to_database(agents: list[dict[str, Any]]) -> None:
+ for agent in agents:
+ add_agent_to_database(**agent)
+
+
+def retrieve_agent_by_first_name(first_name: str) -> AgentProfile:
+ result = AgentProfile.find(AgentProfile.first_name == first_name).all()
+ if len(result) == 0:
+ raise ValueError(f"Agent with first name {first_name} not found")
+ elif len(result) > 1:
+ raise ValueError(f"Multiple agents with first name {first_name} found")
+ else:
+ assert isinstance(result[0], AgentProfile)
+ return result[0]
+
+
+def add_env_profile(**kwargs: dict[str, Any]) -> None:
+ env_profile = EnvironmentProfile(**kwargs)
+ env_profile.save()
+
+
+def add_env_profiles(env_profiles: list[dict[str, Any]]) -> None:
+ for env_profile in env_profiles:
+ add_env_profile(**env_profile)
+
+
+def add_relationship_profile(**kwargs: dict[str, Any]) -> None:
+ relationship_profile = RelationshipProfile(**kwargs)
+ relationship_profile.save()
+
+
+def add_relationship_profiles(
+ relationship_profiles: list[dict[str, Any]]
+) -> None:
+ for relationship_profile in relationship_profiles:
+ add_relationship_profile(**relationship_profile)
+
+
+def delete_all_agents() -> None:
+ pks = AgentProfile.all_pks()
+ pks_list = list(pks)
+ for id in pks:
+ AgentProfile.delete(id)
+
+
+def delete_all_env_profiles() -> None:
+ pks = EnvironmentProfile.all_pks()
+ #for id in pks:
+ # EnvironmentProfile.delete(id)
+
+
+def delete_all_relationships() -> None:
+ pks = list(RelationshipProfile.all_pks())
+ #for id in pks:
+ # RelationshipProfile.delete(id)
+ pks = list(RelationshipProfile.all_pks())
+ print("Relationships deleted, all relationships: ", len(list(pks)))
+
+
+def sample_env_agent_combo_and_push_to_db(env_id: str) -> None:
+ sampler = ConstraintBasedSampler[Observation, AgentAction](
+ env_candidates=[env_id]
+ )
+ try:
+ env_agent_combo_list = list(
+ sampler.sample(agent_classes=[LLMAgent] * 2, replacement=False)
+ )
+ except:
+ return
+ print(len(env_agent_combo_list))
+ for env, agent in env_agent_combo_list:
+ EnvAgentComboStorage(
+ env_id=env.profile.pk,
+ agent_ids=[agent[0].profile.pk, agent[1].profile.pk],
+ ).save()
+
+
+def relationship_map(relationship: str) -> int:
+ return int(eval(relationship))
+
+
+if __name__ == "__main__":
+ assert (
+ len(sys.argv) == 3
+ ), "Please provide a csv file with agent or environment profiles, and the type of profile (agent or environment)"
+ df = pd.read_csv(sys.argv[1])
+ type = sys.argv[2]
+ if type == "agent":
+ agents = cast(list[dict[str, Any]], df.to_dict(orient="records"))
+ for agent in agents:
+ agent["age"] = int(agent["age"])
+ agent["moral_values"] = agent["moral_values"].split(",")
+ agent["schwartz_personal_values"] = agent[
+ "schwartz_personal_values"
+ ].split(",")
+ add_agents_to_database(agents)
+ Migrator().run()
+ elif type == "environment":
+ df = df[
+ [
+ "codename",
+ "scenario",
+ "agent_goals",
+ "relationship",
+ "age_constraint",
+ "occupation_constraint",
+ "source",
+ ]
+ ]
+ envs = cast(list[dict[str, Any]], df.to_dict(orient="records"))
+ for env in envs:
+ env["agent_goals"] = ast.literal_eval(env["agent_goals"])
+ assert isinstance(env["relationship"], int)
+ Migrator().run()
+ elif type == "relationship":
+ relationships = cast(
+ list[dict[str, Any]], df.to_dict(orient="records")
+ )
+ for relationship in relationships:
+ assert isinstance(relationship["relationship"], int)
+ add_relationship_profiles(relationships)
+ Migrator().run()
+ elif type == 'agentenvcombo':
+ env_ids = list(EnvironmentProfile.all_pks())
+ for env_id in env_ids:
+ sample_env_agent_combo_and_push_to_db(env_id)
\ No newline at end of file
diff --git a/llm_generate/test_redis1.py b/llm_generate/test_redis1.py
new file mode 100644
index 00000000..34c18a52
--- /dev/null
+++ b/llm_generate/test_redis1.py
@@ -0,0 +1,10 @@
+import redis
+
+r = redis.Redis(
+ host='us1-normal-burro-37804.upstash.io',
+ port=37804,
+ password='a870a438f928424bb507d5895b3ab3fc'
+)
+
+r.set('foo', 'bar')
+print(r.get('foo'))
\ No newline at end of file
diff --git a/llm_generate/test_redis2.py b/llm_generate/test_redis2.py
new file mode 100644
index 00000000..06c691c9
--- /dev/null
+++ b/llm_generate/test_redis2.py
@@ -0,0 +1,11 @@
+from redis_om import JsonModel, get_redis_connection
+
+class Person(JsonModel):
+ name: str
+ age: int
+
+# Create an instance of your model
+person = Person(name="John", age=30)
+
+# Save to Redis with a specific key
+person.save()
\ No newline at end of file