Skip to content

Commit

Permalink
support scenario and social goal generation based on new inspirationa…
Browse files Browse the repository at this point in the history
…l prompts
  • Loading branch information
lwaekfjlk committed Nov 21, 2023
1 parent cf27a32 commit 3312bef
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,5 @@ cython_debug/
./llm_rl/preprocess/GPT4-4_Redis_Easy_No_Slide

llm_rl/*cache/
data_generate/**/*.jsonl
data_generate/**/*.tsv
7 changes: 4 additions & 3 deletions data_generate/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,12 @@ async def agenerate_env_profile(
"""
return await agenerate(
model_name=model_name,
template="""Please generate scenarios and goals following the examples below.
template="""Please generate scenarios and goals following those examples below:
Examples:
{examples}
Additionally, generate creative scenarios based on one or more inspirational prompt. The scenario and social goal is motivated by them but not very related to those prompts, when creating the goals, try to find one point that both sides may not agree upon initially and need to collaboratively resolve it.
Inspirational prompt: {inspiration_prompt}
Generate creative scenarios and social goals based on one or more inspirational prompt listed below. The scenario and social goal should be related to at least one of those inspirational prompts, when creating the goals, try to find one point that both sides may not agree upon initially and need to collaboratively resolve it.
Inspirational prompt:
{inspiration_prompt}
Please use the following format and follow that format strictly:
{format_instructions}
""",
Expand Down
1 change: 1 addition & 0 deletions data_generate/requirments.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
sotopia
convokit
rejson
evaluate
redis==3.5.3 # for step3 and step4
Expand Down
66 changes: 66 additions & 0 deletions data_generate/step0_create_new_inspirational_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import jsonlines
import csv
from tqdm import tqdm

def collect_social_iqa(inspirational_prompt_data):
with jsonlines.open('./social_src/social_iqa_train.jsonl', 'r') as f:
social_iqa_dataset = list(f)
for data in social_iqa_dataset:
inspirational_prompt_data.append({'prompt': data['context'], 'source': 'social_iqa'})
return inspirational_prompt_data


def collect_normbank(inspirational_prompt_data):
rows = []
with open("./social_src/NormBank.csv", 'r') as file:
csvreader = csv.reader(file)
header = next(csvreader)
for row in csvreader:
inspirational_prompt_data.append({'prompt': row[1], 'source': 'normbank'})
return inspirational_prompt_data


def collect_social_chemistry(inspirational_prompt_data):
rows = []
with open("./social_src/social-chem-101.v1.0.tsv", 'r') as file:
csvreader = csv.reader(file, delimiter="\t")
header = next(csvreader)
for row in csvreader:
inspirational_prompt_data.append({'prompt': row[-6], 'source': 'normbank'})
return inspirational_prompt_data

def collect_persuation_for_good(inspirational_prompt_data):
from convokit import Corpus, download
corpus = Corpus(filename=download("persuasionforgood-corpus"))
import pdb; pdb.set_trace()

def delete_sotopia_data(inspirational_prompt_data):
sotopia_prompts = []
with open("./social_src/inspirational_prompt_for_sotopia.csv", 'r') as file:
csvreader = csv.reader(file)
header = next(csvreader)
for row in csvreader:
sotopia_prompts.append(row[0])
for inspirational_prompt in tqdm(inspirational_prompt_data):
if inspirational_prompt['prompt'] in sotopia_prompts:
inspirational_prompt_data.remove(inspirational_prompt)
print(len(inspirational_prompt_data))
return inspirational_prompt_data


inspirational_prompt_data = []
#inspirational_prompt_data = collect_persuation_for_good(inspirational_prompt_data)
inspirational_prompt_data = collect_social_chemistry(inspirational_prompt_data)
inspirational_prompt_data = collect_social_iqa(inspirational_prompt_data)
inspirational_prompt_data = collect_normbank(inspirational_prompt_data)
inspirational_prompt_data = delete_sotopia_data(inspirational_prompt_data)

fieldnames = inspirational_prompt_data[0].keys()

# Write to a CSV file
with open('./social_src/inspirational_prompt.csv', 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
# Write the header (optional, but usually a good idea)
writer.writeheader()
for row in inspirational_prompt_data:
writer.writerow(row)
10 changes: 5 additions & 5 deletions data_generate/step1_generate_env_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def pydantics_to_csv(filename: str, data: list[T]) -> None:
#random.seed(41)

envs = EnvironmentProfile.find().all()
ins_prompts = pd.read_csv("./inspirational_prompt_for_env.csv")
ins_prompts = pd.read_csv("./social_src/inspirational_prompt.csv")
prompts = [prompt.strip().replace('\"', '') for prompt in ins_prompts["prompt"].tolist()]

# randomly choose 3 prompts
Expand All @@ -29,10 +29,10 @@ def pydantics_to_csv(filename: str, data: list[T]) -> None:
target_num = 500

for i in range(target_num):
sampled_envs = random.sample(envs, 5)
sampled_prompt = random.sample(prompts, 5)
sampled_examples.append(f"1.{sampled_envs[0].json()}\n2.{sampled_envs[1].json()}\n3.{sampled_envs[2].json()}\n4.{sampled_envs[3].json()}\n5.{sampled_envs[4].json()}")
sampled_prompts.append(f"1.{sampled_prompt[0]}\n2.{sampled_prompt[1]}\n3.{sampled_prompt[2]}\n4.{sampled_prompt[3]}\n5.{sampled_prompt[4]}")
sampled_envs = random.sample(envs, 1)
sampled_prompt = random.sample(prompts, 3)
sampled_examples.append(f"{sampled_envs[0].json()}")
sampled_prompts.append(f"1.{sampled_prompt[0]}\n2.{sampled_prompt[1]}\n3.{sampled_prompt[2]}")

assert len(sampled_examples) == target_num
assert len(sampled_prompts) == target_num
Expand Down

0 comments on commit 3312bef

Please sign in to comment.