Skip to content

Commit

Permalink
finished code for hard scenario eval
Browse files Browse the repository at this point in the history
  • Loading branch information
zqi2cmu authored and ruiyiw committed Nov 14, 2023
1 parent ce6344e commit 1463d61
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 1 deletion.
Empty file removed eval/dummyfile
Empty file.
75 changes: 75 additions & 0 deletions eval/llm_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import json
from tqdm import tqdm

from langchain.llms import OpenAI
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from langchain.pydantic_v1 import BaseModel, Field


class QuantitativeEval(BaseModel):
agent1_name: str = Field(description="Agent 1's name")
agent1_gain: int = Field(description="Agent 1's gain/loss")
agent2_name: str = Field(description="Agent 2's name")
agent2_gain: int = Field(description="Agent 2's gain/loss")


def get_model_parser(model_name='text-davinci-003') -> (PromptTemplate, PydanticOutputParser):
model = OpenAI(model_name=model_name, temperature=0.0)
parser = PydanticOutputParser(pydantic_object=QuantitativeEval)

prompt_text = (
"Try to understand the following situation and answer the question in the end. "
"\n Situation: {situation}"
"\n Question: {question}"
"\n Please represent loss as negative values. {format_instructions}\n "
)

prompt = PromptTemplate(
template=prompt_text,
input_variables=["situation", "question"],
partial_variables={"format_instructions": parser.get_format_instructions()}
)

prompt_and_model = prompt | model

return prompt_and_model, parser


def evaluate(environment_episode_map, environment_question_map, model_name='text-davinci-003'):
results = {}
model, response_parser = get_model_parser(model_name=model_name)

for environment_id, episodes in tqdm(environment_episode_map.items()):
results_for_env = []

for episode in episodes:
situation = episode["messages_and_rewards"]
question = environment_question_map.get(environment_id)

if question:
model_response = model.invoke({"situation": situation, "question": question})
parsed_output = response_parser.parse(model_response)
episode["output"] = parsed_output.dict()

results_for_env.append(episode)

results[environment_id] = results_for_env

return results


def main():
with open("human_readable_eps_by_env.json", "r") as f:
env_eps_map = json.load(f)

with open("env_specific_eval.json", "r") as f:
env_question_map = json.load(f)

res = evaluate(env_eps_map, env_question_map)

with open("env_specific_eval_with_output.json", "w") as f:
json.dump(res, f)

if __name__ == "__main__":
main()
25 changes: 25 additions & 0 deletions eval/pull_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from sotopia.database.logs import EpisodeLog
from sotopia.database.persistent_profile import EnvironmentProfile
from sotopia.database.persistent_profile import AgentProfile
import json

TAG = "ft-llama-2-13b-chat_baseline_ruiyi_1010_7" # Baseline tag

HARD_ENVS = ["01H7VFHNV13MHN97GAH73E3KM8", "01H7VFHN5WVC5HKKVBHZBA553R", "01H7VFHN9W0WAFZCBT09PKJJNK", "01H7VFHPDZVVCDZR3AARA547CY", "01H7VFHPQQQY6H4DNC6NBQ8XTG", "01H7VFHN7WJK7VWVRZZTQ6DX9T", "01H7VFHPS5WJW2694R1MNC8JFY",
"01H7VFHNN7XTR99319DS8KZCQM", "01H7VFHQ11NAMZS4A2RDGDB01V", "01H7VFHPSWGDGEYRP63H2DJKV0", "01H7VFHNF4G18PC9JHGRC8A1R6", "01H7VFHNNYH3W0VRWVY178K2TK", "01H7VFHP8AN5643B0NR0NP00VE", "01H7VFHN7A1ZX5KSMT2YN9RXC4"]

envs = []
eps_by_env = dict()
human_readable_eps_by_env = dict()

for env_profile_id in HARD_ENVS:
eps = list(EpisodeLog.find(EpisodeLog.tag == TAG,
EpisodeLog.environment == env_profile_id))
eps_by_env[env_profile_id] = eps
human_readable_eps_by_env[env_profile_id] = []
for ep in eps:
agent_profiles, messages_and_rewards = ep.render_for_humans()
human_readable_eps_by_env[env_profile_id].append({"env_pk": env_profile_id, "ep_pk": ep.pk, "agents": ep.agents, "messages_and_rewards": "\n".join(messages_and_rewards)})

with open("human_readable_eps_by_env.json", "w") as f:
json.dump(human_readable_eps_by_env, f)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ datasets
names
together
pydantic==1.10.12

sotopia

0 comments on commit 1463d61

Please sign in to comment.