From a9016186db68c188465b6e968a90520d57a63ff4 Mon Sep 17 00:00:00 2001 From: sharonwx54 Date: Thu, 14 Mar 2024 23:18:58 -0700 Subject: [PATCH] convert notebook to py --- data_process/utils/interactive_filtering.py | 121 ++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 data_process/utils/interactive_filtering.py diff --git a/data_process/utils/interactive_filtering.py b/data_process/utils/interactive_filtering.py new file mode 100644 index 00000000..f289ed51 --- /dev/null +++ b/data_process/utils/interactive_filtering.py @@ -0,0 +1,121 @@ +""" +converted from interactive filtering notebook. +The code could be run in notebook for visualizing the score distribution + +BELOW are instruction in notebook +Please do NOT change anything under Global Fix Variables +

+Please MAKE changes under Global Flex Variables to fit your need +

+Once you set everything in Global Flex Variables, click Run All and view the results. +


+--------------------------------------------------------------------------------- +
+Range for FILTER_THRESHOLD should be between 0 and 10. +

+You can enter a list of environment pks under SELECTED_ENV_LIST, so the notebook would only look at conversations under these pks. +

+If SELECTED_ENV_LIST is NOT empty, USE_ONLY_GEN and USE_ONLY_SOTOPIA would be ignore. +

+If SELECTED_ENV_LIST is empty, the environment pks would include either all non-sotopia pks, or all sotopia pks, or both, depending on USE_ONLY_GEN and USE_ONLY_SOTOPIA value. +

+For IF_BALANCE, this applies to filtering threshold < 10. If set to True, the filtering would automatically balance the number of dialogues for agent 1 and agent 2, and only keep the smaller subset. If set to False, the filtering would keep all dialogues that lead to reward above threshold, so agent 1 and agent 2 could have different total number of filtered dialogues. +

+Option for FILTER_SCORE include: +* 'believability' +* 'relationship' +* 'knowledge' +* 'secret' +* 'social_rules' +* 'financial_and_material_benefits' +* 'goal' +* 'overall_score' + +""" +import sys +import os +os.environ["REDIS_OM_URL"] = "redis://:password@server_name:port_num" + +import json +from tqdm.notebook import tqdm +from sotopia.database.persistent_profile import AgentProfile, EnvironmentProfile, RelationshipProfile +from sotopia.database.logs import EpisodeLog +from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage +from redis_om import Migrator +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from redis_filtering import get_sotopia_scenarios, get_generated_scenarios, get_episode_by_env +from redis_filtering import goal_reward_by_env_agent, get_env_mean_var, get_threshold_by_keep_rate, filter_pks_to_prompts +from redis_visualization import plot_agent_reward_distribution, plot_env_reward_distribution, make_pretty + +# Global Var +SOTOPIA_HARD_SCENARIO = set( + ['01H7VFHNV13MHN97GAH73E3KM8', '01H7VFHN5WVC5HKKVBHZBA553R', '01H7VFHNN7XTR99319DS8KZCQM', + '01H7VFHN9W0WAFZCBT09PKJJNK', '01H7VFHPDZVVCDZR3AARA547CY', '01H7VFHPQQQY6H4DNC6NBQ8XTG', + '01H7VFHPQQQY6H4DNC6NBQ8XTG', '01H7VFHN7WJK7VWVRZZTQ6DX9T', '01H7VFHN7A1ZX5KSMT2YN9RXC4', + '01H7VFHPS5WJW2694R1MNC8JFY', '01H7VFHPS5WJW2694R1MNC8JFY', '01H7VFHNN7XTR99319DS8KZCQM', + '01H7VFHQ11NAMZS4A2RDGDB01V', '01H7VFHQ11NAMZS4A2RDGDB01V', '01H7VFHPSWGDGEYRP63H2DJKV0', + '01H7VFHPSWGDGEYRP63H2DJKV0', '01H7VFHNF4G18PC9JHGRC8A1R6', '01H7VFHNNYH3W0VRWVY178K2TK', + '01H7VFHP8AN5643B0NR0NP00VE', '01H7VFHN7A1ZX5KSMT2YN9RXC4']) +SOTOPIA_SCENARIO_PK = get_sotopia_scenarios() +GEN_SCENARIO_PK = get_generated_scenarios(exclude=SOTOPIA_SCENARIO_PK) +OTHER_SCENARIO_PK = ['01H4EFKY8CXBETGNDNDD08X2MS', '01H4EFKY8H8A4P12S4YYE2DNCC', + '01H4EFKY8VCAJJJM8WACW3KYWE', '01H4EFKY91BHEG5GD3VRGGY9YE', + '01H6S9W1B2QRVBT0JTZTX8DVEM', '01H6S9W1B9KPXARNHRFTXBAQ6A', + '01H6S9W1BFRRT091TCP4E4E66J', '01H6S9W1BMGR7MFRPH0V55J2TD', + '01H6S9W1BSQK5WT5Y9RAKSH45J', '01H6S9W1BZVYXYSDG4KR2Z3F4X'] + +# Global Adjustable Var +TAG_LIST = [] # default to all tags +USE_ONLY_SOTOPIA = False +USE_ONLY_GEN = True +SELECTED_ENV_LIST = OTHER_SCENARIO_PK # [] leave it empty if you want all sotopia, all non-sotopia, or everything +FILTER_SCORE = "goal" +FILTER_THRESHOLD = 7 +KEEP_RATE = 0.75 # this is the amount of data we want to keep, after filtering +IF_BALANCE = True # when filtering, do you want to keep # agent 1 and # agent 2 balance? + +# set to False if you want to mix agent1 and agent2 and only see distribution of rewards under an environment +SPLIT_ENV_AGENT_DISPLAY=True + +# ONLY set to true if, after reviewing the filtered result, you want to convert the episodelogs into completion format +TO_PROMPT = False +# path and folder name to save json of prompts +SAVE_DIR = "" +# most case we ignore formatting, so set to False +INCLUDE_FORMAT = False + +# mean, var for full data without filtering +env_episodes = get_episode_by_env(TAG_LIST, USE_ONLY_SOTOPIA, USE_ONLY_GEN, OTHER_SCENARIO_PK) +env_rewards, env_pks = goal_reward_by_env_agent(env_episodes, FILTER_SCORE) +env_mean_var = get_env_mean_var(env_rewards) + +"""Threshold Checking""" +approx_threshold = get_threshold_by_keep_rate(env_rewards, KEEP_RATE, IF_BALANCE) +print(approx_threshold) + +"""Running Filtering""" +# first select in-range episodes and the scores +filter_env_rewards, filter_env_pks = goal_reward_by_env_agent( + env_episodes, FILTER_SCORE, FILTER_THRESHOLD, balance=IF_BALANCE) +filter_env_mean_var = get_env_mean_var(filter_env_rewards) + +# display strats +filter_stats_df = pd.concat({k: pd.DataFrame(v) for k, v in filter_env_mean_var.items()}).unstack(1) +filter_stats_df.sort_index(inplace=True) +filter_agent1_num = filter_stats_df['agent1']['count'].sum() +filter_agent2_num = filter_stats_df['agent2']['count'].sum() +print("Total {} + {} Conversation Across {} Environments.".format( + filter_agent1_num, filter_agent2_num, len(filter_stats_df), )) +print("Agent 1 reduced by {} ({}% of original). Agent 2 reduced by {} ({}% of original).".format( + agent1_num - filter_agent1_num, 100*round((filter_agent1_num)/(agent1_num), 2), + agent2_num - filter_agent2_num, 100*round((filter_agent2_num)/(agent2_num ), 2))) +if filter_agent1_num != filter_agent2_num: + print("If we want to keep conversations for agent 1 and agent 2 balance, go to the above cell and change BALANCE = True") + print("Resulting number of conversations would be no more than {} for each agent.".format(min(filter_agent1_num, filter_agent2_num))) +filter_stats_df.style.pipe(make_pretty, max(filter_stats_df['agent1']['var']+filter_stats_df['agent2']['var'])) + + +if TO_PROMPT: + filter_pks_to_prompts(filter_env_pks, SAVE_DIR, INCLUDE_FORMAT) \ No newline at end of file