Skip to content

Commit

Permalink
Add stop_regex support for OpenAI models, and fix an issue when cachi…
Browse files Browse the repository at this point in the history
…ng is off in Transformers
  • Loading branch information
slundberg committed May 20, 2023
1 parent aefb45f commit 699c8e8
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 58 deletions.
4 changes: 2 additions & 2 deletions guidance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.49"
__version__ = "0.0.50"

import types
import sys
Expand All @@ -17,7 +17,7 @@

# This makes the guidance module callable
class Guidance(types.ModuleType):
def __call__(self, template, llm=None, cache_seed=0, logprobs=None, silent='auto', async_mode=False, stream='auto', caching=None, await_missing=False, **kwargs):
def __call__(self, template, llm=None, cache_seed=0, logprobs=None, silent='auto', async_mode=False, stream=None, caching=None, await_missing=False, **kwargs):
return Program(template, llm=llm, cache_seed=cache_seed, logprobs=logprobs, silent=silent, async_mode=async_mode, stream=stream, caching=caching, await_missing=await_missing, **kwargs)
sys.modules[__name__].__class__ = Guidance

Expand Down
7 changes: 4 additions & 3 deletions guidance/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Program:
the generated output to mark where template tags used to be.
'''

def __init__(self, text, llm=None, cache_seed=0, logprobs=None, silent='auto', async_mode=False, stream='auto', caching=None, await_missing=False, **kwargs):
def __init__(self, text, llm=None, cache_seed=0, logprobs=None, silent='auto', async_mode=False, stream=None, caching=None, await_missing=False, **kwargs):
""" Create a new Program object from a program string.
Parameters
Expand All @@ -56,8 +56,9 @@ def __init__(self, text, llm=None, cache_seed=0, logprobs=None, silent='auto', a
async_mode : bool (default False)
If True, the program will be executed asynchronously. This is useful for programs that
take a long time to run, or that need to be run in parallel.
stream : bool (default False)
If True, the program will try to stream all the results from the LLM token by token.
stream : bool (default None)
If True, the program will try to stream all the results from the LLM token by token. If None
streaming will be enabled if is needed for funtionality. (Warning: this param may change a bit in the future)
caching : bool (default None)
If True, the program will cache the results of the LLM. If False, it will not cache the results.
If None, it will use the default caching setting from the LLM.
Expand Down
30 changes: 18 additions & 12 deletions guidance/library/_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import re
import uuid
import logging
import types
from .._grammar import grammar
from .._utils import escape_template_block

log = logging.getLogger(__name__)

async def gen(variable_name="generated", stop=None, stop_regex=None, max_tokens=500, n=1, temperature=0.0, top_p=1.0,
logprobs=None, pattern=None, hidden=False, parse=False, list_append=False, save_prompt=False,
token_healing=None, _parser_context=None):
async def gen(variable_name="generated", stop=None, stop_regex=None, save_stop_text=False, max_tokens=500, n=1,
temperature=0.0, top_p=1.0, logprobs=None, pattern=None, hidden=False, parse=False, list_append=False,
save_prompt=False, token_healing=None, _parser_context=None):
''' Use the LLM to generate a completion.
Parameters
Expand All @@ -22,6 +23,10 @@ async def gen(variable_name="generated", stop=None, stop_regex=None, max_tokens=
the generated value.
stop_regex : str
A regular expression to use for stopping generation. If not provided, the stop string will be used.
save_stop_text : str or bool
If set to a string, the exact stop text used will be saved in a variable with the given name. If set to
True, the stop text will be saved in a variable named `variable_name+"_stop_text"`. If set to False,
the stop text will not be saved.
max_tokens : int
The maximum number of tokens to generate in this completion.
n : int
Expand Down Expand Up @@ -113,14 +118,8 @@ async def gen(variable_name="generated", stop=None, stop_regex=None, max_tokens=
else:
cache_seed = 0

# see if we should stream the results
if n == 1: # we can't stream batches right now
if parser.program.stream == "auto":
stream_generation = not parser.program.silent or parser.program.async_mode
else:
stream_generation = parser.program.stream
else:
stream_generation = False
# we can't stream batches right now
stream_generation = parser.program.stream if n == 1 else False

# save the prompt if requested
if save_prompt:
Expand All @@ -140,7 +139,7 @@ async def gen(variable_name="generated", stop=None, stop_regex=None, max_tokens=
generated_value = prefix
partial_output(prefix)
logprobs_out = []
if not stream_generation:
if not isinstance(gen_obj, types.GeneratorType):
gen_obj = [gen_obj]
if list_append:
value_list = parser.get_variable(variable_name, [])
Expand Down Expand Up @@ -169,6 +168,13 @@ async def gen(variable_name="generated", stop=None, stop_regex=None, max_tokens=
parser.set_variable(variable_name, generated_value)
if logprobs is not None:
parser.set_variable(variable_name+"_logprobs", logprobs_out)

# save the final stopping text if requested
if save_stop_text is not False:
if save_stop_text is True:
save_stop_text = variable_name+"_stop_text"
parser.set_variable(save_stop_text, resp["choices"][0].get('stop_text', None))

if hasattr(gen_obj, 'close'):
gen_obj.close()
generated_value += suffix
Expand Down
2 changes: 1 addition & 1 deletion guidance/library/_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def recursive_select(current_prefix, allow_token_extension=True):
option_tokens = parser.program.llm.encode(parser_prefix[-50:] + option)

# if we extended the last token to a longer one
if option_tokens[len(tmp_prefix_tokens)-1] != tmp_prefix_tokens[-1]:
if len(tmp_prefix_tokens) > 0 and option_tokens[len(tmp_prefix_tokens)-1] != tmp_prefix_tokens[-1]:
if allow_token_extension: # this is a valid extension only if we are not allowed to extend the token
logit_bias1[option_tokens[len(tmp_prefix_tokens)-1]] = 100

Expand Down
180 changes: 161 additions & 19 deletions guidance/llms/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,14 @@
import collections
import json
import re
import regex
from ._llm import LLM, LLMSession, SyncSession


class MalformedPromptException(Exception):
pass
def prompt_to_messages(prompt):
messages = []
start_tags = re.findall(r'<\|im_start\|>', prompt)
end_tags = re.findall(r'<\|im_end\|>', prompt)
# if len(start_tags) != len(end_tags):
# raise MalformedPromptException("Malformed prompt: start and end tags are not properly paired")

assert prompt.endswith("<|im_start|>assistant\n"), "When calling OpenAI chat models you must generate only directly inside the assistant role! The OpenAI API does not currently support partial assistant prompting."

Expand Down Expand Up @@ -56,9 +53,6 @@ def add_text_to_chat_mode(chat_mode):
for c in chat_mode['choices']:
c['text'] = c['message']['content']
return chat_mode


# c['text'] = f'<|im_start|>{c["message"]["role"]}\n{c["message"]["content"]}<|im_end|>'

# model that need to use the chat completion API
chat_models = [
Expand Down Expand Up @@ -149,11 +143,88 @@ def role_end(self, role=None):
return "<|im_end|>"

@classmethod
def stream_then_save(cls, gen, key):
def stream_then_save(cls, gen, key, stop_regex, n):
list_out = []
cached_out = None

# init stop_regex variables
if stop_regex is not None:
if isinstance(stop_regex, str):
stop_patterns = [regex.compile(stop_regex)]
else:
stop_patterns = [regex.compile(pattern) for pattern in stop_regex]

current_strings = ["" for _ in range(n)]
# last_out_pos = ["" for _ in range(n)]

# iterate through the stream
all_done = False
for out in gen:
list_out.append(out)
yield out

# if we have a cached output, extend it with the current output
if cached_out is not None:
out = merge_stream_chunks(cached_out, out)

# check if we have stop_regex matches
found_partial = False
if stop_regex is not None:

# keep track of the generated text so far
for i,choice in enumerate(out['choices']):
current_strings[i] += choice['text']

# check if all of the strings match a stop string (and hence we can stop the batch inference)
all_done = True
for i in range(len(current_strings)):
found = False
for s in stop_patterns:
if s.search(current_strings[i]):
found = True
if not found:
all_done = False
break

# find where trim off the stop regex matches if needed (and look for partial matches)
stop_pos = [1e10 for _ in range(n)]
stop_text = [None for _ in range(n)]
for i in range(len(current_strings)):
for s in stop_patterns:
m = s.search(current_strings[i], partial=True)
if m:
span = m.span()
if span[1] > span[0]:
if m.partial: # we might be starting a stop sequence, so we can't emit anything yet
found_partial = True
break
else:
stop_text[i] = current_strings[i][span[0]:span[1]]
stop_pos[i] = min(span[0], stop_pos[i])
if stop_pos != 1e10:
stop_pos[i] = stop_pos[i] - len(current_strings[i]) # convert to relative position from the end

# if we might be starting a stop sequence, we need to cache the output and continue to wait and see
if found_partial:
cached_out = out
continue

# if we get here, we are not starting a stop sequence, so we can emit the output
else:
cached_out = None

if stop_regex is not None:
for i in range(len(out['choices'])):
if stop_pos[i] < len(out['choices'][i]['text']):
out['choices'][i] = out['choices'][i].to_dict() # because sometimes we might need to set the text to the empty string (and OpenAI's object does not like that)
out['choices'][i]['text'] = out['choices'][i]['text'][:stop_pos[i]]
out['choices'][i]['stop_text'] = stop_text[i]
out['choices'][i]['finish_reason'] = "stop"

list_out.append(out)
yield out
if all_done:
gen.close()
break

cls.cache[key] = list_out

def _stream_completion(self):
Expand Down Expand Up @@ -258,13 +329,88 @@ def decode(self, tokens, fragment=True):
return self._tokenizer.decode(tokens)


def merge_stream_chunks(first_chunk, second_chunk):
""" This merges two stream responses together.
"""

out = copy.deepcopy(first_chunk)

# merge the choices
for i in range(len(out['choices'])):
out_choice = out['choices'][i]
second_choice = second_chunk['choices'][i]
out_choice['text'] += second_choice['text']
if 'index' in second_choice:
out_choice['index'] = second_choice['index']
if 'finish_reason' in second_choice:
out_choice['finish_reason'] = second_choice['finish_reason']
if out_choice.get('logprobs', None) is not None:
out_choice['logprobs']['token_logprobs'] += second_choice['logprobs']['token_logprobs']
out_choice['logprobs']['top_logprobs'] += second_choice['logprobs']['top_logprobs']
out_choice['logprobs']['text_offset'] = second_choice['logprobs']['text_offset']

return out


class OpenAIStreamer():
def __init__(self, stop_regex, n):
self.stop_regex = stop_regex
self.n = n
self.current_strings = ["" for _ in range(n)]
self.current_length = 0

class RegexStopChecker():
def __init__(self, stop_pattern, decode, prefix_length):
if isinstance(stop_pattern, str):
self.stop_patterns = [regex.compile(stop_pattern)]
else:
self.stop_patterns = [regex.compile(pattern) for pattern in stop_pattern]
self.prefix_length = prefix_length
self.decode = decode
self.current_strings = None
self.current_length = 0

def __call__(self, input_ids, scores, **kwargs):

# extend our current strings
if self.current_strings is None:
self.current_strings = ["" for _ in range(len(input_ids))]
for i in range(len(self.current_strings)):
self.current_strings[i] += self.decode(input_ids[i][self.current_length:])

# trim off the prefix string so we don't look for stop matches in the prompt
if self.current_length == 0:
for i in range(len(self.current_strings)):
self.current_strings[i] = self.current_strings[i][self.prefix_length:]

self.current_length = len(input_ids[0])

# check if all of the strings match a stop string (and hence we can stop the batch inference)
all_done = True
for i in range(len(self.current_strings)):
found = False
for s in self.stop_patterns:
if s.search(self.current_strings[i]):
found = True
if not found:
all_done = False
break

return all_done

# Define a deque to store the timestamps of the calls
class OpenAISession(LLMSession):
async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n=1, max_tokens=1000, logprobs=None, top_p=1.0, echo=False, logit_bias=None, token_healing=None, pattern=None, stream=False, cache_seed=0, caching=None):
async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n=1, max_tokens=1000, logprobs=None, top_p=1.0, echo=False, logit_bias=None, token_healing=None, pattern=None, stream=None, cache_seed=0, caching=None):
""" Generate a completion of the given prompt.
"""

assert token_healing is None or token_healing is False, "The OpenAI API does not support token healing! Please either switch to an endpoint that does, or don't use the `token_healing` argument to `gen`."
# we need to stream in order to support stop_regex
if stream is None:
stream = stop_regex is not None
assert stop_regex is None or stream, "We can only support stop_regex for the OpenAI API when stream=True!"
assert stop_regex is None or n == 1, "We don't yet support stop_regex combined with n > 1 with the OpenAI API!"

assert token_healing is None or token_healing is False, "The OpenAI API does not yet support token healing! Please either switch to an endpoint that does, or don't use the `token_healing` argument to `gen`."

# set defaults
if temperature is None:
Expand All @@ -274,7 +420,7 @@ async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n
args = locals().copy()

assert not pattern, "The OpenAI API does not support Guidance pattern controls! Please either switch to an endpoint that does, or don't use the `pattern` argument to `gen`."
assert not stop_regex, "The OpenAI API does not support Guidance stop_regex controls! Please either switch to an endpoint that does, or don't use the `stop_regex` argument to `gen`."
# assert not stop_regex, "The OpenAI API does not support Guidance stop_regex controls! Please either switch to an endpoint that does, or don't use the `stop_regex` argument to `gen`."

# define the key for the cache
key = self._cache_key(args)
Expand Down Expand Up @@ -326,7 +472,7 @@ async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n
raise Exception(f"Too many (more than {self.llm.max_retries}) OpenAI API RateLimitError's in a row!")

if stream:
return self.llm.stream_then_save(out, key)
return self.llm.stream_then_save(out, key, stop_regex, n)
else:
self.llm.__class__.cache[key] = out

Expand All @@ -336,8 +482,4 @@ async def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n
return self.llm.__class__.cache[key]
return [self.llm.__class__.cache[key]]

return self.llm.__class__.cache[key]

# class OpenAISession(AsyncOpenAISession):
# def __call__(self, *args, **kwargs):
# return self._loop.run_until_complete(super().__call__(*args, **kwargs))
return self.llm.__class__.cache[key]
Loading

0 comments on commit 699c8e8

Please sign in to comment.