Skip to content

Commit

Permalink
adding sliding window to remove dialogue but keep context when surpas…
Browse files Browse the repository at this point in the history
…sing 2048 tokens
  • Loading branch information
sharonwx54 committed Oct 26, 2023
1 parent a43ffbf commit 18d4bde
Showing 1 changed file with 42 additions and 11 deletions.
53 changes: 42 additions & 11 deletions data_process/redis_data_filtering/prompt_reverse_engineering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from collections import defaultdict
from typing import Any, Dict, List, Tuple, Union, cast

import transformers
import pandas as pd
import rich
from rich.console import Console
Expand All @@ -15,17 +15,14 @@
import enum

#PROMPT_PREFIX = "Prompt after formatting:\n"

MAX_TOKEN = 2048
PROMPT_TEMPLATE="""Prompt after formatting:\nImagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal.
You can find {agent}'s background and goal in the 'Here is the context of the interaction' field.
Note that {agent}'s secret and goal is only visible to you.
You should try your best to achieve {agent}'s goal in a way that align with their character traits.
Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).
{history}.
You are at Turn #{turn_number}. Your available action types are
{action_list}.
Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.
"""
You are at Turn #{turn_number}."""

#PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)
FORMAT_TEMPLATE = """\nAs an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]}
Expand All @@ -40,7 +37,7 @@
Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).
{history}.
You are at Turn #{turn_number}. Your available action types are
{action_list}.
"none action speak non-verbal communication leave".
Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.
Please only generate a JSON string including the action type and the argument.
Expand All @@ -54,6 +51,16 @@

ACTION_REVERSE_MAP = {"left ": "leave", 'did n': 'none', 'said:': 'speak'}

MODEL_CHECKPOINT = "meta-llama/Llama-2-13b-chat-hf"
HF_TOKEN = "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"


TOKENIZER = transformers.AutoTokenizer.from_pretrained(
MODEL_CHECKPOINT,
padding = False,
truncation = False,
token=HF_TOKEN,
)

def to_natural_language(self) -> str:
match self.action_type:
Expand Down Expand Up @@ -113,7 +120,23 @@ def generate_result(msg):

return str_result

def reverse_episode_log(epilog, later_speak=False, include_format=False):
def surpass_max_token_check(string, max_token=MAX_TOKEN, tokenizer=TOKENIZER):
prompt_tokens = len(tokenizer(string)['input_ids'])
return max(prompt_tokens - max_token, 0)

def truncate_prompt_to_length(dia_his, surpass_num, tokenizer=TOKENIZER):
# context_len = len(tokenizer(context)['input_ids'])
dia_sen = dia_his.split("\n")
remove_len = 0
i = 0
while remove_len < surpass_num:
remove_len+=len(tokenizer(dia_sen[i])['input_ids'])
i+=1
trunc_dia = "\n".join(p for p in dia_sen[i:])
return trunc_dia


def reverse_episode_log(epilog, later_speak=False, include_format=False, max_token=MAX_TOKEN):
episode_msg = epilog.messages
# per episode
agent_model = epilog.models[1]
Expand Down Expand Up @@ -144,7 +167,8 @@ def reverse_episode_log(epilog, later_speak=False, include_format=False):
dial_history += "\n"+tpl[2]
else:
# for the first context, we don't need \n
dial_history += tpl[2]
context = tpl[2]
dial_history += context

if tpl[0] == speaker: # if speaker is the agent, use what he said as result
str_result = generate_result(tpl[2])
Expand All @@ -153,14 +177,21 @@ def reverse_episode_log(epilog, later_speak=False, include_format=False):
# take alternative turns as we always want to predict one agent, not both
next_turn = i
prompt = promt_template.format(
agent=speaker, history=dial_history, turn_number=next_turn,
action_list=ACTION_LIST)
agent=speaker, history=dial_history, turn_number=next_turn)
over_tokens = surpass_max_token_check(prompt)
if over_tokens > 0:
all_dial = dial_history[len(context):]
#print(all_dial)
trun_dial = truncate_prompt_to_length(all_dial, over_tokens)
prompt = promt_template.format(
agent=speaker, history=context+"\n"+trun_dial, turn_number=next_turn)
turn_dic["prompt"] = prompt
turn_dic['result'] = str_result
prompt_result_instances.append(turn_dic)

return prompt_result_instances


def parse_prompt_to_json(episode, dir, init_speak):
prompt_result_instances = reverse_episode_log(episode, init_speak)

Expand Down

0 comments on commit 18d4bde

Please sign in to comment.