-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* support inference on the whole dataset * add initial code for scenario and social goal generation * modify readme * add step1 and step2 but still not correct * add a test * support generating env and match it with existing agents to be a combo * add readme (cherry picked from commit 09adb35)
- Loading branch information
Showing
8 changed files
with
369 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Data Generation | ||
|
||
For the first step, we generate envProfile (including scenario / social goal / relationship restriction) based on inspiring prompt. | ||
|
||
For the second step, we put the original agentProfile and relationshipProfile into our new redis database | ||
|
||
For the third step, we combine them together to be combos based on conditiona sampling (the restriction is the relationship) | ||
|
||
All the EnvProfile (new generated), AgentProfile (sotopia original), RelationshipProfile (sotopia original), and envagentcombo are on the redis database that is new created. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
"""This file is used to generate specific environments based on existing | ||
datasets. The generation functions below should call agenerate_env_profile | ||
in `sotopia/generation_utils/generate.py` with the appropriate parameters. | ||
Here are the datasets we have so far: | ||
1. Mutual-Friend (https://huggingface.co/datasets/mutual_friends) | ||
""" | ||
import asyncio | ||
from typing import Hashable | ||
|
||
import datasets | ||
import names | ||
import numpy as np | ||
from datasets import DatasetDict, load_dataset | ||
|
||
from generate import ( | ||
ListOfStrOutputParser, | ||
StrOutputParser, | ||
agenerate, | ||
generate, | ||
) | ||
|
||
|
||
async def generate_mutual_friend_envs() -> tuple[str, list[str]]: | ||
"""Generate environments based on the mutual-friend dataset.""" | ||
mutual_friend_dataset: DatasetDict = load_dataset("mutual_friends") | ||
all_data = mutual_friend_dataset["train"] | ||
# sample one datum from all data | ||
datum = np.random.choice(all_data) | ||
friends = datum["scenario_kbs"] | ||
num_of_friends_in_total = sum(map(len, friends)) | ||
# generate names for the friends | ||
set_of_names = set() | ||
for _ in range(num_of_friends_in_total): | ||
name = names.get_first_name() | ||
while name in set_of_names: | ||
name = names.get_first_name() | ||
set_of_names.add(name) | ||
list_of_names = list(set_of_names) | ||
friend_map: dict[tuple[str, ...], str] = {} | ||
friend_list_map: list[list[str]] = [[] for _ in range(len(friends))] | ||
friend_description_keys: list[str] = datum["scenario_attributes"]["name"] | ||
name_pointer = 0 | ||
for i, friends_array in enumerate(friends): | ||
for friend in friends_array: | ||
assert ( | ||
len(friend) == 2 | ||
) # in [[key1, key2, ...], [value1, value2, ...]] format | ||
if not tuple(friend[1]) in friend_map: | ||
friend_map[tuple(friend[1])] = list_of_names[name_pointer] | ||
name_pointer += 1 | ||
friend_list_map[i].append(friend_map[tuple(friend[1])]) | ||
friend_set_map: list[set[str]] = [ | ||
set(friend_list) for friend_list in friend_list_map | ||
] | ||
common_friends = [] | ||
for friend_description, friend_name in friend_map.items(): | ||
if all([friend_name in friend_set for friend_set in friend_set_map]): | ||
common_friends.append(friend_name) | ||
scenario = ( | ||
f'{len(friends)} strangers are meeting at a party. <p viewer="environment">They have {len(common_friends)} common friends: ' | ||
f"{', '.join(common_friends[:-1])}" | ||
+ (" and " if len(common_friends) > 1 else "") | ||
+ common_friends[-1] | ||
+ ".</p>" | ||
) | ||
goals: list[str] = [] | ||
for friends_array in friends: | ||
template = f"You are trying to figure out whether you have a mutual friend with the other person. \n" | ||
template += f"<extra_info> You know the following friends" | ||
for friend in friends_array: | ||
friend_name = friend_map[tuple(friend[1])] | ||
friend_description = friend[1] | ||
template += f" {friend_name}: {' '.join([(i + ': ' + j + ' ') if i != 'Name' else '' for i, j in zip(friend[0], friend_description)])}\n" | ||
template += f"</extra_info>" | ||
goals.append(template) | ||
|
||
return scenario, goals | ||
|
||
|
||
async def generate_craigslist_bargains_envs() -> tuple[str, list[str]]: | ||
"""Generate environments based on the craigslist_bargains dataset.""" | ||
craigslist_bargains_dataset: DatasetDict = load_dataset( | ||
"craigslist_bargains" | ||
) | ||
all_data = craigslist_bargains_dataset["train"] | ||
# sample one datum from all data | ||
datum = np.random.choice(all_data) | ||
scenario = generate( | ||
model_name="gpt-4", | ||
template="The following sentence is automatically generated with the following" | ||
'template: "One person is selling <item> for <price>, another person is' | ||
'trying to buy it. Here is the description of the item: <description>." with item = {title}, ' | ||
"price={price}, and description={description} Please make the sentence" | ||
"fluent and natural.", | ||
input_values={ | ||
"title": datum["items"]["Title"][0], | ||
"price": datum["items"]["Price"][0], | ||
"description": datum["items"]["Description"][0], | ||
}, | ||
output_parser=StrOutputParser(), | ||
) | ||
|
||
goals: list[str] = [] | ||
for i in range(2): | ||
if datum["agent_info"]["Role"][i] == "seller": | ||
markup_ratio = np.random.exponential(0.5) | ||
datum["agent_info"]["Target"][i] = datum["items"]["Price"][0] / ( | ||
1 + markup_ratio | ||
) | ||
goal = generate( | ||
model_name="gpt-4", | ||
template="The following sentence is automatically generated with the following" | ||
'template: "You want to <role> this item. Your target price ' | ||
"is $<price> (round up to two decimals). You will get penalty if you sell or buy it " | ||
"for a price that is significantly lower than (if <role> is seller) or significantly" | ||
"higher than (if <role> is buyer) the target price, but will get bonus if you successfully " | ||
"sell it higher than the target price (if <role> is seller) or buy it for lower than" | ||
'the target price (if <role> is buyer)." ' | ||
"with role = {role} and price = {price}. Please make the sentence" | ||
"fluent and natural. Do not change the original meaning of the sentence.", | ||
input_values={ | ||
"role": datum["agent_info"]["Role"][i], | ||
"price": datum["agent_info"]["Target"][i], | ||
}, | ||
output_parser=StrOutputParser(), | ||
) | ||
goals.append(goal) | ||
|
||
return scenario, goals | ||
|
||
|
||
if __name__ == '__main__': | ||
for i in range(10): | ||
scenario, goals = asyncio.run(generate_mutual_friend_envs()) | ||
import pdb; pdb.set_trace() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
sotopia |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import asyncio | ||
import random | ||
from typing import TypeVar | ||
from tqdm import tqdm | ||
|
||
import pandas as pd | ||
import rich | ||
from pydantic import BaseModel | ||
|
||
from sotopia.database import EnvironmentProfile | ||
from sotopia.generation_utils.generate import agenerate_env_profile | ||
|
||
random.seed(41) | ||
|
||
env_borrowMoney = EnvironmentProfile.find( | ||
EnvironmentProfile.codename == "borrow_money" | ||
).all()[0] | ||
env_roadtrip = EnvironmentProfile.find( | ||
EnvironmentProfile.codename == "take_turns" | ||
).all()[0] | ||
env_prisonerDillema = EnvironmentProfile.find( | ||
EnvironmentProfile.codename == "prison_dilemma" | ||
).all()[0] | ||
|
||
examples = f"{env_borrowMoney.json()}\n\n{env_roadtrip.json()}\n\n{env_prisonerDillema.json()}" | ||
|
||
ins_prompts = pd.read_csv("./inspirational_prompt_for_env.csv") | ||
prompts = ins_prompts["prompt"].tolist() | ||
|
||
T = TypeVar("T", bound=BaseModel) | ||
|
||
|
||
def pydantics_to_csv(filename: str, data: list[T]) -> None: | ||
pd.DataFrame([item.dict() for item in data]).to_csv(filename, index=False) | ||
|
||
|
||
backgrounds = [] | ||
for prompt in tqdm(prompts): | ||
rich.print(prompt) | ||
background, prompt_full = asyncio.run( | ||
agenerate_env_profile( | ||
model_name="gpt-4", | ||
inspiration_prompt=prompt, | ||
examples=examples, | ||
) | ||
) | ||
rich.print(background) | ||
rich.print(prompt_full) | ||
backgrounds.append(background) | ||
|
||
pydantics_to_csv("./backgrounds.csv", backgrounds) |
150 changes: 150 additions & 0 deletions
150
llm_generate/step2_push_agent_relationship_env_to_db.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import ast | ||
import sys | ||
from typing import Any, cast | ||
|
||
import pandas as pd | ||
from redis_om import Migrator | ||
|
||
from sotopia.database.persistent_profile import ( | ||
AgentProfile, | ||
EnvironmentProfile, | ||
RelationshipProfile, | ||
) | ||
from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage | ||
from sotopia.samplers import ConstraintBasedSampler | ||
from sotopia.messages import AgentAction, Observation | ||
from sotopia.agents import LLMAgent | ||
|
||
|
||
|
||
def add_agent_to_database(**kwargs: dict[str, Any]) -> None: | ||
agent = AgentProfile(**kwargs) | ||
agent.save() | ||
|
||
|
||
def add_agents_to_database(agents: list[dict[str, Any]]) -> None: | ||
for agent in agents: | ||
add_agent_to_database(**agent) | ||
|
||
|
||
def retrieve_agent_by_first_name(first_name: str) -> AgentProfile: | ||
result = AgentProfile.find(AgentProfile.first_name == first_name).all() | ||
if len(result) == 0: | ||
raise ValueError(f"Agent with first name {first_name} not found") | ||
elif len(result) > 1: | ||
raise ValueError(f"Multiple agents with first name {first_name} found") | ||
else: | ||
assert isinstance(result[0], AgentProfile) | ||
return result[0] | ||
|
||
|
||
def add_env_profile(**kwargs: dict[str, Any]) -> None: | ||
env_profile = EnvironmentProfile(**kwargs) | ||
env_profile.save() | ||
|
||
|
||
def add_env_profiles(env_profiles: list[dict[str, Any]]) -> None: | ||
for env_profile in env_profiles: | ||
add_env_profile(**env_profile) | ||
|
||
|
||
def add_relationship_profile(**kwargs: dict[str, Any]) -> None: | ||
relationship_profile = RelationshipProfile(**kwargs) | ||
relationship_profile.save() | ||
|
||
|
||
def add_relationship_profiles( | ||
relationship_profiles: list[dict[str, Any]] | ||
) -> None: | ||
for relationship_profile in relationship_profiles: | ||
add_relationship_profile(**relationship_profile) | ||
|
||
|
||
def delete_all_agents() -> None: | ||
pks = AgentProfile.all_pks() | ||
pks_list = list(pks) | ||
for id in pks: | ||
AgentProfile.delete(id) | ||
|
||
|
||
def delete_all_env_profiles() -> None: | ||
pks = EnvironmentProfile.all_pks() | ||
#for id in pks: | ||
# EnvironmentProfile.delete(id) | ||
|
||
|
||
def delete_all_relationships() -> None: | ||
pks = list(RelationshipProfile.all_pks()) | ||
#for id in pks: | ||
# RelationshipProfile.delete(id) | ||
pks = list(RelationshipProfile.all_pks()) | ||
print("Relationships deleted, all relationships: ", len(list(pks))) | ||
|
||
|
||
def sample_env_agent_combo_and_push_to_db(env_id: str) -> None: | ||
sampler = ConstraintBasedSampler[Observation, AgentAction]( | ||
env_candidates=[env_id] | ||
) | ||
try: | ||
env_agent_combo_list = list( | ||
sampler.sample(agent_classes=[LLMAgent] * 2, replacement=False) | ||
) | ||
except: | ||
return | ||
print(len(env_agent_combo_list)) | ||
for env, agent in env_agent_combo_list: | ||
EnvAgentComboStorage( | ||
env_id=env.profile.pk, | ||
agent_ids=[agent[0].profile.pk, agent[1].profile.pk], | ||
).save() | ||
|
||
|
||
def relationship_map(relationship: str) -> int: | ||
return int(eval(relationship)) | ||
|
||
|
||
if __name__ == "__main__": | ||
assert ( | ||
len(sys.argv) == 3 | ||
), "Please provide a csv file with agent or environment profiles, and the type of profile (agent or environment)" | ||
df = pd.read_csv(sys.argv[1]) | ||
type = sys.argv[2] | ||
if type == "agent": | ||
agents = cast(list[dict[str, Any]], df.to_dict(orient="records")) | ||
for agent in agents: | ||
agent["age"] = int(agent["age"]) | ||
agent["moral_values"] = agent["moral_values"].split(",") | ||
agent["schwartz_personal_values"] = agent[ | ||
"schwartz_personal_values" | ||
].split(",") | ||
add_agents_to_database(agents) | ||
Migrator().run() | ||
elif type == "environment": | ||
df = df[ | ||
[ | ||
"codename", | ||
"scenario", | ||
"agent_goals", | ||
"relationship", | ||
"age_constraint", | ||
"occupation_constraint", | ||
"source", | ||
] | ||
] | ||
envs = cast(list[dict[str, Any]], df.to_dict(orient="records")) | ||
for env in envs: | ||
env["agent_goals"] = ast.literal_eval(env["agent_goals"]) | ||
assert isinstance(env["relationship"], int) | ||
Migrator().run() | ||
elif type == "relationship": | ||
relationships = cast( | ||
list[dict[str, Any]], df.to_dict(orient="records") | ||
) | ||
for relationship in relationships: | ||
assert isinstance(relationship["relationship"], int) | ||
add_relationship_profiles(relationships) | ||
Migrator().run() | ||
elif type == 'agentenvcombo': | ||
env_ids = list(EnvironmentProfile.all_pks()) | ||
for env_id in env_ids: | ||
sample_env_agent_combo_and_push_to_db(env_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import redis | ||
|
||
r = redis.Redis( | ||
host='us1-normal-burro-37804.upstash.io', | ||
port=37804, | ||
password='a870a438f928424bb507d5895b3ab3fc' | ||
) | ||
|
||
r.set('foo', 'bar') | ||
print(r.get('foo')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from redis_om import JsonModel, get_redis_connection | ||
|
||
class Person(JsonModel): | ||
name: str | ||
age: int | ||
|
||
# Create an instance of your model | ||
person = Person(name="John", age=30) | ||
|
||
# Save to Redis with a specific key | ||
person.save() |