-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add step1 and step2 but still not correct
- Loading branch information
Showing
6 changed files
with
347 additions
and
4 deletions.
There are no files selected for viewing
118 changes: 118 additions & 0 deletions
118
llm_generate/generate_episode_constraint_based_sampling.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import asyncio | ||
import logging | ||
import sys | ||
from logging import FileHandler | ||
from typing import Literal | ||
|
||
from rich import print | ||
from rich.logging import RichHandler | ||
|
||
from sotopia.agents import LLMAgent | ||
from sotopia.database import ( | ||
AgentProfile, | ||
EnvAgentComboStorage, | ||
EnvironmentProfile, | ||
) | ||
from sotopia.envs import ParallelSotopiaEnv | ||
from sotopia.envs.evaluators import ( | ||
ReachGoalLLMEvaluator, | ||
RuleBasedTerminatedEvaluator, | ||
) | ||
from sotopia.generation_utils.generate import LLM_Name | ||
from sotopia.messages import AgentAction, Observation | ||
from sotopia.samplers import ( | ||
BaseSampler, | ||
ConstraintBasedSampler, | ||
EnvAgentCombo, | ||
) | ||
from sotopia.server import run_async_server | ||
|
||
# date and message only | ||
FORMAT = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" | ||
|
||
logging.basicConfig( | ||
level=15, | ||
format=FORMAT, | ||
datefmt="[%X]", | ||
handlers=[ | ||
RichHandler(), | ||
FileHandler("./round_robin_parallel_sotopia_env_2.log"), | ||
], | ||
) | ||
|
||
model_names: dict[str, LLM_Name] = { | ||
"env": "gpt-4", | ||
"agent1": "gpt-3.5-turbo", | ||
"agent2": "gpt-3.5-turbo", | ||
} | ||
|
||
push_to_db = sys.argv[1] | ||
assert push_to_db in ["True", "False"], "push_to_db should be True or False" | ||
push_to_db_bool = push_to_db == "True" | ||
env_ids: list[str] = [] | ||
|
||
for code_name in ["secret_feeling"]: | ||
envs_with_code_name = EnvironmentProfile.find( | ||
EnvironmentProfile.codename == code_name | ||
).all() | ||
assert len(envs_with_code_name) | ||
assert (env_id := envs_with_code_name[0].pk) | ||
env_ids.append(env_id) | ||
|
||
|
||
for env_id in env_ids: | ||
assert env_id is not None, "env_id should not be None" | ||
env_agent_combo_storage_list = EnvAgentComboStorage.find( | ||
EnvAgentComboStorage.env_id == env_id | ||
).all() | ||
|
||
sampler = ( | ||
ConstraintBasedSampler[Observation, AgentAction]( | ||
env_candidates=[env_id], | ||
) | ||
if len(env_agent_combo_storage_list) == 0 | ||
else BaseSampler[Observation, AgentAction]() | ||
) | ||
|
||
env_agent_combo_list: list[EnvAgentCombo[Observation, AgentAction]] = [] | ||
|
||
import pdb; pdb.set_trace() | ||
|
||
for env_agent_combo_storage in env_agent_combo_storage_list: | ||
assert isinstance(env_agent_combo_storage, EnvAgentComboStorage) | ||
env_profile = EnvironmentProfile.get(env_agent_combo_storage.env_id) | ||
env = ParallelSotopiaEnv( | ||
env_profile=env_profile, | ||
model_name=model_names["env"], | ||
action_order="round-robin", | ||
evaluators=[ | ||
RuleBasedTerminatedEvaluator( | ||
max_turn_number=20, max_stale_turn=2 | ||
), | ||
], | ||
terminal_evaluators=[ | ||
ReachGoalLLMEvaluator(model_names["env"]), | ||
], | ||
) | ||
agent_profiles = [ | ||
AgentProfile.get(id) for id in env_agent_combo_storage.agent_ids | ||
] | ||
|
||
agents = [ | ||
LLMAgent(agent_profile=agent_profile, model_name=agent_model) | ||
for agent_profile, agent_model in zip( | ||
agent_profiles, [model_names["agent1"], model_names["agent2"]] | ||
) | ||
] | ||
|
||
env_agent_combo_list.append((env, agents)) | ||
asyncio.run( | ||
run_async_server( | ||
model_dict=model_names, | ||
action_order="round-robin", | ||
sampler=sampler, | ||
env_agent_combo_list=env_agent_combo_list, | ||
push_to_db=push_to_db_bool, | ||
using_async=False, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1 @@ | ||
sotopia | ||
datasets | ||
sotopia |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import asyncio | ||
import random | ||
from typing import TypeVar | ||
from tqdm import tqdm | ||
|
||
import pandas as pd | ||
import rich | ||
from pydantic import BaseModel | ||
|
||
from sotopia.database import EnvironmentProfile | ||
from sotopia.generation_utils.generate import agenerate_env_profile | ||
|
||
random.seed(41) | ||
|
||
env_borrowMoney = EnvironmentProfile.find( | ||
EnvironmentProfile.codename == "borrow_money" | ||
).all()[0] | ||
env_roadtrip = EnvironmentProfile.find( | ||
EnvironmentProfile.codename == "take_turns" | ||
).all()[0] | ||
env_prisonerDillema = EnvironmentProfile.find( | ||
EnvironmentProfile.codename == "prison_dilemma" | ||
).all()[0] | ||
|
||
examples = f"{env_borrowMoney.json()}\n\n{env_roadtrip.json()}\n\n{env_prisonerDillema.json()}" | ||
|
||
ins_prompts = pd.read_csv("./inspirational_prompt_for_env.csv") | ||
prompts = ins_prompts["prompt"].tolist() | ||
|
||
T = TypeVar("T", bound=BaseModel) | ||
|
||
|
||
def pydantics_to_csv(filename: str, data: list[T]) -> None: | ||
pd.DataFrame([item.dict() for item in data]).to_csv(filename, index=False) | ||
|
||
|
||
backgrounds = [] | ||
for prompt in tqdm(prompts): | ||
rich.print(prompt) | ||
background, prompt_full = asyncio.run( | ||
agenerate_env_profile( | ||
model_name="gpt-4", | ||
inspiration_prompt=prompt, | ||
examples=examples, | ||
) | ||
) | ||
rich.print(background) | ||
rich.print(prompt_full) | ||
backgrounds.append(background) | ||
|
||
pydantics_to_csv("./backgrounds.csv", backgrounds) |
164 changes: 164 additions & 0 deletions
164
llm_generate/step2_push_agent_relationship_env_to_db.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import ast | ||
import sys | ||
from typing import Any, cast | ||
|
||
import pandas as pd | ||
from redis_om import Migrator | ||
|
||
from sotopia.database.persistent_profile import ( | ||
AgentProfile, | ||
EnvironmentProfile, | ||
RelationshipProfile, | ||
RelationshipType, | ||
) | ||
from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage | ||
from sotopia.samplers import ConstraintBasedSampler | ||
from sotopia.messages import AgentAction, Observation | ||
from sotopia.agents import LLMAgent | ||
import redis | ||
|
||
|
||
|
||
def add_agent_to_database(**kwargs: dict[str, Any]) -> None: | ||
agent = AgentProfile(**kwargs) | ||
agent.save() | ||
|
||
|
||
def add_agents_to_database(agents: list[dict[str, Any]]) -> None: | ||
for agent in agents: | ||
add_agent_to_database(**agent) | ||
|
||
|
||
def retrieve_agent_by_first_name(first_name: str) -> AgentProfile: | ||
result = AgentProfile.find(AgentProfile.first_name == first_name).all() | ||
if len(result) == 0: | ||
raise ValueError(f"Agent with first name {first_name} not found") | ||
elif len(result) > 1: | ||
raise ValueError(f"Multiple agents with first name {first_name} found") | ||
else: | ||
assert isinstance(result[0], AgentProfile) | ||
return result[0] | ||
|
||
|
||
def add_env_profile(**kwargs: dict[str, Any]) -> None: | ||
import pdb; pdb.set_trace() | ||
env_profile = EnvironmentProfile(**kwargs) | ||
env_profile.save(r) | ||
import pdb; pdb.set_trace() | ||
|
||
|
||
def add_env_profiles(env_profiles: list[dict[str, Any]]) -> None: | ||
for env_profile in env_profiles: | ||
add_env_profile(**env_profile) | ||
import pdb; pdb.set_trace() | ||
|
||
|
||
def add_relationship_profile(**kwargs: dict[str, Any]) -> None: | ||
relationship_profile = RelationshipProfile(**kwargs) | ||
relationship_profile.save() | ||
|
||
|
||
def add_relationship_profiles( | ||
relationship_profiles: list[dict[str, Any]] | ||
) -> None: | ||
for relationship_profile in relationship_profiles: | ||
add_relationship_profile(**relationship_profile) | ||
|
||
|
||
def delete_all_agents() -> None: | ||
pks = AgentProfile.all_pks() | ||
pks_list = list(pks) | ||
for id in pks: | ||
AgentProfile.delete(id) | ||
|
||
|
||
def delete_all_env_profiles() -> None: | ||
pks = EnvironmentProfile.all_pks() | ||
#for id in pks: | ||
# EnvironmentProfile.delete(id) | ||
|
||
|
||
def delete_all_relationships() -> None: | ||
pks = list(RelationshipProfile.all_pks()) | ||
#for id in pks: | ||
# RelationshipProfile.delete(id) | ||
pks = list(RelationshipProfile.all_pks()) | ||
print("Relationships deleted, all relationships: ", len(list(pks))) | ||
|
||
|
||
def sample_env_agent_combo_and_push_to_db(env_id: str) -> None: | ||
sampler = ConstraintBasedSampler[Observation, AgentAction]( | ||
env_candidates=[env_id] | ||
) | ||
env_agent_combo_list = list( | ||
sampler.sample(agent_classes=[LLMAgent] * 2, replacement=False) | ||
) | ||
for env, agent in env_agent_combo_list: | ||
EnvAgentComboStorage( | ||
env_id=env.profile.pk, | ||
agent_ids=[agent[0].profile.pk, agent[1].profile.pk], | ||
).save() | ||
|
||
|
||
def relationship_map(relationship: str) -> int: | ||
return int(eval(relationship)) | ||
|
||
|
||
if __name__ == "__main__": | ||
assert ( | ||
len(sys.argv) == 3 | ||
), "Please provide a csv file with agent or environment profiles, and the type of profile (agent or environment)" | ||
df = pd.read_csv(sys.argv[1]) | ||
type = sys.argv[2] | ||
if type == "agent": | ||
import pdb; pdb.set_trace() | ||
agents = cast(list[dict[str, Any]], df.to_dict(orient="records")) | ||
for agent in agents: | ||
agent["age"] = int(agent["age"]) | ||
agent["moral_values"] = agent["moral_values"].split(",") | ||
agent["schwartz_personal_values"] = agent[ | ||
"schwartz_personal_values" | ||
].split(",") | ||
add_agents_to_database(agents) | ||
elif type == "environment": | ||
pks = EnvironmentProfile.all_pks() | ||
''' | ||
df = df[ | ||
( | ||
df["Xuhui"].astype(float).fillna(0) | ||
+ df["Leena"].astype(float).fillna(0) | ||
+ df["Hao"].astype(float).fillna(0) | ||
) | ||
> 1 | ||
] | ||
''' | ||
df = df[ | ||
[ | ||
"codename", | ||
"scenario", | ||
"agent_goals", | ||
"relationship", | ||
"age_constraint", | ||
"occupation_constraint", | ||
"source", | ||
] | ||
] | ||
envs = cast(list[dict[str, Any]], df.to_dict(orient="records")) | ||
for env in envs: | ||
env["agent_goals"] = ast.literal_eval(env["agent_goals"]) | ||
# check env['relationship'] is int | ||
assert isinstance(env["relationship"], int) | ||
|
||
add_env_profiles(envs) | ||
Migrator().run() | ||
elif type == "relationship": | ||
import pdb; pdb.set_trace() | ||
relationships = cast( | ||
list[dict[str, Any]], df.to_dict(orient="records") | ||
) | ||
for relationship in relationships: | ||
relationship["relationship"] = relationship_map( | ||
relationship["relationship"] | ||
) | ||
add_relationship_profiles(relationships) | ||
Migrator().run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import redis | ||
|
||
r = redis.Redis( | ||
host='us1-normal-burro-37804.upstash.io', | ||
port=37804, | ||
password='a870a438f928424bb507d5895b3ab3fc' | ||
) | ||
|
||
r.set('foo', 'bar') | ||
print(r.get('foo')) |