From 1c715b623e9c5c6414608bce62988c9e726648bd Mon Sep 17 00:00:00 2001 From: Sharon Zhang <123585394+sharonwx54@users.noreply.github.com> Date: Thu, 26 Oct 2023 16:14:44 -0400 Subject: [PATCH] Make prompt shorter to fit 2048 window size (#78) * remove format in prompt * small bug in filtering * remove one line of format * adding sliding window to remove dialogue but keep context when surpassing 2048 tokens --- .../prompt_reverse_engineering.py | 74 +++++++++++++++---- .../redis_data_filtering/redis_filtering.py | 2 +- 2 files changed, 60 insertions(+), 16 deletions(-) diff --git a/data_process/redis_data_filtering/prompt_reverse_engineering.py b/data_process/redis_data_filtering/prompt_reverse_engineering.py index aab51079..92eb9296 100644 --- a/data_process/redis_data_filtering/prompt_reverse_engineering.py +++ b/data_process/redis_data_filtering/prompt_reverse_engineering.py @@ -2,7 +2,7 @@ import os from collections import defaultdict from typing import Any, Dict, List, Tuple, Union, cast - +import transformers import pandas as pd import rich from rich.console import Console @@ -15,21 +15,14 @@ import enum #PROMPT_PREFIX = "Prompt after formatting:\n" - +MAX_TOKEN = 2048 PROMPT_TEMPLATE="""Prompt after formatting:\nImagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal. You can find {agent}'s background and goal in the 'Here is the context of the interaction' field. Note that {agent}'s secret and goal is only visible to you. You should try your best to achieve {agent}'s goal in a way that align with their character traits. Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before). {history}. -You are at Turn #{turn_number}. Your available action types are -{action_list}. -Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave. - -Please only generate a JSON string including the action type and the argument. -Your action should follow the given format: -{format_instructions} -""" +You are at Turn #{turn_number}.""" #PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str) FORMAT_TEMPLATE = """\nAs an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]} @@ -37,11 +30,37 @@ \nHere is the output schema:\n```\n{\"description\": \"An interface for messages.\\nThere is only one required method: to_natural_language\", \"properties\": {\"action_type\": {\"title\": \"Action Type\", \"description\": \"whether to speak at this turn or choose to not do anything\", \"enum\": [\"none\", \"speak\", \"non-verbal communication\", \"action\", \"leave\"], \"type\": \"string\"}, \"argument\": {\"title\": \"Argument\", \"description\": \"the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action\", \"type\": \"string\"}}, \"required\": [\"action_type\", \"argument\"]}\n```\u001b[0m""" +PROMPT_TEMPLATE_W_FORMAT="""Prompt after formatting:\nImagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal. +You can find {agent}'s background and goal in the 'Here is the context of the interaction' field. +Note that {agent}'s secret and goal is only visible to you. +You should try your best to achieve {agent}'s goal in a way that align with their character traits. +Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before). +{history}. +You are at Turn #{turn_number}. Your available action types are +"none action speak non-verbal communication leave". +Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave. + +Please only generate a JSON string including the action type and the argument. +Your action should follow the given format: +\nAs an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]} +the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance of the schema. The object {\"properties\": {\"foo\": [\"bar\", \"baz\"]}} is not well-formatted. +\nHere is the output schema:\n```\n{\"description\": \"An interface for messages.\\nThere is only one required method: to_natural_language\", \"properties\": {\"action_type\": {\"title\": \"Action Type\", \"description\": \"whether to speak at this turn or choose to not do anything\", \"enum\": [\"none\", \"speak\", \"non-verbal communication\", \"action\", \"leave\"], \"type\": \"string\"}, \"argument\": {\"title\": \"Argument\", \"description\": \"the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action\", \"type\": \"string\"}}, \"required\": [\"action_type\", \"argument\"]}\n```\u001b[0m +""" # static ACTION_LIST = "none action speak non-verbal communication leave" #" ".join(ActionType) ACTION_REVERSE_MAP = {"left ": "leave", 'did n': 'none', 'said:': 'speak'} +MODEL_CHECKPOINT = "meta-llama/Llama-2-13b-chat-hf" +HF_TOKEN = "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" + + +TOKENIZER = transformers.AutoTokenizer.from_pretrained( + MODEL_CHECKPOINT, + padding = False, + truncation = False, + token=HF_TOKEN, + ) def to_natural_language(self) -> str: match self.action_type: @@ -101,10 +120,27 @@ def generate_result(msg): return str_result -def reverse_episode_log(epilog, later_speak=False): +def surpass_max_token_check(string, max_token=MAX_TOKEN, tokenizer=TOKENIZER): + prompt_tokens = len(tokenizer(string)['input_ids']) + return max(prompt_tokens - max_token, 0) + +def truncate_prompt_to_length(dia_his, surpass_num, tokenizer=TOKENIZER): + # context_len = len(tokenizer(context)['input_ids']) + dia_sen = dia_his.split("\n") + remove_len = 0 + i = 0 + while remove_len < surpass_num: + remove_len+=len(tokenizer(dia_sen[i])['input_ids']) + i+=1 + trunc_dia = "\n".join(p for p in dia_sen[i:]) + return trunc_dia + + +def reverse_episode_log(epilog, later_speak=False, include_format=False, max_token=MAX_TOKEN): episode_msg = epilog.messages # per episode agent_model = epilog.models[1] + promt_template = PROMPT_TEMPLATE_W_FORMAT if include_format else PROMPT_TEMPLATE if len(episode_msg) > 0: init_loop = episode_msg[0] @@ -131,7 +167,8 @@ def reverse_episode_log(epilog, later_speak=False): dial_history += "\n"+tpl[2] else: # for the first context, we don't need \n - dial_history += tpl[2] + context = tpl[2] + dial_history += context if tpl[0] == speaker: # if speaker is the agent, use what he said as result str_result = generate_result(tpl[2]) @@ -139,15 +176,22 @@ def reverse_episode_log(epilog, later_speak=False): if i%2 == turn_div: # take alternative turns as we always want to predict one agent, not both next_turn = i - prompt = PROMPT_TEMPLATE.format( - agent=speaker, history=dial_history, turn_number=next_turn, - action_list=ACTION_LIST, format_instructions=FORMAT_TEMPLATE) + prompt = promt_template.format( + agent=speaker, history=dial_history, turn_number=next_turn) + over_tokens = surpass_max_token_check(prompt) + if over_tokens > 0: + all_dial = dial_history[len(context):] + #print(all_dial) + trun_dial = truncate_prompt_to_length(all_dial, over_tokens) + prompt = promt_template.format( + agent=speaker, history=context+"\n"+trun_dial, turn_number=next_turn) turn_dic["prompt"] = prompt turn_dic['result'] = str_result prompt_result_instances.append(turn_dic) return prompt_result_instances + def parse_prompt_to_json(episode, dir, init_speak): prompt_result_instances = reverse_episode_log(episode, init_speak) diff --git a/data_process/redis_data_filtering/redis_filtering.py b/data_process/redis_data_filtering/redis_filtering.py index b4c95a74..690b7997 100644 --- a/data_process/redis_data_filtering/redis_filtering.py +++ b/data_process/redis_data_filtering/redis_filtering.py @@ -95,7 +95,7 @@ def goal_filter_per_env_agent(episodes): env_tpls.append((episodes[agent1_rank[i]], 0)) env_tpls.append((episodes[agent2_rank[i]], 1)) else: - if goal_score['agent1'][agent1_rank[i]] >= min(GOAL_KEEP_THRESHOD, agent1_avg) and (goal_score['agent2'][agent2_rank[i]] >= min(KEEP_THRESHOD, agent2_avg)): + if goal_score['agent1'][agent1_rank[i]] >= min(GOAL_KEEP_THRESHOD, agent1_avg) and (goal_score['agent2'][agent2_rank[i]] >= min(GOAL_KEEP_THRESHOD, agent2_avg)): env_tpls.append((episodes[agent1_rank[i]], 0)) env_tpls.append((episodes[agent1_rank[i]], 1))