From a71d7d3da031c3f94b20b5df2b4cc322320260b2 Mon Sep 17 00:00:00 2001 From: Nick Walton Date: Mon, 2 Dec 2019 07:02:36 -0700 Subject: [PATCH] Develop (#10) * adds saving and loading * fixed some cmd issues --- generator/gpt2/gpt2_generator.py | 23 ++++++++++------- generator/human_dm.py | 6 +++++ play.py | 2 +- play_dm.py | 44 ++++++++++++++++++++++++++++++++ story/story_manager.py | 14 ++++------ 5 files changed, 70 insertions(+), 19 deletions(-) create mode 100644 generator/human_dm.py create mode 100644 play_dm.py diff --git a/generator/gpt2/gpt2_generator.py b/generator/gpt2/gpt2_generator.py index f64deef2..4f147c17 100644 --- a/generator/gpt2/gpt2_generator.py +++ b/generator/gpt2/gpt2_generator.py @@ -84,15 +84,7 @@ def result_replace(self, result): return result - def generate(self, prompt, options=None, seed=1): - - debug_print = False - prefix = self.prompt_replace(prompt) - - if debug_print: - print("******DEBUG******") - print("Prompt is: ", repr(prefix)) - + def generate_raw(self, prompt): context_tokens = self.enc.encode(prompt) generated = 0 for _ in range(self.samples // self.batch_size): @@ -102,6 +94,19 @@ def generate(self, prompt, options=None, seed=1): for i in range(self.batch_size): generated += 1 text = self.enc.decode(out[i]) + return text + + + def generate(self, prompt, options=None, seed=1): + + debug_print = False + prompt = self.prompt_replace(prompt) + + if debug_print: + print("******DEBUG******") + print("Prompt is: ", repr(prompt)) + + text = self.generate_raw(prompt) if debug_print: print("Generated result is: ", repr(text)) diff --git a/generator/human_dm.py b/generator/human_dm.py new file mode 100644 index 00000000..5dded391 --- /dev/null +++ b/generator/human_dm.py @@ -0,0 +1,6 @@ +from story.utils import * + +class HumanDM: + + def generate(self, prompt, options=None, seed=None): + return input() diff --git a/play.py b/play.py index 8cb98bf2..ef4b38c1 100644 --- a/play.py +++ b/play.py @@ -18,7 +18,7 @@ def select_game(): else: print_str += " (experimental)" console_print(print_str) - console_print(str(len(settings)) + ") custom expiermental") + console_print(str(len(settings)) + ") custom experimental") choice = get_num_options(len(settings)+1) if choice == len(settings): diff --git a/play_dm.py b/play_dm.py new file mode 100644 index 00000000..3947d887 --- /dev/null +++ b/play_dm.py @@ -0,0 +1,44 @@ +from story.story_manager import * +from generator.human_dm import * +from generator.gpt2.gpt2_generator import * +from story.utils import * +from termios import tcflush, TCIFLUSH +from play import * +import time, sys, os +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +class AIPlayer: + + def __init__(self, generator): + self.generator = generator + + def get_action(self, prompt): + return self.generator.generate_raw(prompt) + +def play_dm(): + generator = GPT2Generator() + story_manager = UnconstrainedStoryManager(HumanDM()) + context, prompt = select_game() + console_print(context + prompt) + story_manager.start_new_story(prompt, context=context, upload_story=False) + + player = AIPlayer(generator) + + while True: + action_prompt = story_manager.story_context() + "What do you do next? \n> You" + action = player.get_action(action_prompt) + print("\n******DEBUG FULL ACTION*******") + print(action) + print("******END DEBUG******\n") + action = action.split("\n")[0] + shown_action = "> You" + action + console_print(second_to_first_person(shown_action)) + story_manager.act(action) + + + + +if __name__ == '__main__': + play_dm() + + diff --git a/story/story_manager.py b/story/story_manager.py index f1bb3861..8ba1f133 100644 --- a/story/story_manager.py +++ b/story/story_manager.py @@ -72,7 +72,6 @@ def latest_result(self): if len(self.results) < 2: latest_result += self.story_start - while mem_ind > 0: if len(self.results) >= mem_ind: @@ -138,22 +137,19 @@ def save_to_storage(self): def load_from_storage(self, story_id): file_name = "story" + story_id + ".json" - cmd = "gsutil cp gs://aidungeonstories/" + file_name + " . >/dev/null 2>&1" + cmd = "gsutil cp gs://aidungeonstories/" + file_name + " ." os.system(cmd) exists = os.path.isfile(file_name) - with open(file_name, 'r') as fp: - game = json.load(fp) - self.init_from_dict(game) - if exists: + with open(file_name, 'r') as fp: + game = json.load(fp) + self.init_from_dict(game) return str(self) else: return "Error save not found." - - class StoryManager(): def __init__(self, generator): @@ -163,7 +159,7 @@ def __init__(self, generator): def start_new_story(self, story_prompt, context="", game_state=None, upload_story=False): block = self.generator.generate(context + story_prompt) block = cut_trailing_sentence(block) - self.story = Story(story_prompt + block, context=context, game_state=game_state, upload_story=upload_story) + self.story = Story(context + story_prompt + block, context=context, game_state=game_state, upload_story=upload_story) return self.story def load_story(self, story, from_json=False):