diff --git a/.gitignore b/.gitignore index d5c2b8f7..e512ef7c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ AI-Adventure-2bb65e3a4e2f.json *.pyc *.pyo +data/text_adventures.txt +venv/ diff --git a/README.md b/README.md index f94b43d1..33780c9a 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,6 @@ # AIDungeon2 -## The model for AI Dungeon2 is temporarily unavailable to download due to cost. We're working on a solution! - +### The model for AI Dungeon2 is temporarily unavailable to download due to cost. We're working on a solution! Read more about AIDungeon2 and how it was built [here](https://pcc.cs.byu.edu/2019/11/21/ai-dungeon-2-creating-infinitely-generated-text-adventures-with-deep-learning-language-models/). diff --git a/data/build_training_data.py b/data/build_training_data.py index f53a03e8..6038e750 100644 --- a/data/build_training_data.py +++ b/data/build_training_data.py @@ -4,37 +4,65 @@ def load_tree(filename): - with open(filename, 'r') as fp: + with open(filename, "r") as fp: tree = json.load(fp) return tree + def remove_phrase(text): phrases = ["Years pass...", "Years pass"] for phrase in phrases: text = text.replace(phrase, "") return text + def make_stories(current_story, tree): stories = [] action = first_to_second_person(tree["action"]) action_list = action.split(" ") first_word = action_list[0] - if first_word[-1] == '.': + if first_word[-1] == ".": first_word = first_word[:-1] - dont_add_you = ["the", "another", "next", "in", "monday", "back", "a", "years", "one", - "two", "during", "months", "weeks", "seven", "three", "...", "twelve", - "four","five","six", "blackness...", "you", "no", "yes", "up", "down", "onward", ] + dont_add_you = [ + "the", + "another", + "next", + "in", + "monday", + "back", + "a", + "years", + "one", + "two", + "during", + "months", + "weeks", + "seven", + "three", + "...", + "twelve", + "four", + "five", + "six", + "blackness...", + "you", + "no", + "yes", + "up", + "down", + "onward", + ] if action[0] is '"': last_quote = action.rfind('"') - action = "You say " + action[:last_quote + 1] + action = "You say " + action[: last_quote + 1] elif first_word.lower() not in dont_add_you: action = "You " + action[0].lower() + action[1:] action = remove_phrase(action) result = remove_phrase(tree["result"]) - current_story += ("\n> " + action + "\n" + result) + current_story += "\n> " + action + "\n" + result action_results = tree["action_results"] if len(action_results) == 0 or action_results[0] is None: @@ -48,6 +76,7 @@ def make_stories(current_story, tree): return stories + def get_stories(filename): tree = load_tree(filename) stories = [] @@ -57,14 +86,13 @@ def get_stories(filename): output_file_path = "text_adventures.txt" -with open(output_file_path, 'w') as output_file: - filenames = ["stories/story" + str(i) + ".json" for i in range(0,93)] - #filenames = [] +with open(output_file_path, "w") as output_file: + filenames = ["stories/story" + str(i) + ".json" for i in range(0, 93)] + # filenames = [] for filename in filenames: tree = load_tree(filename) print('"' + tree["tree_id"] + '",') - filenames += ["stories/crowdsourcedstory" + str(i) + ".json" for i in range(0, 12)] stories = [] for filename in filenames: @@ -80,6 +108,3 @@ def get_stories(filename): print(len(raw_text)) output_file.write(raw_text) - - - diff --git a/data/make_reddit_data.py b/data/make_reddit_data.py index 218ac01e..964fc9a1 100644 --- a/data/make_reddit_data.py +++ b/data/make_reddit_data.py @@ -2,21 +2,21 @@ from story.utils import * import os -def load_stories(file): +def load_stories(file): - try: - with open(file) as fp: - stories = json.load(fp) - return stories - except: - with open(file) as fp: - stories = [] - for line in fp: - if len(line) > 10: - story = json.loads(line) - stories.append(story) - return stories + try: + with open(file) as fp: + stories = json.load(fp) + return stories + except: + with open(file) as fp: + stories = [] + for line in fp: + if len(line) > 10: + story = json.loads(line) + stories.append(story) + return stories def modify_story(story): @@ -32,10 +32,11 @@ def modify_story(story): else: return None + current = os.getcwd() files = os.listdir(current + "/writingprompts") output_file_path = "writing_prompts.txt" -with open(output_file_path, 'w') as output_file: +with open(output_file_path, "w") as output_file: filenames = ["writingprompts/" + file for file in files] cleaned_stories = [] for filename in filenames: diff --git a/data/mechturk.py b/data/mechturk.py index 8cc1fcdc..3ced0ecc 100644 --- a/data/mechturk.py +++ b/data/mechturk.py @@ -19,12 +19,13 @@ import json import os + def data_to_forest(filename): trees = [] rows = [] - with open(filename, newline='') as f: + with open(filename, newline="") as f: reader = csv.reader(f) for row in reader: rows.append(row) @@ -40,8 +41,8 @@ def data_to_forest(filename): while row_ind < len(rows): action_result = {} action_result["action"] = rows[row_ind][i] - if row_ind+1 < len(rows): - action_result["result"] = rows[row_ind+1][i] + if row_ind + 1 < len(rows): + action_result["result"] = rows[row_ind + 1][i] else: action_result["result"] = None action_result["action_results"] = [] @@ -60,11 +61,27 @@ def build_action_samples_helper(context, story_block, action_results, path, tree for i, action_result in enumerate(action_results): new_path = path[:] new_path.append(i) - if len(action_result["action_results"]) is 0 and action_result["result"] is not None: - row = [tree_id, "".join(str(x) for x in new_path), context, story_block, action_result["action"], action_result["result"]] + if ( + len(action_result["action_results"]) is 0 + and action_result["result"] is not None + ): + row = [ + tree_id, + "".join(str(x) for x in new_path), + context, + story_block, + action_result["action"], + action_result["result"], + ] samples.append(row) else: - sub_result = build_action_samples_helper(context, action_result["result"], action_result["action_results"], new_path, tree_id) + sub_result = build_action_samples_helper( + context, + action_result["result"], + action_result["action_results"], + new_path, + tree_id, + ) samples += sub_result return samples @@ -72,19 +89,38 @@ def build_action_samples_helper(context, story_block, action_results, path, tree def make_write_actions_batch(forest, filename): # Traverse to the bottom levels of each tree - with open(filename, mode='w', newline='') as file: - writer = csv.writer(file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) - writer.writerow(["tree_id", "path", "context", "story_block_1", "previous_action", "story_block_2"]) + with open(filename, mode="w", newline="") as file: + writer = csv.writer( + file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + writer.writerow( + [ + "tree_id", + "path", + "context", + "story_block_1", + "previous_action", + "story_block_2", + ] + ) for tree in forest: first_story_block = tree["first_story_block"] - samples = build_action_samples_helper(tree["context"], first_story_block, tree["action_results"], [], tree["tree_id"]) + samples = build_action_samples_helper( + tree["context"], + first_story_block, + tree["action_results"], + [], + tree["tree_id"], + ) for sample in samples: writer.writerow(sample) -def build_result_samples_helper(context, story_block, parent_action_result, path, tree_id): +def build_result_samples_helper( + context, story_block, parent_action_result, path, tree_id +): samples = [] action_results = parent_action_result["action_results"] @@ -93,10 +129,24 @@ def build_result_samples_helper(context, story_block, parent_action_result, path new_path = path[:] new_path.append(i) if action_result["result"] is None: - row = [tree_id, "".join(str(x) for x in new_path), context, story_block, parent_action_result["action"], parent_action_result["result"], action_result["action"]] + row = [ + tree_id, + "".join(str(x) for x in new_path), + context, + story_block, + parent_action_result["action"], + parent_action_result["result"], + action_result["action"], + ] samples.append(row) else: - sub_result = build_result_samples_helper(context, parent_action_result["result"], action_result, new_path, tree_id) + sub_result = build_result_samples_helper( + context, + parent_action_result["result"], + action_result, + new_path, + tree_id, + ) samples += sub_result return samples @@ -104,25 +154,44 @@ def build_result_samples_helper(context, story_block, parent_action_result, path def make_write_results_batch(forest, filename): - with open(filename, mode='w', newline='') as file: - writer = csv.writer(file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) - writer.writerow(["tree_id", "path", "context", "story_block_1", "previous_action_1", "story_block_2", "previous_action_2"]) + with open(filename, mode="w", newline="") as file: + writer = csv.writer( + file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + writer.writerow( + [ + "tree_id", + "path", + "context", + "story_block_1", + "previous_action_1", + "story_block_2", + "previous_action_2", + ] + ) for tree in forest: first_story_block = tree["first_story_block"] samples = [] for i, action_result in enumerate(tree["action_results"]): path = [i] - samples += build_result_samples_helper(tree["context"], first_story_block, action_result, path, tree["tree_id"]) + samples += build_result_samples_helper( + tree["context"], + first_story_block, + action_result, + path, + tree["tree_id"], + ) for sample in samples: writer.writerow(sample) def save_tree(tree, filename): - with open(filename, 'w') as fp: + with open(filename, "w") as fp: json.dump(tree, fp) + def save_forest(forest, forest_name): if not os.path.exists("./" + forest_name): @@ -130,11 +199,13 @@ def save_forest(forest, forest_name): for tree in forest: save_tree(tree, "./" + forest_name + "/" + tree["tree_id"] + ".json") + def load_tree(filename): - with open(filename, 'r') as fp: + with open(filename, "r") as fp: tree = json.load(fp) return tree + def load_forest(forest_name): files = os.listdir("./" + forest_name) @@ -148,7 +219,7 @@ def csv_to_dict(file): update_dict = {} field_names = [] - with open(file, newline='') as f: + with open(file, newline="") as f: reader = csv.reader(f) for row in reader: if len(update_dict) is 0: @@ -177,12 +248,15 @@ def update_forest_with_results(forest_name, update_file): current_action_results = tree for choice in update_dict["Input.path"][i]: choice_num = int(choice) - current_action_results = current_action_results["action_results"][choice_num] + current_action_results = current_action_results["action_results"][ + choice_num + ] current_action_results["result"] = update_dict["Answer.result"][i] return tree_dict.values() + def update_forest_with_actions(forest_name, update_file): update_dict = csv_to_dict(update_file) tree_dict = {} @@ -197,16 +271,28 @@ def update_forest_with_actions(forest_name, update_file): current_action_results = tree for choice in update_dict["Input.path"][i]: choice_num = int(choice) - current_action_results = current_action_results["action_results"][choice_num] - + current_action_results = current_action_results["action_results"][ + choice_num + ] current_action_results["action_results"].append( - {"action": update_dict["Answer.action_1"][i], "result": None , "action_results":[]}) + { + "action": update_dict["Answer.action_1"][i], + "result": None, + "action_results": [], + } + ) current_action_results["action_results"].append( - {"action": update_dict["Answer.action_2"][i], "result": None, "action_results": []}) + { + "action": update_dict["Answer.action_2"][i], + "result": None, + "action_results": [], + } + ) return tree_dict.values() + old_forest_name = "seed_forest_1.8" new_forest_name = "seed_forest_1.9" update_type = "results" diff --git a/data/scraper.py b/data/scraper.py index 75fb866c..4318ebe7 100644 --- a/data/scraper.py +++ b/data/scraper.py @@ -2,6 +2,7 @@ from selenium.webdriver.chrome.options import Options import time import json + """ format of tree is dict { @@ -19,59 +20,63 @@ } """ -class Scraper: +class Scraper: def __init__(self): chrome_options = Options() chrome_options.add_argument("--binary=/path/to/other/chrome/binary") chrome_options.add_argument("--incognito") chrome_options.add_argument("--window-size=1920x1080") exec_path = "/usr/bin/chromedriver" - self.driver = webdriver.Chrome(chrome_options=chrome_options, executable_path=exec_path) + self.driver = webdriver.Chrome( + chrome_options=chrome_options, executable_path=exec_path + ) self.max_depth = 10 - self.end_actions = {"End Game and Leave Comments", - "Click here to End the Game and Leave Comments", - "See How Well You Did (you can still back-page afterwards if you like)", - "You have died.", - "You have died", - "Epilogue", - "Save Game", - "Your quest might have been more successful...", - "5 - not the best, certainly not the worst", - "The End! (leave comments on game)", - "6 - it's worth every cent", - "You do not survive the journey to California", - "Quit the game.", - "7 - even better than Reeses' Cups®", - "8 - it will bring you enlightenment", - "End of game! Leave a comment!", - "Better luck next time", - "click here to continue", - "Rating And Leaving Comments", - "You do not survive your journey to California", - "Your Outlaw Career has come to an end", - "Thank you for taking the time to read my story", - "You have no further part in the story, End Game and Leave Comments", - '', - "You play no further part in this story. End Game and Leave Comments", - "drivers", - "Alas, poor Yorick, they slew you well", - "My heart bleeds for you", - "To End the Game and Leave Comments click here", - "Call it a day", - "Check the voicemail.", - "reset", - "There's nothing you can do anymore...it's over.", - "To Be Continued...", - "Thanks again for taking the time to read this", - "If you just want to escape this endless story you can do that by clicking here", - "Boo Hoo Hoo", - "End.", - "Pick up some money real quick", - "", - "Well you did live a decent amount of time in the Army", - "End Game", - "You have survived the Donner Party's journey to California!"} + self.end_actions = { + "End Game and Leave Comments", + "Click here to End the Game and Leave Comments", + "See How Well You Did (you can still back-page afterwards if you like)", + "You have died.", + "You have died", + "Epilogue", + "Save Game", + "Your quest might have been more successful...", + "5 - not the best, certainly not the worst", + "The End! (leave comments on game)", + "6 - it's worth every cent", + "You do not survive the journey to California", + "Quit the game.", + "7 - even better than Reeses' Cups®", + "8 - it will bring you enlightenment", + "End of game! Leave a comment!", + "Better luck next time", + "click here to continue", + "Rating And Leaving Comments", + "You do not survive your journey to California", + "Your Outlaw Career has come to an end", + "Thank you for taking the time to read my story", + "You have no further part in the story, End Game and Leave Comments", + "", + "You play no further part in this story. End Game and Leave Comments", + "drivers", + "Alas, poor Yorick, they slew you well", + "My heart bleeds for you", + "To End the Game and Leave Comments click here", + "Call it a day", + "Check the voicemail.", + "reset", + "There's nothing you can do anymore...it's over.", + "To Be Continued...", + "Thanks again for taking the time to read this", + "If you just want to escape this endless story you can do that by clicking here", + "Boo Hoo Hoo", + "End.", + "Pick up some money real quick", + "", + "Well you did live a decent amount of time in the Army", + "End Game", + "You have survived the Donner Party's journey to California!", + } self.texts = set() def GoToURL(self, url): @@ -92,7 +97,7 @@ def GoBack(self): time.sleep(0.2) def ClickAction(self, links, action_num): - links[action_num+4].click() + links[action_num + 4].click() time.sleep(0.2) def GetActions(self): @@ -110,7 +115,7 @@ def BuildTreeHelper(self, parent_story, action_num, depth, old_actions): action_result["action"] = action links = self.GetLinks() - if action_num+4 >= len(links): + if action_num + 4 >= len(links): return None self.ClickAction(links, action_num) @@ -136,7 +141,6 @@ def BuildTreeHelper(self, parent_story, action_num, depth, old_actions): self.GoBack() return action_result - def BuildStoryTree(self, url): scraper.GoToURL(url) text = scraper.GetText() @@ -157,110 +161,113 @@ def BuildStoryTree(self, url): return story_dict + def save_tree(tree, filename): - with open(filename, 'w') as fp: + with open(filename, "w") as fp: json.dump(tree, fp) + scraper = Scraper() urls = [ - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10638", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=11246", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=54639", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7397", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8041", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=11545", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7393", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=13875", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=37696", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=31013", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=45375", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=41698", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10634", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=42204", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=6823", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=18988", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10359", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=5466", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=28030", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=56515", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7480", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=11274", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=53134", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=17306", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=470", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8041", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=23928", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10183", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=45866", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=60232", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=6376", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=36791", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=60128", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=52961", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=54011", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=34838", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=13349", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8038", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=56742", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=48393", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=53356", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10872", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7393", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=31013", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=43910", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=53837", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8098", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=55043", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=28838", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=11906", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8040", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=2280", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=31014", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=43744", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=44543", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=56753", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=36594", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=15424", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8035", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10524", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=14899", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=9361", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=28030", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=49642", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=43573", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=38025", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7480", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7567", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=60747", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10359", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=31353", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=13875", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=56501", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=38542", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=42204", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=43993", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=1153", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=24743", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=57114", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=52887", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=21879", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=16489", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=53186", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=34849", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=26752", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7094", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8557", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=45225", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=4720", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=51926", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=45375", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=27234", - "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=60772"] + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10638", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=11246", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=54639", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7397", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8041", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=11545", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7393", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=13875", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=37696", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=31013", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=45375", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=41698", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10634", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=42204", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=6823", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=18988", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10359", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=5466", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=28030", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=56515", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7480", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=11274", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=53134", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=17306", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=470", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8041", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=23928", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10183", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=45866", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=60232", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=6376", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=36791", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=60128", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=52961", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=54011", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=34838", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=13349", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8038", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=56742", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=48393", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=53356", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10872", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7393", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=31013", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=43910", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=53837", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8098", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=55043", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=28838", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=11906", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8040", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=2280", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=31014", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=43744", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=44543", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=56753", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=36594", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=15424", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8035", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10524", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=14899", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=9361", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=28030", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=49642", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=43573", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=38025", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7480", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7567", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=60747", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=10359", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=31353", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=13875", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=56501", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=38542", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=42204", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=43993", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=1153", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=24743", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=57114", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=52887", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=21879", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=16489", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=53186", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=34849", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=26752", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=7094", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=8557", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=45225", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=4720", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=51926", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=45375", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=27234", + "http://chooseyourstory.com/story/viewer/default.aspx?StoryId=60772", +] for i in range(50, len(urls)): print("****** Extracting Adventure ", urls[i], " ***********") tree = scraper.BuildStoryTree(urls[i]) - save_tree(tree, "stories/story" + str(41+i) + ".json") + save_tree(tree, "stories/story" + str(41 + i) + ".json") print("done") diff --git a/data/sheet_to_story.py b/data/sheet_to_story.py index ab83d54d..4aab98f1 100644 --- a/data/sheet_to_story.py +++ b/data/sheet_to_story.py @@ -1,4 +1,3 @@ - """ format of tree is dict { @@ -19,12 +18,13 @@ import json import os + def data_to_forest(filename): trees = [] rows = [] - with open(filename, newline='') as f: + with open(filename, newline="") as f: reader = csv.reader(f) for row in reader: rows.append(row) @@ -40,8 +40,8 @@ def data_to_forest(filename): while row_ind < len(rows): action_result = {} action_result["action"] = rows[row_ind][i] - if row_ind+1 < len(rows): - action_result["result"] = rows[row_ind+1][i] + if row_ind + 1 < len(rows): + action_result["result"] = rows[row_ind + 1][i] else: action_result["result"] = None action_result["action_results"] = [] @@ -60,11 +60,27 @@ def build_action_samples_helper(context, story_block, action_results, path, tree for i, action_result in enumerate(action_results): new_path = path[:] new_path.append(i) - if len(action_result["action_results"]) is 0 and action_result["result"] is not None: - row = [tree_id, "".join(str(x) for x in new_path), context, story_block, action_result["action"], action_result["result"]] + if ( + len(action_result["action_results"]) is 0 + and action_result["result"] is not None + ): + row = [ + tree_id, + "".join(str(x) for x in new_path), + context, + story_block, + action_result["action"], + action_result["result"], + ] samples.append(row) else: - sub_result = build_action_samples_helper(context, action_result["result"], action_result["action_results"], new_path, tree_id) + sub_result = build_action_samples_helper( + context, + action_result["result"], + action_result["action_results"], + new_path, + tree_id, + ) samples += sub_result return samples @@ -72,19 +88,38 @@ def build_action_samples_helper(context, story_block, action_results, path, tree def make_write_actions_batch(forest, filename): # Traverse to the bottom levels of each tree - with open(filename, mode='w', newline='') as file: - writer = csv.writer(file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) - writer.writerow(["tree_id", "path", "context", "story_block_1", "previous_action", "story_block_2"]) + with open(filename, mode="w", newline="") as file: + writer = csv.writer( + file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + writer.writerow( + [ + "tree_id", + "path", + "context", + "story_block_1", + "previous_action", + "story_block_2", + ] + ) for tree in forest: first_story_block = tree["first_story_block"] - samples = build_action_samples_helper(tree["context"], first_story_block, tree["action_results"], [], tree["tree_id"]) + samples = build_action_samples_helper( + tree["context"], + first_story_block, + tree["action_results"], + [], + tree["tree_id"], + ) for sample in samples: writer.writerow(sample) -def build_result_samples_helper(context, story_block, parent_action_result, path, tree_id): +def build_result_samples_helper( + context, story_block, parent_action_result, path, tree_id +): samples = [] action_results = parent_action_result["action_results"] @@ -93,10 +128,24 @@ def build_result_samples_helper(context, story_block, parent_action_result, path new_path = path[:] new_path.append(i) if action_result["result"] is None: - row = [tree_id, "".join(str(x) for x in new_path), context, story_block, parent_action_result["action"], parent_action_result["result"], action_result["action"]] + row = [ + tree_id, + "".join(str(x) for x in new_path), + context, + story_block, + parent_action_result["action"], + parent_action_result["result"], + action_result["action"], + ] samples.append(row) else: - sub_result = build_result_samples_helper(context, parent_action_result["result"], action_result, new_path, tree_id) + sub_result = build_result_samples_helper( + context, + parent_action_result["result"], + action_result, + new_path, + tree_id, + ) samples += sub_result return samples @@ -104,25 +153,44 @@ def build_result_samples_helper(context, story_block, parent_action_result, path def make_write_results_batch(forest, filename): - with open(filename, mode='w', newline='') as file: - writer = csv.writer(file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) - writer.writerow(["tree_id", "path", "context", "story_block_1", "previous_action_1", "story_block_2", "previous_action_2"]) + with open(filename, mode="w", newline="") as file: + writer = csv.writer( + file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + writer.writerow( + [ + "tree_id", + "path", + "context", + "story_block_1", + "previous_action_1", + "story_block_2", + "previous_action_2", + ] + ) for tree in forest: first_story_block = tree["first_story_block"] samples = [] for i, action_result in enumerate(tree["action_results"]): path = [i] - samples += build_result_samples_helper(tree["context"], first_story_block, action_result, path, tree["tree_id"]) + samples += build_result_samples_helper( + tree["context"], + first_story_block, + action_result, + path, + tree["tree_id"], + ) for sample in samples: writer.writerow(sample) def save_tree(tree, filename): - with open(filename, 'w') as fp: + with open(filename, "w") as fp: json.dump(tree, fp) + def save_forest(forest, forest_name): if not os.path.exists("./" + forest_name): @@ -130,11 +198,13 @@ def save_forest(forest, forest_name): for tree in forest: save_tree(tree, "./" + forest_name + "/" + tree["tree_id"] + ".json") + def load_tree(filename): - with open(filename, 'r') as fp: + with open(filename, "r") as fp: tree = json.load(fp) return tree + def load_forest(forest_name): files = os.listdir("./" + forest_name) @@ -148,7 +218,7 @@ def csv_to_dict(file): update_dict = {} field_names = [] - with open(file, newline='') as f: + with open(file, newline="") as f: reader = csv.reader(f) for row in reader: if len(update_dict) is 0: @@ -177,12 +247,15 @@ def update_forest_with_results(forest_name, update_file): current_action_results = tree for choice in update_dict["Input.path"][i]: choice_num = int(choice) - current_action_results = current_action_results["action_results"][choice_num] + current_action_results = current_action_results["action_results"][ + choice_num + ] current_action_results["result"] = update_dict["Answer.result"][i] return tree_dict.values() + def update_forest_with_actions(forest_name, update_file): update_dict = csv_to_dict(update_file) tree_dict = {} @@ -197,17 +270,29 @@ def update_forest_with_actions(forest_name, update_file): current_action_results = tree for choice in update_dict["Input.path"][i]: choice_num = int(choice) - current_action_results = current_action_results["action_results"][choice_num] - + current_action_results = current_action_results["action_results"][ + choice_num + ] current_action_results["action_results"].append( - {"action": update_dict["Answer.action_1"][i], "result": None , "action_results":[]}) + { + "action": update_dict["Answer.action_1"][i], + "result": None, + "action_results": [], + } + ) current_action_results["action_results"].append( - {"action": update_dict["Answer.action_2"][i], "result": None, "action_results": []}) + { + "action": update_dict["Answer.action_2"][i], + "result": None, + "action_results": [], + } + ) return tree_dict.values() + tree = data_to_forest("upwork.csv") for i, story in enumerate(tree): save_tree(story, "crowdsourcedstory" + str(i) + ".json") -print("done") \ No newline at end of file +print("done") diff --git a/generator/gpt2/download_model.py b/generator/gpt2/download_model.py index 56d4e767..0fbc440b 100644 --- a/generator/gpt2/download_model.py +++ b/generator/gpt2/download_model.py @@ -4,24 +4,36 @@ from tqdm import tqdm if len(sys.argv) != 2: - print('You must enter the model name as a parameter, e.g.: download_model.py 124M') + print("You must enter the model name as a parameter, e.g.: download_model.py 124M") sys.exit(1) model = sys.argv[1] -subdir = os.path.join('models', model) +subdir = os.path.join("models", model) if not os.path.exists(subdir): os.makedirs(subdir) -subdir = subdir.replace('\\','/') # needed for Windows +subdir = subdir.replace("\\", "/") # needed for Windows -for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']: +for filename in [ + "checkpoint", + "encoder.json", + "hparams.json", + "model.ckpt.data-00000-of-00001", + "model.ckpt.index", + "model.ckpt.meta", + "vocab.bpe", +]: - r = requests.get("https://storage.googleapis.com/gpt-2/" + subdir + "/" + filename, stream=True) + r = requests.get( + "https://storage.googleapis.com/gpt-2/" + subdir + "/" + filename, stream=True + ) - with open(os.path.join(subdir, filename), 'wb') as f: + with open(os.path.join(subdir, filename), "wb") as f: file_size = int(r.headers["content-length"]) chunk_size = 1000 - with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: + with tqdm( + ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True + ) as pbar: # 1k for chunk_size, since Ethernet packet size is around 1500 bytes for chunk in r.iter_content(chunk_size=chunk_size): f.write(chunk) diff --git a/generator/gpt2/gpt2_generator.py b/generator/gpt2/gpt2_generator.py index fdb0fe7a..e355b750 100644 --- a/generator/gpt2/gpt2_generator.py +++ b/generator/gpt2/gpt2_generator.py @@ -1,17 +1,19 @@ from story.utils import * import warnings + warnings.filterwarnings("ignore") import os import tensorflow as tf + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) from generator.gpt2.src import sample, encoder, model import json import numpy as np -class GPT2Generator: - def __init__(self, generate_num=60, temperature=0.4, top_k=40, top_p=0.9): - self.generate_num=generate_num +class GPT2Generator: + def __init__(self, generate_num=60, temperature=0.4, top_k=40, top_p=0.9): + self.generate_num = generate_num self.temp = temperature self.top_k = top_k self.top_p = top_p @@ -26,7 +28,7 @@ def __init__(self, generate_num=60, temperature=0.4, top_k=40, top_p=0.9): self.enc = encoder.get_encoder(self.model_name, models_dir) hparams = model.default_hparams() - with open(os.path.join(models_dir, self.model_name, 'hparams.json')) as f: + with open(os.path.join(models_dir, self.model_name, "hparams.json")) as f: hparams.override_from_dict(json.load(f)) seed = np.random.randint(0, 100000) @@ -35,13 +37,16 @@ def __init__(self, generate_num=60, temperature=0.4, top_k=40, top_p=0.9): self.sess = tf.compat.v1.Session(config=config) self.context = tf.placeholder(tf.int32, [self.batch_size, None]) - #np.random.seed(seed) + # np.random.seed(seed) # tf.set_random_seed(seed) self.output = sample.sample_sequence( - hparams=hparams, length=self.generate_num, + hparams=hparams, + length=self.generate_num, context=self.context, batch_size=self.batch_size, - temperature=temperature, top_k=top_k, top_p=top_p + temperature=temperature, + top_k=top_k, + top_p=top_p, ) saver = tf.train.Saver() @@ -54,8 +59,8 @@ def prompt_replace(self, prompt): if len(prompt) > 0 and prompt[-1] == " ": prompt = prompt[:-1] - #prompt = second_to_first_person(prompt) - + # prompt = second_to_first_person(prompt) + # print("\n\nAFTER PROMPT_REPLACE") # print(repr(prompt)) return prompt @@ -72,7 +77,7 @@ def result_replace(self, result): result = result.replace("#", "") result = result.replace("*", "") result = result.replace("\n\n", "\n") - #result = first_to_second_person(result) + # result = first_to_second_person(result) result = remove_profanity(result) if not first_letter_capitalized: @@ -88,15 +93,17 @@ def generate_raw(self, prompt): context_tokens = self.enc.encode(prompt) generated = 0 for _ in range(self.samples // self.batch_size): - out = self.sess.run(self.output, feed_dict={ - self.context: [context_tokens for _ in range(self.batch_size)] - })[:, len(context_tokens):] + out = self.sess.run( + self.output, + feed_dict={ + self.context: [context_tokens for _ in range(self.batch_size)] + }, + )[:, len(context_tokens) :] for i in range(self.batch_size): generated += 1 text = self.enc.decode(out[i]) return text - def generate(self, prompt, options=None, seed=1): debug_print = False diff --git a/generator/gpt2/src/encoder.py b/generator/gpt2/src/encoder.py index 5f52e723..198f19d1 100644 --- a/generator/gpt2/src/encoder.py +++ b/generator/gpt2/src/encoder.py @@ -5,6 +5,7 @@ import regex as re from functools import lru_cache + @lru_cache() def bytes_to_unicode(): """ @@ -16,17 +17,22 @@ def bytes_to_unicode(): To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) cs = bs[:] n = 0 - for b in range(2**8): + for b in range(2 ** 8): if b not in bs: bs.append(b) - cs.append(2**8+n) + cs.append(2 ** 8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) + def get_pairs(word): """Return set of symbol pairs in a word. @@ -39,18 +45,21 @@ def get_pairs(word): prev_char = char return pairs + class Encoder: - def __init__(self, encoder, bpe_merges, errors='replace'): + def __init__(self, encoder, bpe_merges, errors="replace"): self.encoder = encoder - self.decoder = {v:k for k,v in self.encoder.items()} - self.errors = errors # how to handle errors in decoding + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.cache = {} # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions - self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + self.pat = re.compile( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) def bpe(self, token): if token in self.cache: @@ -62,7 +71,7 @@ def bpe(self, token): return token while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) if bigram not in self.bpe_ranks: break first, second = bigram @@ -77,8 +86,8 @@ def bpe(self, token): new_word.extend(word[i:]) break - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) i += 2 else: new_word.append(word[i]) @@ -89,29 +98,33 @@ def bpe(self, token): break else: pairs = get_pairs(word) - word = ' '.join(word) + word = " ".join(word) self.cache[token] = word return word def encode(self, text): bpe_tokens = [] for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) return bpe_tokens def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + "utf-8", errors=self.errors + ) return text + def get_encoder(model_name, models_dir): - with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f: + with open(os.path.join(models_dir, model_name, "encoder.json"), "r") as f: encoder = json.load(f) - with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f: + with open( + os.path.join(models_dir, model_name, "vocab.bpe"), "r", encoding="utf-8" + ) as f: bpe_data = f.read() - bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] - return Encoder( - encoder=encoder, - bpe_merges=bpe_merges, - ) + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] + return Encoder(encoder=encoder, bpe_merges=bpe_merges,) diff --git a/generator/gpt2/src/model.py b/generator/gpt2/src/model.py index 230b83cc..55783629 100644 --- a/generator/gpt2/src/model.py +++ b/generator/gpt2/src/model.py @@ -2,14 +2,10 @@ import tensorflow as tf from tensorflow.contrib.training import HParams + def default_hparams(): - return HParams( - n_vocab=0, - n_ctx=1024, - n_embd=768, - n_head=12, - n_layer=12, - ) + return HParams(n_vocab=0, n_ctx=1024, n_embd=768, n_head=12, n_layer=12,) + def shape_list(x): """Deal with dynamic shape in tensorflow cleanly.""" @@ -17,50 +13,64 @@ def shape_list(x): dynamic = tf.shape(x) return [dynamic[i] if s is None else s for i, s in enumerate(static)] + def softmax(x, axis=-1): x = x - tf.reduce_max(x, axis=axis, keepdims=True) ex = tf.exp(x) return ex / tf.reduce_sum(ex, axis=axis, keepdims=True) + def gelu(x): - return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3)))) + return 0.5 * x * (1 + tf.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))) + def norm(x, scope, *, axis=-1, epsilon=1e-5): """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" with tf.variable_scope(scope): n_state = x.shape[-1].value - g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1)) - b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0)) + g = tf.get_variable("g", [n_state], initializer=tf.constant_initializer(1)) + b = tf.get_variable("b", [n_state], initializer=tf.constant_initializer(0)) u = tf.reduce_mean(x, axis=axis, keepdims=True) - s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True) + s = tf.reduce_mean(tf.square(x - u), axis=axis, keepdims=True) x = (x - u) * tf.rsqrt(s + epsilon) - x = x*g + b + x = x * g + b return x + def split_states(x, n): """Reshape the last dimension of x into [n, x.shape[-1]/n].""" *start, m = shape_list(x) - return tf.reshape(x, start + [n, m//n]) + return tf.reshape(x, start + [n, m // n]) + def merge_states(x): """Smash the last two dimensions of x into a single dimension.""" *start, a, b = shape_list(x) - return tf.reshape(x, start + [a*b]) + return tf.reshape(x, start + [a * b]) + def conv1d(x, scope, nf, *, w_init_stdev=0.02): with tf.variable_scope(scope): *start, nx = shape_list(x) - w = tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)) - b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0)) - c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf]) + w = tf.get_variable( + "w", + [1, nx, nf], + initializer=tf.random_normal_initializer(stddev=w_init_stdev), + ) + b = tf.get_variable("b", [nf], initializer=tf.constant_initializer(0)) + c = tf.reshape( + tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])) + b, + start + [nf], + ) return c + def attention_mask(nd, ns, *, dtype): """1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. """ - i = tf.range(nd)[:,None] + i = tf.range(nd)[:, None] j = tf.range(ns) m = i >= j - ns + nd return tf.cast(m, dtype) @@ -70,7 +80,9 @@ def attn(x, scope, n_state, *, past, hparams): assert x.shape.ndims == 3 # Should be [batch, sequence, features] assert n_state % hparams.n_head == 0 if past is not None: - assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v] + assert ( + past.shape.ndims == 5 + ) # Should be [batch, 2, heads, sequence, features], where 2 is [k, v] def split_heads(x): # From [batch, sequence, features] to [batch, heads, sequence, features] @@ -85,7 +97,7 @@ def mask_attn_weights(w): _, _, nd, ns = shape_list(w) b = attention_mask(nd, ns, dtype=w.dtype) b = tf.reshape(b, [1, 1, nd, ns]) - w = w*b - tf.cast(1e10, w.dtype)*(1-b) + w = w * b - tf.cast(1e10, w.dtype) * (1 - b) return w def multihead_attn(q, k, v): @@ -99,7 +111,7 @@ def multihead_attn(q, k, v): return a with tf.variable_scope(scope): - c = conv1d(x, 'c_attn', n_state*3) + c = conv1d(x, "c_attn", n_state * 3) q, k, v = map(split_heads, tf.split(c, 3, axis=2)) present = tf.stack([k, v], axis=1) if past is not None: @@ -108,35 +120,45 @@ def multihead_attn(q, k, v): v = tf.concat([pv, v], axis=-2) a = multihead_attn(q, k, v) a = merge_heads(a) - a = conv1d(a, 'c_proj', n_state) + a = conv1d(a, "c_proj", n_state) return a, present def mlp(x, scope, n_state, *, hparams): with tf.variable_scope(scope): nx = x.shape[-1].value - h = gelu(conv1d(x, 'c_fc', n_state)) - h2 = conv1d(h, 'c_proj', nx) + h = gelu(conv1d(x, "c_fc", n_state)) + h2 = conv1d(h, "c_proj", nx) return h2 def block(x, scope, *, past, hparams): with tf.variable_scope(scope): nx = x.shape[-1].value - a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams) + a, present = attn(norm(x, "ln_1"), "attn", nx, past=past, hparams=hparams) x = x + a - m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams) + m = mlp(norm(x, "ln_2"), "mlp", nx * 4, hparams=hparams) x = x + m return x, present + def past_shape(*, hparams, batch_size=None, sequence=None): - return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head] + return [ + batch_size, + hparams.n_layer, + 2, + hparams.n_head, + sequence, + hparams.n_embd // hparams.n_head, + ] + def expand_tile(value, size): """Add a new axis of given size.""" - value = tf.convert_to_tensor(value, name='value') + value = tf.convert_to_tensor(value, name="value") ndims = value.shape.ndims - return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims) + return tf.tile(tf.expand_dims(value, axis=0), [size] + [1] * ndims) + def positions_for(tokens, past_length): batch_size = tf.shape(tokens)[0] @@ -144,31 +166,39 @@ def positions_for(tokens, past_length): return expand_tile(past_length + tf.range(nsteps), batch_size) -def model(hparams, X, past=None, scope='model', reuse=False): +def model(hparams, X, past=None, scope="model", reuse=False): with tf.variable_scope(scope, reuse=reuse): results = {} batch, sequence = shape_list(X) - wpe = tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd], - initializer=tf.random_normal_initializer(stddev=0.01)) - wte = tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd], - initializer=tf.random_normal_initializer(stddev=0.02)) + wpe = tf.get_variable( + "wpe", + [hparams.n_ctx, hparams.n_embd], + initializer=tf.random_normal_initializer(stddev=0.01), + ) + wte = tf.get_variable( + "wte", + [hparams.n_vocab, hparams.n_embd], + initializer=tf.random_normal_initializer(stddev=0.02), + ) past_length = 0 if past is None else tf.shape(past)[-2] h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length)) # Transformer presents = [] - pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer + pasts = ( + tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer + ) assert len(pasts) == hparams.n_layer for layer, past in enumerate(pasts): - h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) + h, present = block(h, "h%d" % layer, past=past, hparams=hparams) presents.append(present) - results['present'] = tf.stack(presents, axis=1) - h = norm(h, 'ln_f') + results["present"] = tf.stack(presents, axis=1) + h = norm(h, "ln_f") # Language model loss. Do tokens = 2: - similarity = get_similarity(story_manager.story.results[-1], story_manager.story.results[-2]) + similarity = get_similarity( + story_manager.story.results[-1], story_manager.story.results[-2] + ) if similarity > 0.9: story_manager.story.actions = story_manager.story.actions[:-1] story_manager.story.results = story_manager.story.results[:-1] - console_print("Woops that action caused the model to start looping. Try a different action to prevent that.") + console_print( + "Woops that action caused the model to start looping. Try a different action to prevent that." + ) continue if player_won(result): @@ -212,9 +241,11 @@ def play_aidungeon_2(): console_print(result) console_print("YOU DIED. GAME OVER") console_print("\nOptions:") - console_print('0) Start a new game') - console_print('1) "I\'m not dead yet!" (If you didn\'t actually die) ') - console_print('Which do you choose? ') + console_print("0) Start a new game") + console_print( + "1) \"I'm not dead yet!\" (If you didn't actually die) " + ) + console_print("Which do you choose? ") choice = get_num_options(2) if choice == 0: break @@ -226,5 +257,5 @@ def play_aidungeon_2(): console_print(result) -if __name__ == '__main__': +if __name__ == "__main__": play_aidungeon_2() diff --git a/play_dm.py b/play_dm.py index 59758e2f..3a2a3ae0 100644 --- a/play_dm.py +++ b/play_dm.py @@ -4,16 +4,18 @@ from story.utils import * from play import * import time, sys, os + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" -class AIPlayer: +class AIPlayer: def __init__(self, generator): self.generator = generator def get_action(self, prompt): return self.generator.generate_raw(prompt) + def play_dm(): console_print("Initializing AI Dungeon DM Mode") @@ -35,15 +37,11 @@ def play_dm(): action = action.split("\n")[0] punc = action.rfind(".") if punc > 0: - action = action[:punc+1] + action = action[: punc + 1] shown_action = "> You" + action console_print(second_to_first_person(shown_action)) story_manager.act(action) - - -if __name__ == '__main__': +if __name__ == "__main__": play_dm() - - diff --git a/story/story_manager.py b/story/story_manager.py index 0f35b152..daa7aa29 100644 --- a/story/story_manager.py +++ b/story/story_manager.py @@ -5,9 +5,11 @@ import subprocess import os -class Story(): - def __init__(self, story_start, context ="", seed=None, game_state=None, upload_story=False): +class Story: + def __init__( + self, story_start, context="", seed=None, game_state=None, upload_story=False + ): self.story_start = story_start self.context = context self.rating = -1 @@ -34,8 +36,9 @@ def __del__(self): if self.upload_story: self.save_to_storage() console_print("Game saved.") - console_print("To load the game, type 'load' and enter the following ID: " + self.uuid) - + console_print( + "To load the game, type 'load' and enter the following ID: " + self.uuid + ) def init_from_dict(self, story_dict): self.story_start = story_dict["story_start"] @@ -53,7 +56,6 @@ def init_from_dict(self, story_dict): else: self.rating = -1 - def initialize_from_json(self, json_string): story_dict = json.loads(json_string) self.init_from_dict(story_dict) @@ -72,7 +74,7 @@ def latest_result(self): while mem_ind > 0: if len(self.results) >= mem_ind: - latest_result += (self.actions[-mem_ind] + self.results[-mem_ind]) + latest_result += self.actions[-mem_ind] + self.results[-mem_ind] mem_ind -= 1 @@ -113,22 +115,25 @@ def load_from_local(self, save_name): file_name = "AIDungeonSave_" + save_name + ".json" print("Save ID that can be used to load game is: ", self.uuid) - with open(file_name, 'r') as fp: + with open(file_name, "r") as fp: game = json.load(fp) self.init_from_dict(game) def save_to_storage(self): self.uuid = str(uuid.uuid1()) - story_json = self.to_json() file_name = "story" + str(self.uuid) + ".json" f = open(file_name, "w") f.write(story_json) f.close() - FNULL = open(os.devnull, 'w') - p = Popen(['gsutil', 'cp', file_name, 'gs://aidungeonstories'], stdout=FNULL, stderr=subprocess.STDOUT) + FNULL = open(os.devnull, "w") + p = Popen( + ["gsutil", "cp", file_name, "gs://aidungeonstories"], + stdout=FNULL, + stderr=subprocess.STDOUT, + ) return self.uuid def load_from_storage(self, story_id): @@ -139,7 +144,7 @@ def load_from_storage(self, story_id): exists = os.path.isfile(file_name) if exists: - with open(file_name, 'r') as fp: + with open(file_name, "r") as fp: game = json.load(fp) self.init_from_dict(game) return str(self) @@ -147,16 +152,22 @@ def load_from_storage(self, story_id): return "Error save not found." -class StoryManager(): - +class StoryManager: def __init__(self, generator): self.generator = generator self.story = None - def start_new_story(self, story_prompt, context="", game_state=None, upload_story=False): + def start_new_story( + self, story_prompt, context="", game_state=None, upload_story=False + ): block = self.generator.generate(context + story_prompt) block = cut_trailing_sentence(block) - self.story = Story(context + story_prompt + block, context=context, game_state=game_state, upload_story=upload_story) + self.story = Story( + context + story_prompt + block, + context=context, + game_state=game_state, + upload_story=upload_story, + ) return self.story def load_new_story(self, story_id): @@ -166,7 +177,7 @@ def load_new_story(self, story_id): exists = os.path.isfile(file_name) if exists: - with open(file_name, 'r') as fp: + with open(file_name, "r") as fp: game = json.load(fp) self.story = Story("") self.story.init_from_dict(game) @@ -190,7 +201,6 @@ def story_context(self): class UnconstrainedStoryManager(StoryManager): - def act(self, action_choice): result = self.generate_result(action_choice) @@ -203,7 +213,6 @@ def generate_result(self, action): class ConstrainedStoryManager(StoryManager): - def __init__(self, generator, action_verbs_key="classic"): super().__init__(generator) self.action_phrases = get_action_verbs(action_verbs_key) @@ -211,7 +220,9 @@ def __init__(self, generator, action_verbs_key="classic"): self.cacher = None self.seed = None - def enable_caching(self, credentials_file=None, seed=0, bucket_name="dungeon-cache"): + def enable_caching( + self, credentials_file=None, seed=0, bucket_name="dungeon-cache" + ): self.cache = True self.cacher = Cacher(credentials_file, bucket_name) self.seed = seed @@ -220,7 +231,9 @@ def start_new_story(self, story_prompt, context="", game_state=None): if self.cache: return self.start_new_story_cache(story_prompt, game_state=game_state) else: - return super().start_new_story(story_prompt, context=context, game_state=game_state) + return super().start_new_story( + story_prompt, context=context, game_state=game_state + ) def start_new_story_generate(self, story_prompt, game_state=None): super().start_new_story(story_prompt, game_state=game_state) @@ -235,7 +248,9 @@ def start_new_story_cache(self, story_prompt, game_state=None): self.story = Story(story_start, seed=self.seed) self.story.possible_action_results = self.get_action_results() else: - story_start = self.start_new_story_generate(story_prompt, game_state=game_state) + story_start = self.start_new_story_generate( + story_prompt, game_state=game_state + ) self.story.seed = self.seed self.cacher.cache_file(self.seed, [], story_start, "story") @@ -249,7 +264,9 @@ def get_possible_actions(self): if self.story.possible_action_results is None: self.story.possible_action_results = self.get_action_results() - return [action_result[0] for action_result in self.story.possible_action_results] + return [ + action_result[0] for action_result in self.story.possible_action_results + ] def act(self, action_choice_str): @@ -276,11 +293,16 @@ def get_action_results(self): return self.get_action_results_generate() def get_action_results_generate(self): - action_results = [self.generate_action_result(self.story_context(), phrase) for phrase in self.action_phrases] + action_results = [ + self.generate_action_result(self.story_context(), phrase) + for phrase in self.action_phrases + ] return action_results def get_action_results_cache(self): - response = self.cacher.retrieve_from_cache(self.story.seed, self.story.choices, "choices") + response = self.cacher.retrieve_from_cache( + self.story.seed, self.story.choices, "choices" + ) if response is not None: print("Retrieved from cache") @@ -289,11 +311,15 @@ def get_action_results_cache(self): print("Didn't receive from cache") action_results = self.get_action_results_generate() response = json.dumps(action_results) - self.cacher.cache_file(self.story.seed, self.story.choices, response, "choices") + self.cacher.cache_file( + self.story.seed, self.story.choices, response, "choices" + ) return action_results def generate_action_result(self, prompt, phrase, options=None): - action_result = phrase + " " + self.generator.generate(prompt + " " + phrase, options) + action_result = ( + phrase + " " + self.generator.generate(prompt + " " + phrase, options) + ) action, result = split_first_sentence(action_result) return action, result diff --git a/story/utils.py b/story/utils.py index b1ef2cf2..ad2aee80 100644 --- a/story/utils.py +++ b/story/utils.py @@ -1,4 +1,4 @@ - # coding: utf-8 +# coding: utf-8 import re import yaml from difflib import SequenceMatcher @@ -6,11 +6,13 @@ YAML_FILE = "story/story_data.yaml" from profanityfilter import ProfanityFilter + with open("story/extra_censored_words.txt", "r") as f: more_words = [l.replace("\n", "") for l in f.readlines()] pf = ProfanityFilter(extra_censor_list=more_words) + def console_print(text, width=75): last_newline = 0 i = 0 @@ -25,9 +27,11 @@ def console_print(text, width=75): i += 1 print(text) + def get_similarity(a, b): return SequenceMatcher(None, a, b).ratio() + def get_num_options(num): while True: @@ -52,13 +56,25 @@ def player_died(text): # if len(matches) > 0: # return True - dead_phrases = ["you die", "You die", "you died", "you are dead", "You died", "You are dead", "You're dead", - "you're dead", "you have died", "You have died", "you bleed out"] + dead_phrases = [ + "you die", + "You die", + "you died", + "you are dead", + "You died", + "You are dead", + "You're dead", + "you're dead", + "you have died", + "You have died", + "you bleed out", + ] for phrase in dead_phrases: if phrase in text: return True return False + def player_won(text): won_phrases = ["live happily ever after", "you live forever"] @@ -67,6 +83,7 @@ def player_won(text): return True return False + def remove_profanity(text): return pf.censor(text) @@ -79,40 +96,47 @@ def cut_trailing_quotes(text): final_ind = text.rfind('"') return text[:final_ind] - + def split_first_sentence(text): - first_period = text.find('.') - first_exclamation = text.find('!') - + first_period = text.find(".") + first_exclamation = text.find("!") + if first_exclamation < first_period and first_exclamation > 0: - split_point = first_exclamation+1 + split_point = first_exclamation + 1 elif first_period > 0: - split_point = first_period+1 + split_point = first_period + 1 else: split_point = text[0:20] - + return text[0:split_point], text[split_point:] + def cut_trailing_action(text): lines = text.split("\n") last_line = lines[-1] - if "you ask" in last_line or "You ask" in last_line or "you say" in last_line or "You say" in last_line: + if ( + "you ask" in last_line + or "You ask" in last_line + or "you say" in last_line + or "You say" in last_line + ): text = "\n".join(lines[0:-1]) return text - + + def cut_trailing_sentence(text): text = standardize_punctuation(text) - last_punc = max(text.rfind('.'), text.rfind("!"), text.rfind("?")) + last_punc = max(text.rfind("."), text.rfind("!"), text.rfind("?")) if last_punc <= 0: - last_punc = len(text)-1 + last_punc = len(text) - 1 et_token = text.find("<") if et_token > 0: - last_punc = min(last_punc, et_token-1) + last_punc = min(last_punc, et_token - 1) act_token = text.find(">") if act_token > 0: - last_punc = min(last_punc, act_token-1) + last_punc = min(last_punc, act_token - 1) text = text[:last_punc] @@ -163,19 +187,21 @@ def is_second_person(text): def capitalize(word): return word[0].upper() + word[1:] - + def mapping_variation_pairs(mapping): mapping_list = [] - mapping_list.append((" " + mapping[0]+" ", " " + mapping[1]+" ")) - mapping_list.append((" " + capitalize(mapping[0]) + " ", " " + capitalize(mapping[1]) + " ")) + mapping_list.append((" " + mapping[0] + " ", " " + mapping[1] + " ")) + mapping_list.append( + (" " + capitalize(mapping[0]) + " ", " " + capitalize(mapping[1]) + " ") + ) # Change you it's before a punctuation if mapping[0] is "you": mapping = ("you", "me") - mapping_list.append((" " + mapping[0]+",", " " + mapping[1]+",")) - mapping_list.append((" " + mapping[0]+"\?", " " + mapping[1]+"\?")) - mapping_list.append((" " + mapping[0]+"\!", " " + mapping[1]+"\!")) + mapping_list.append((" " + mapping[0] + ",", " " + mapping[1] + ",")) + mapping_list.append((" " + mapping[0] + "\?", " " + mapping[1] + "\?")) + mapping_list.append((" " + mapping[0] + "\!", " " + mapping[1] + "\!")) mapping_list.append((" " + mapping[0] + "\.", " " + mapping[1] + ".")) return mapping_list @@ -202,14 +228,14 @@ def mapping_variation_pairs(mapping): ("I've", "you've"), ("I was", "you were"), ("my", "your"), - ("we","you"), + ("we", "you"), ("we're", "you're"), - ("mine","yours"), + ("mine", "yours"), ("me", "you"), ("us", "you"), ("our", "your"), ("I'll", "you'll"), - ("myself", "yourself") + ("myself", "yourself"), ] second_to_first_mappings = [ @@ -222,9 +248,10 @@ def mapping_variation_pairs(mapping): ("you", "me"), ("you'll", "I'll"), ("yourself", "myself"), - ("you've", "I've") + ("you've", "I've"), ] + def capitalize_helper(string): string_list = list(string) string_list[0] = string_list[0].upper() @@ -232,19 +259,20 @@ def capitalize_helper(string): def capitalize_first_letters(text): - first_letters_regex = re.compile(r'((?<=[\.\?!]\s)(\w+)|(^\w+))') + first_letters_regex = re.compile(r"((?<=[\.\?!]\s)(\w+)|(^\w+))") def cap(match): - return (capitalize_helper(match.group())) + return capitalize_helper(match.group()) result = first_letters_regex.sub(cap, text) return result + def standardize_punctuation(text): text = text.replace("’", "'") text = text.replace("`", "'") - text = text.replace('“', '"') - text = text.replace('”', '"') + text = text.replace("“", '"') + text = text.replace("”", '"') return text @@ -258,6 +286,7 @@ def first_to_second_person(text): return capitalize_first_letters(text[1:]) + def second_to_first_person(text): text = " " + text text = standardize_punctuation(text) @@ -266,4 +295,4 @@ def second_to_first_person(text): for variation in variations: text = replace_outside_quotes(text, variation[0], variation[1]) - return capitalize_first_letters(text[1:]) \ No newline at end of file + return capitalize_first_letters(text[1:])