-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
finished code for hard scenario eval
- Loading branch information
Showing
4 changed files
with
101 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import json | ||
from tqdm import tqdm | ||
|
||
from langchain.llms import OpenAI | ||
from langchain.output_parsers import PydanticOutputParser | ||
from langchain.prompts import PromptTemplate | ||
from langchain.pydantic_v1 import BaseModel, Field | ||
|
||
|
||
class QuantitativeEval(BaseModel): | ||
agent1_name: str = Field(description="Agent 1's name") | ||
agent1_gain: int = Field(description="Agent 1's gain/loss") | ||
agent2_name: str = Field(description="Agent 2's name") | ||
agent2_gain: int = Field(description="Agent 2's gain/loss") | ||
|
||
|
||
def get_model_parser(model_name='text-davinci-003') -> (PromptTemplate, PydanticOutputParser): | ||
model = OpenAI(model_name=model_name, temperature=0.0) | ||
parser = PydanticOutputParser(pydantic_object=QuantitativeEval) | ||
|
||
prompt_text = ( | ||
"Try to understand the following situation and answer the question in the end. " | ||
"\n Situation: {situation}" | ||
"\n Question: {question}" | ||
"\n Please represent loss as negative values. {format_instructions}\n " | ||
) | ||
|
||
prompt = PromptTemplate( | ||
template=prompt_text, | ||
input_variables=["situation", "question"], | ||
partial_variables={"format_instructions": parser.get_format_instructions()} | ||
) | ||
|
||
prompt_and_model = prompt | model | ||
|
||
return prompt_and_model, parser | ||
|
||
|
||
def evaluate(environment_episode_map, environment_question_map, model_name='text-davinci-003'): | ||
results = {} | ||
model, response_parser = get_model_parser(model_name=model_name) | ||
|
||
for environment_id, episodes in tqdm(environment_episode_map.items()): | ||
results_for_env = [] | ||
|
||
for episode in episodes: | ||
situation = episode["messages_and_rewards"] | ||
question = environment_question_map.get(environment_id) | ||
|
||
if question: | ||
model_response = model.invoke({"situation": situation, "question": question}) | ||
parsed_output = response_parser.parse(model_response) | ||
episode["output"] = parsed_output.dict() | ||
|
||
results_for_env.append(episode) | ||
|
||
results[environment_id] = results_for_env | ||
|
||
return results | ||
|
||
|
||
def main(): | ||
with open("human_readable_eps_by_env.json", "r") as f: | ||
env_eps_map = json.load(f) | ||
|
||
with open("env_specific_eval.json", "r") as f: | ||
env_question_map = json.load(f) | ||
|
||
res = evaluate(env_eps_map, env_question_map) | ||
|
||
with open("env_specific_eval_with_output.json", "w") as f: | ||
json.dump(res, f) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from sotopia.database.logs import EpisodeLog | ||
from sotopia.database.persistent_profile import EnvironmentProfile | ||
from sotopia.database.persistent_profile import AgentProfile | ||
import json | ||
|
||
TAG = "ft-llama-2-13b-chat_baseline_ruiyi_1010_7" # Baseline tag | ||
|
||
HARD_ENVS = ["01H7VFHNV13MHN97GAH73E3KM8", "01H7VFHN5WVC5HKKVBHZBA553R", "01H7VFHN9W0WAFZCBT09PKJJNK", "01H7VFHPDZVVCDZR3AARA547CY", "01H7VFHPQQQY6H4DNC6NBQ8XTG", "01H7VFHN7WJK7VWVRZZTQ6DX9T", "01H7VFHPS5WJW2694R1MNC8JFY", | ||
"01H7VFHNN7XTR99319DS8KZCQM", "01H7VFHQ11NAMZS4A2RDGDB01V", "01H7VFHPSWGDGEYRP63H2DJKV0", "01H7VFHNF4G18PC9JHGRC8A1R6", "01H7VFHNNYH3W0VRWVY178K2TK", "01H7VFHP8AN5643B0NR0NP00VE", "01H7VFHN7A1ZX5KSMT2YN9RXC4"] | ||
|
||
envs = [] | ||
eps_by_env = dict() | ||
human_readable_eps_by_env = dict() | ||
|
||
for env_profile_id in HARD_ENVS: | ||
eps = list(EpisodeLog.find(EpisodeLog.tag == TAG, | ||
EpisodeLog.environment == env_profile_id)) | ||
eps_by_env[env_profile_id] = eps | ||
human_readable_eps_by_env[env_profile_id] = [] | ||
for ep in eps: | ||
agent_profiles, messages_and_rewards = ep.render_for_humans() | ||
human_readable_eps_by_env[env_profile_id].append({"env_pk": env_profile_id, "ep_pk": ep.pk, "agents": ep.agents, "messages_and_rewards": "\n".join(messages_and_rewards)}) | ||
|
||
with open("human_readable_eps_by_env.json", "w") as f: | ||
json.dump(human_readable_eps_by_env, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,4 +16,4 @@ datasets | |
names | ||
together | ||
pydantic==1.10.12 | ||
|
||
sotopia |