Skip to content

Commit

Permalink
update filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
sharonwx54 committed Nov 21, 2023
1 parent 924f518 commit 3b688e8
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions data_process/redis_data_filtering/redis_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

OVERALL_REWARD_FILTER = 3.2
GOAL_AVG_THRESHOLD = 7
GOAL_KEEP_THRESHOD = 7
GOAL_KEEP_THRESHOD = 10

SELECTED_TAG = [#"gpt-3.5-turbo_gpt-3.5-turbo_v0.0.1_clean",
#"gpt-4_gpt-3.5-turbo_v0.0.1_clean",
Expand Down Expand Up @@ -72,13 +72,16 @@ def goal_reward_by_env_agent(env_epi_dic):

return reward_dic

def goal_filter_per_env_agent(episodes):
def goal_filter_per_env_agent(episodes, apply_filter=True):
# filter using goal reward scores for each agent position given scenario
goal_score = {'agent1':[], 'agent2':[]}
env_tpls = []
# at least need to have half of the total len of the dialogue amount per scenario
# then add the filtering by score
min_threshold_amt = len(episodes) // 2
if apply_filter:
min_threshold_amt = len(episodes) // 2
else:
min_threshold_amt = -1
for episode in episodes:
rewards = episode.rewards
goal_score['agent1'].append(rewards[0][1]['goal'])
Expand All @@ -95,37 +98,37 @@ def goal_filter_per_env_agent(episodes):
env_tpls.append((episodes[agent1_rank[i]], 0))
env_tpls.append((episodes[agent2_rank[i]], 1))
else:
if goal_score['agent1'][agent1_rank[i]] >= min(GOAL_KEEP_THRESHOD, agent1_avg) and (goal_score['agent2'][agent2_rank[i]] >= min(GOAL_KEEP_THRESHOD, agent2_avg)):
if goal_score['agent1'][agent1_rank[i]] > min(GOAL_KEEP_THRESHOD, agent1_avg) and (goal_score['agent2'][agent2_rank[i]] > min(GOAL_KEEP_THRESHOD, agent2_avg)):
env_tpls.append((episodes[agent1_rank[i]], 0))
env_tpls.append((episodes[agent1_rank[i]], 1))

return env_tpls


def goal_filter_all_env_agent(env_episode_dic):
def goal_filter_all_env_agent(env_episode_dic, apply_filter=True):
filter_env_dic = {}
for env, episodes in env_episode_dic.items():
env_agent_episode = goal_filter_per_env_agent(episodes)
env_agent_episode = goal_filter_per_env_agent(episodes, apply_filter)
filter_env_dic[env] = env_agent_episode

return filter_env_dic


def run_filtered_episodes_to_prompt(filter_env_agent_episodes, json_dir, level="Easy"):
def run_filtered_episodes_to_prompt(filter_env_agent_episodes, json_dir, level="Easy", include_format=False):
if not os.path.exists(json_dir):
os.makedirs(json_dir)
parse_count = 0
for env, tpls in filter_env_agent_episodes.items():
if (level == 'Easy' and env in HARD_SCENARIO) or (level == 'Hard' and env not in HARD_SCENARIO):
continue
for tpl in tpls:
parse_prompt_to_json(tpl[0], json_dir, tpl[1])
parse_prompt_to_json(tpl[0], json_dir, tpl[1], include_format)
parse_count+=1

print(parse_count)


"""----------->Functions that were used for different approaches of DP, mostly depreciated<------------ """
"""----------->DEPRECIATED: Functions that were used for different approaches of DP, mostly DEPRECIATED<------------ """

def overall_reward_by_env(episode_env_dict):
# for each scenario, append all episode's overall score
Expand Down

0 comments on commit 3b688e8

Please sign in to comment.