From 699c8e8987f719bdfb35b38f89a6b540a3e9017d Mon Sep 17 00:00:00 2001 From: Scott Lundberg Date: Sat, 20 May 2023 15:20:31 -0700 Subject: [PATCH] Add stop_regex support for OpenAI models, and fix an issue when caching is off in Transformers --- guidance/__init__.py | 4 +- guidance/_program.py | 7 +- guidance/library/_gen.py | 30 +++--- guidance/library/_select.py | 2 +- guidance/llms/_openai.py | 180 ++++++++++++++++++++++++++++---- guidance/llms/_transformers.py | 24 ++++- tests/library/test_gen.py | 29 ++++- tests/library/test_select.py | 26 ++++- tests/llms/test_transformers.py | 11 -- tests/utils.py | 8 ++ 10 files changed, 263 insertions(+), 58 deletions(-) diff --git a/guidance/__init__.py b/guidance/__init__.py index 3f5e26612..55e9795d8 100644 --- a/guidance/__init__.py +++ b/guidance/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.49" +__version__ = "0.0.50" import types import sys @@ -17,7 +17,7 @@ # This makes the guidance module callable class Guidance(types.ModuleType): - def __call__(self, template, llm=None, cache_seed=0, logprobs=None, silent='auto', async_mode=False, stream='auto', caching=None, await_missing=False, **kwargs): + def __call__(self, template, llm=None, cache_seed=0, logprobs=None, silent='auto', async_mode=False, stream=None, caching=None, await_missing=False, **kwargs): return Program(template, llm=llm, cache_seed=cache_seed, logprobs=logprobs, silent=silent, async_mode=async_mode, stream=stream, caching=caching, await_missing=await_missing, **kwargs) sys.modules[__name__].__class__ = Guidance diff --git a/guidance/_program.py b/guidance/_program.py index d87753560..260be00d3 100644 --- a/guidance/_program.py +++ b/guidance/_program.py @@ -33,7 +33,7 @@ class Program: the generated output to mark where template tags used to be. ''' - def __init__(self, text, llm=None, cache_seed=0, logprobs=None, silent='auto', async_mode=False, stream='auto', caching=None, await_missing=False, **kwargs): + def __init__(self, text, llm=None, cache_seed=0, logprobs=None, silent='auto', async_mode=False, stream=None, caching=None, await_missing=False, **kwargs): """ Create a new Program object from a program string. Parameters @@ -56,8 +56,9 @@ def __init__(self, text, llm=None, cache_seed=0, logprobs=None, silent='auto', a async_mode : bool (default False) If True, the program will be executed asynchronously. This is useful for programs that take a long time to run, or that need to be run in parallel. - stream : bool (default False) - If True, the program will try to stream all the results from the LLM token by token. + stream : bool (default None) + If True, the program will try to stream all the results from the LLM token by token. If None + streaming will be enabled if is needed for funtionality. (Warning: this param may change a bit in the future) caching : bool (default None) If True, the program will cache the results of the LLM. If False, it will not cache the results. If None, it will use the default caching setting from the LLM. diff --git a/guidance/library/_gen.py b/guidance/library/_gen.py index 028f3949a..21a76ede0 100644 --- a/guidance/library/_gen.py +++ b/guidance/library/_gen.py @@ -2,14 +2,15 @@ import re import uuid import logging +import types from .._grammar import grammar from .._utils import escape_template_block log = logging.getLogger(__name__) -async def gen(variable_name="generated", stop=None, stop_regex=None, max_tokens=500, n=1, temperature=0.0, top_p=1.0, - logprobs=None, pattern=None, hidden=False, parse=False, list_append=False, save_prompt=False, - token_healing=None, _parser_context=None): +async def gen(variable_name="generated", stop=None, stop_regex=None, save_stop_text=False, max_tokens=500, n=1, + temperature=0.0, top_p=1.0, logprobs=None, pattern=None, hidden=False, parse=False, list_append=False, + save_prompt=False, token_healing=None, _parser_context=None): ''' Use the LLM to generate a completion. Parameters @@ -22,6 +23,10 @@ async def gen(variable_name="generated", stop=None, stop_regex=None, max_tokens= the generated value. stop_regex : str A regular expression to use for stopping generation. If not provided, the stop string will be used. + save_stop_text : str or bool + If set to a string, the exact stop text used will be saved in a variable with the given name. If set to + True, the stop text will be saved in a variable named `variable_name+"_stop_text"`. If set to False, + the stop text will not be saved. max_tokens : int The maximum number of tokens to generate in this completion. n : int @@ -113,14 +118,8 @@ async def gen(variable_name="generated", stop=None, stop_regex=None, max_tokens= else: cache_seed = 0 - # see if we should stream the results - if n == 1: # we can't stream batches right now - if parser.program.stream == "auto": - stream_generation = not parser.program.silent or parser.program.async_mode - else: - stream_generation = parser.program.stream - else: - stream_generation = False + # we can't stream batches right now + stream_generation = parser.program.stream if n == 1 else False # save the prompt if requested if save_prompt: @@ -140,7 +139,7 @@ async def gen(variable_name="generated", stop=None, stop_regex=None, max_tokens= generated_value = prefix partial_output(prefix) logprobs_out = [] - if not stream_generation: + if not isinstance(gen_obj, types.GeneratorType): gen_obj = [gen_obj] if list_append: value_list = parser.get_variable(variable_name, []) @@ -169,6 +168,13 @@ async def gen(variable_name="generated", stop=None, stop_regex=None, max_tokens= parser.set_variable(variable_name, generated_value) if logprobs is not None: parser.set_variable(variable_name+"_logprobs", logprobs_out) + + # save the final stopping text if requested + if save_stop_text is not False: + if save_stop_text is True: + save_stop_text = variable_name+"_stop_text" + parser.set_variable(save_stop_text, resp["choices"][0].get('stop_text', None)) + if hasattr(gen_obj, 'close'): gen_obj.close() generated_value += suffix diff --git a/guidance/library/_select.py b/guidance/library/_select.py index 9f6122211..fc3d7c051 100644 --- a/guidance/library/_select.py +++ b/guidance/library/_select.py @@ -80,7 +80,7 @@ async def recursive_select(current_prefix, allow_token_extension=True): option_tokens = parser.program.llm.encode(parser_prefix[-50:] + option) # if we extended the last token to a longer one - if option_tokens[len(tmp_prefix_tokens)-1] != tmp_prefix_tokens[-1]: + if len(tmp_prefix_tokens) > 0 and option_tokens[len(tmp_prefix_tokens)-1] != tmp_prefix_tokens[-1]: if allow_token_extension: # this is a valid extension only if we are not allowed to extend the token logit_bias1[option_tokens[len(tmp_prefix_tokens)-1]] = 100 diff --git a/guidance/llms/_openai.py b/guidance/llms/_openai.py index 9bf68f482..9285785fb 100644 --- a/guidance/llms/_openai.py +++ b/guidance/llms/_openai.py @@ -9,6 +9,7 @@ import collections import json import re +import regex from ._llm import LLM, LLMSession, SyncSession @@ -16,10 +17,6 @@ class MalformedPromptException(Exception): pass def prompt_to_messages(prompt): messages = [] - start_tags = re.findall(r'<\|im_start\|>', prompt) - end_tags = re.findall(r'<\|im_end\|>', prompt) - # if len(start_tags) != len(end_tags): - # raise MalformedPromptException("Malformed prompt: start and end tags are not properly paired") assert prompt.endswith("<|im_start|>assistant\n"), "When calling OpenAI chat models you must generate only directly inside the assistant role! The OpenAI API does not currently support partial assistant prompting." @@ -56,9 +53,6 @@ def add_text_to_chat_mode(chat_mode): for c in chat_mode['choices']: c['text'] = c['message']['content'] return chat_mode - - - # c['text'] = f'<|im_start|>{c["message"]["role"]}\n{c["message"]["content"]}<|im_end|>' # model that need to use the chat completion API chat_models = [ @@ -149,11 +143,88 @@ def role_end(self, role=None): return "<|im_end|>" @classmethod - def stream_then_save(cls, gen, key): + def stream_then_save(cls, gen, key, stop_regex, n): list_out = [] + cached_out = None + + # init stop_regex variables + if stop_regex is not None: + if isinstance(stop_regex, str): + stop_patterns = [regex.compile(stop_regex)] + else: + stop_patterns = [regex.compile(pattern) for pattern in stop_regex] + + current_strings = ["" for _ in range(n)] + # last_out_pos = ["" for _ in range(n)] + + # iterate through the stream + all_done = False for out in gen: - list_out.append(out) - yield out + + # if we have a cached output, extend it with the current output + if cached_out is not None: + out = merge_stream_chunks(cached_out, out) + + # check if we have stop_regex matches + found_partial = False + if stop_regex is not None: + + # keep track of the generated text so far + for i,choice in enumerate(out['choices']): + current_strings[i] += choice['text'] + + # check if all of the strings match a stop string (and hence we can stop the batch inference) + all_done = True + for i in range(len(current_strings)): + found = False + for s in stop_patterns: + if s.search(current_strings[i]): + found = True + if not found: + all_done = False + break + + # find where trim off the stop regex matches if needed (and look for partial matches) + stop_pos = [1e10 for _ in range(n)] + stop_text = [None for _ in range(n)] + for i in range(len(current_strings)): + for s in stop_patterns: + m = s.search(current_strings[i], partial=True) + if m: + span = m.span() + if span[1] > span[0]: + if m.partial: # we might be starting a stop sequence, so we can't emit anything yet + found_partial = True + break + else: + stop_text[i] = current_strings[i][span[0]:span[1]] + stop_pos[i] = min(span[0], stop_pos[i]) + if stop_pos != 1e10: + stop_pos[i] = stop_pos[i] - len(current_strings[i]) # convert to relative position from the end + + # if we might be starting a stop sequence, we need to cache the output and continue to wait and see + if found_partial: + cached_out = out + continue + + # if we get here, we are not starting a stop sequence, so we can emit the output + else: + cached_out = None + + if stop_regex is not None: + for i in range(len(out['choices'])): + if stop_pos[i] < len(out['choices'][i]['text']): + out['choices'][i] = out['choices'][i].to_dict() # because sometimes we might need to set the text to the empty string (and OpenAI's object does not like that) + out['choices'][i]['text'] = out['choices'][i]['text'][:stop_pos[i]] + out['choices'][i]['stop_text'] = stop_text[i] + out['choices'][i]['finish_reason'] = "stop" + + list_out.append(out) + yield out + if all_done: + gen.close() + break + cls.cache[key] = list_out def _stream_completion(self): @@ -258,13 +329,88 @@ def decode(self, tokens, fragment=True): return self._tokenizer.decode(tokens) +def merge_stream_chunks(first_chunk, second_chunk): + """ This merges two stream responses together. + """ + + out = copy.deepcopy(first_chunk) + + # merge the choices + for i in range(len(out['choices'])): + out_choice = out['choices'][i] + second_choice = second_chunk['choices'][i] + out_choice['text'] += second_choice['text'] + if 'index' in second_choice: + out_choice['index'] = second_choice['index'] + if 'finish_reason' in second_choice: + out_choice['finish_reason'] = second_choice['finish_reason'] + if out_choice.get('logprobs', None) is not None: + out_choice['logprobs']['token_logprobs'] += second_choice['logprobs']['token_logprobs'] + out_choice['logprobs']['top_logprobs'] += second_choice['logprobs']['top_logprobs'] + out_choice['logprobs']['text_offset'] = second_choice['logprobs']['text_offset'] + + return out + + +class OpenAIStreamer(): + def __init__(self, stop_regex, n): + self.stop_regex = stop_regex + self.n = n + self.current_strings = ["" for _ in range(n)] + self.current_length = 0 + +class RegexStopChecker(): + def __init__(self, stop_pattern, decode, prefix_length): + if isinstance(stop_pattern, str): + self.stop_patterns = [regex.compile(stop_pattern)] + else: + self.stop_patterns = [regex.compile(pattern) for pattern in stop_pattern] + self.prefix_length = prefix_length + self.decode = decode + self.current_strings = None + self.current_length = 0 + + def __call__(self, input_ids, scores, **kwargs): + + # extend our current strings + if self.current_strings is None: + self.current_strings = ["" for _ in range(len(input_ids))] + for i in range(len(self.current_strings)): + self.current_strings[i] += self.decode(input_ids[i][self.current_length:]) + + # trim off the prefix string so we don't look for stop matches in the prompt + if self.current_length == 0: + for i in range(len(self.current_strings)): + self.current_strings[i] = self.current_strings[i][self.prefix_length:] + + self.current_length = len(input_ids[0]) + + # check if all of the strings match a stop string (and hence we can stop the batch inference) + all_done = True + for i in range(len(self.current_strings)): + found = False + for s in self.stop_patterns: + if s.search(self.current_strings[i]): + found = True + if not found: + all_done = False + break + + return all_done + # Define a deque to store the timestamps of the calls class OpenAISession(LLMSession): - async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n=1, max_tokens=1000, logprobs=None, top_p=1.0, echo=False, logit_bias=None, token_healing=None, pattern=None, stream=False, cache_seed=0, caching=None): + async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n=1, max_tokens=1000, logprobs=None, top_p=1.0, echo=False, logit_bias=None, token_healing=None, pattern=None, stream=None, cache_seed=0, caching=None): """ Generate a completion of the given prompt. """ - assert token_healing is None or token_healing is False, "The OpenAI API does not support token healing! Please either switch to an endpoint that does, or don't use the `token_healing` argument to `gen`." + # we need to stream in order to support stop_regex + if stream is None: + stream = stop_regex is not None + assert stop_regex is None or stream, "We can only support stop_regex for the OpenAI API when stream=True!" + assert stop_regex is None or n == 1, "We don't yet support stop_regex combined with n > 1 with the OpenAI API!" + + assert token_healing is None or token_healing is False, "The OpenAI API does not yet support token healing! Please either switch to an endpoint that does, or don't use the `token_healing` argument to `gen`." # set defaults if temperature is None: @@ -274,7 +420,7 @@ async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n args = locals().copy() assert not pattern, "The OpenAI API does not support Guidance pattern controls! Please either switch to an endpoint that does, or don't use the `pattern` argument to `gen`." - assert not stop_regex, "The OpenAI API does not support Guidance stop_regex controls! Please either switch to an endpoint that does, or don't use the `stop_regex` argument to `gen`." + # assert not stop_regex, "The OpenAI API does not support Guidance stop_regex controls! Please either switch to an endpoint that does, or don't use the `stop_regex` argument to `gen`." # define the key for the cache key = self._cache_key(args) @@ -326,7 +472,7 @@ async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n raise Exception(f"Too many (more than {self.llm.max_retries}) OpenAI API RateLimitError's in a row!") if stream: - return self.llm.stream_then_save(out, key) + return self.llm.stream_then_save(out, key, stop_regex, n) else: self.llm.__class__.cache[key] = out @@ -336,8 +482,4 @@ async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n return self.llm.__class__.cache[key] return [self.llm.__class__.cache[key]] - return self.llm.__class__.cache[key] - -# class OpenAISession(AsyncOpenAISession): -# def __call__(self, *args, **kwargs): -# return self._loop.run_until_complete(super().__call__(*args, **kwargs)) \ No newline at end of file + return self.llm.__class__.cache[key] \ No newline at end of file diff --git a/guidance/llms/_transformers.py b/guidance/llms/_transformers.py index 6c94b203d..20fe8e53e 100644 --- a/guidance/llms/_transformers.py +++ b/guidance/llms/_transformers.py @@ -232,9 +232,13 @@ async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n stop_regex.append(regex.escape(self.llm._tokenizer.eos_token)) # make sure the end of sequence token is always included # handle caching - if key not in self.llm.cache or (caching is not True and not self.llm.caching) or caching is False: + in_cache = key in self.llm.cache + not_caching = (caching is not True and not self.llm.caching) or caching is False + if not in_cache or not_caching: import transformers - # import torch + + assert prompt != "", "You must provide a non-zero length prompt to the Transformers language model!" + # encode the prompt encoded = self.llm.encode([prompt for _ in range(n)], return_tensors="pt", fragment=False) if self.llm.device is not None: @@ -277,12 +281,19 @@ async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n max_tokens = max_context - len(input_ids[0]) # find how much of the prompt is cached - for prefix_match_len, token in enumerate(input_ids[0]): + prefix_match_len = 0 + for token in input_ids[0]: if prefix_match_len >= len(self._prefix_cache) or token != self._prefix_cache[prefix_match_len]: break + else: + prefix_match_len += 1 + + # we always need to run the model on at least one token so transformers is happy + if prefix_match_len == len(input_ids[0]): + prefix_match_len -= 1 # trim the cache to what we can use - if prefix_match_len > 0 and prefix_match_len < len(self._prefix_cache): + if prefix_match_len < len(self._prefix_cache): # prefix_match_len > 0 and self._past_key_values = tuple((key[:,:,:prefix_match_len,:],value[:,:,:prefix_match_len,:]) for key,value in self._past_key_values) # TODO: this is specific to the GPT2 tensor layout self._prefix_cache = self._prefix_cache[:prefix_match_len] @@ -363,6 +374,7 @@ def _stream_then_save(self, streamer, key, thread): thread.join() # clean up the thread self.llm.cache[key] = list_out self._update_prefix_cache(streamer) + self._last_computed_key = key def __exit__(self, exc_type, exc_value, traceback): """ Restore the model to its original state by removing monkey patches. @@ -647,6 +659,7 @@ def put(self, token_obj): # trim off the stop regex matches if needed found_partial = False + stop_text = None if self.stop_regex is not None:# and (finish_reason is None or len(self.input_ids) > 1): stop_regex_obj = [regex.compile(s) for s in self.stop_regex] for s in stop_regex_obj: @@ -658,7 +671,9 @@ def put(self, token_obj): found_partial = True break else: + stop_text = val[span[0]:span[1]] stop_pos = min(span[0], stop_pos) + break # record the reason we stopped (if we have stopped) if stop_pos <= len(val): @@ -668,6 +683,7 @@ def put(self, token_obj): out["choices"][i] = { "text": val[:stop_pos], "finish_reason": finish_reason, + "stop_text": stop_text, "logprobs": {"token_healing_prefix": self.last_token_str, "top_logprobs": display_logprobs} } self.str_pos[i] = len(self.generated_string[i]) diff --git a/tests/library/test_gen.py b/tests/library/test_gen.py index 981c8887f..be23ea492 100644 --- a/tests/library/test_gen.py +++ b/tests/library/test_gen.py @@ -1,8 +1,9 @@ import guidance -from ..utils import get_transformers_llm +import pytest +from ..utils import get_llm def test_gen(): - """ Test that LM geneation works. + """ Test that LM generation works. """ llm = guidance.llms.Mock(" Sue") @@ -28,7 +29,7 @@ def aggregate(best): def test_pattern(): import re - llm = get_transformers_llm("gpt2") + llm = get_llm("transformers:gpt2") out = guidance('''On a scale of 1-10 I would say it is: {{gen 'score' pattern="[0-9]+"}}''', llm=llm)() assert re.match(r'[0-9]+', out["score"]) @@ -46,7 +47,7 @@ def test_pattern2(): - GPT {{gen 'chapter' pattern='[0-9]' max_tokens=1}}:{{gen 'verse' pattern='[0-9]+' stop='\\n'}} '''[1:-1] - llm = get_transformers_llm("gpt2") + llm = get_llm("transformers:gpt2") program = guidance(prompt, llm=llm) executed_program = program( proverb="Where there is no guidance, a people falls,\nbut in an abundance of counselors there is safety.", @@ -56,4 +57,22 @@ def test_pattern2(): ) assert re.fullmatch(r"[0-9]", executed_program["chapter"]) - assert re.fullmatch(r"[0-9]+", executed_program["verse"]) \ No newline at end of file + assert re.fullmatch(r"[0-9]+", executed_program["verse"]) + +@pytest.mark.parametrize("llm", ["transformers:gpt2", "openai:text-curie-001"]) +def test_stop(llm): + """ Test that the stop argument works as expected. + """ + llm = get_llm(llm) + program = guidance("""Write "repeat this. " 10 times: repeat this. repeat this. repeat this. repeat this. repeat this. repeat this.{{gen stop="this" max_tokens=10}}""", llm=llm) + out = program() + assert str(out) == "Write \"repeat this. \" 10 times: repeat this. repeat this. repeat this. repeat this. repeat this. repeat this. repeat " + +@pytest.mark.parametrize("llm", ["transformers:gpt2", "openai:text-curie-001"]) +def test_stop_regex(llm): + """ Test that the stop_regex argument works as expected. + """ + llm = get_llm(llm) + program = guidance("""Write "repeat this. " 10 times: repeat this. repeat this. repeat this. repeat this. repeat this. repeat this.{{gen stop_regex="th.s" max_tokens=10}}""", llm=llm) + out = program() + assert str(out) == "Write \"repeat this. \" 10 times: repeat this. repeat this. repeat this. repeat this. repeat this. repeat this. repeat " \ No newline at end of file diff --git a/tests/library/test_select.py b/tests/library/test_select.py index c027ffb6b..714d5a4dc 100644 --- a/tests/library/test_select.py +++ b/tests/library/test_select.py @@ -59,4 +59,28 @@ def test_select_odd_spacing(): Sentence: {{example}} Answer: {{#select "answer" logprobs='logprobs'}} Yes{{or}} Nein{{or}} Maybe{{/select}}''', llm=llm) prompt = prompt(example='I hate tacos.') - assert prompt["answer"] in [" Yes", " Nein", " Maybe"] \ No newline at end of file + assert prompt["answer"] in [" Yes", " Nein", " Maybe"] + + +def test_overlapping_options(): + """ Test the behavior of `select` when one option is a prefix of another. + """ + + llm = get_transformers_llm("gpt2") + options = ['a', 'aa'] + program = guidance("'{{select options=options}}", llm=llm) + out = program(options=options) + assert out["selected"] in options + +# TODO: fix this next +# def test_unexpected_tokens(): +# """ Test the behavior of `select` when the next tokens are hard to predict. +# """ + +# llm = get_transformers_llm("gpt2") +# options = ['a', 'b'] +# program = guidance("some word xy{{select options=options}}", llm=llm) +# out = program(options=options) +# assert out["selected"] in options + +# TODO: test when we have few starting tokens \ No newline at end of file diff --git a/tests/llms/test_transformers.py b/tests/llms/test_transformers.py index 8c275d546..b9aa62708 100644 --- a/tests/llms/test_transformers.py +++ b/tests/llms/test_transformers.py @@ -14,18 +14,7 @@ def test_repeat(): out2 = s("this is a test like another", max_tokens=5) print(out2) -def test_stop(): - llm = get_transformers_llm('gpt2') - program = guidance("""Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this.{{gen stop="this" max_tokens=10}}""", llm=llm) - out = program() - assert str(out) == "Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat " -def test_pattern(): - import re - llm = get_transformers_llm('gpt2') - program = guidance("""Repeat this. Repeat this. Repeat this. Repeat this. {{gen pattern="[0-9]+" max_tokens=1}}""", llm=llm) - out = program() - assert re.match("^Repeat this. Repeat this. Repeat this. Repeat this. [0-9]+$", str(out)) def test_select(): llm = get_transformers_llm('gpt2') diff --git a/tests/utils.py b/tests/utils.py index f062edd41..b974614d7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,14 @@ opanai_model_cache = {} +def get_llm(model_name): + """ Get an LLM by name. + """ + if model_name.startswith("openai:"): + return get_openai_llm(model_name[7:]) + elif model_name.startswith("transformers:"): + return get_transformers_llm(model_name[13:]) + def get_openai_llm(model_name, caching=False): """ Get an OpenAI LLM with model reuse and smart test skipping. """