diff --git a/guidance/library/__init__.py b/guidance/library/__init__.py index c7a96b124..59d52df98 100644 --- a/guidance/library/__init__.py +++ b/guidance/library/__init__.py @@ -1,5 +1,6 @@ # import functions that can be called directly from ._gen import gen, call_tool, will_gen +from ._image import image # core grammar functions from .._grammar import select diff --git a/guidance/library/_image.py b/guidance/library/_image.py new file mode 100644 index 000000000..b1c759b9d --- /dev/null +++ b/guidance/library/_image.py @@ -0,0 +1,30 @@ +import guidance +import urllib +import typing +import http +import re + +@guidance +def image(lm, src, allow_local=True): + + # load the image bytes + # ...from a url + if isinstance(src, str) and re.match(r'$[^:/]+://', src): + with urllib.request.urlopen(src) as response: + response = typing.cast(http.client.HTTPResponse, response) + bytes = response.read() + + # ...from a local path + elif allow_local and isinstance(src, str): + with open(src, "rb") as f: + bytes = f.read() + + else: + raise Exception(f"Unable to load image bytes from {src}!") + + bytes_id = str(id(bytes)) + + # set the image bytes + lm = lm.set(bytes_id, bytes) + lm += f'<|_image:{bytes_id}|>' + return lm \ No newline at end of file diff --git a/guidance/models/__init__.py b/guidance/models/__init__.py index 2c954deff..fa6d20c14 100644 --- a/guidance/models/__init__.py +++ b/guidance/models/__init__.py @@ -1,5 +1,5 @@ from ._model import Model, Chat -from .vertexai._vertexai import VertexAI, VertexAIChat +from .vertexai._vertexai import VertexAI, VertexAIChat, VertexAICompletion, VertexAIInstruct from ._azure_openai import AzureOpenAI, AzureOpenAIChat, AzureOpenAICompletion, AzureOpenAIInstruct from ._openai import OpenAI, OpenAIChat, OpenAIInstruct, OpenAICompletion from .transformers._transformers import Transformers, TransformersChat diff --git a/guidance/models/_model.py b/guidance/models/_model.py index 925b12368..4a66689de 100644 --- a/guidance/models/_model.py +++ b/guidance/models/_model.py @@ -14,6 +14,8 @@ import time import numpy as np import logging +import base64 + logger = logging.getLogger(__name__) try: from .. import cpp @@ -29,6 +31,7 @@ format_pattern = re.compile(r"<\|\|_.*?_\|\|>", flags=re.DOTALL) nodisp_pattern = re.compile(r"<\|\|_#NODISP_\|\|>.*?<\|\|_/NODISP_\|\|>", flags=re.DOTALL) html_pattern = re.compile(r"<\|\|_html:(.*?)_\|\|>", flags=re.DOTALL) +image_pattern = re.compile(r"<\|_image:(.*?)\|>") class Model: '''A guidance model object, which represents a sequence model in a given state. @@ -122,7 +125,8 @@ def _html(self): display_out = html.escape(display_out) display_out = nodisp_pattern.sub("", display_out) display_out = html_pattern.sub(lambda x: html.unescape(x.group(1)), display_out) - display_out = "
"+display_out+"
" + display_out = image_pattern.sub(lambda x: '', display_out) + display_out = "
"+display_out+"
" return display_out def _send_to_event_queue(self, value): diff --git a/guidance/models/vertexai/_Gemini.py b/guidance/models/vertexai/_Gemini.py new file mode 100644 index 000000000..0d5fb881e --- /dev/null +++ b/guidance/models/vertexai/_Gemini.py @@ -0,0 +1,138 @@ +import os +from pathlib import Path +import multiprocessing +from itertools import takewhile +import operator +import threading +import numpy as np +import queue +import time +import tiktoken +import re + +from ._vertexai import VertexAICompletion, VertexAIInstruct, VertexAIChat +_image_token_pattern = re.compile(r'<\|_image:(.*)\|>') + +try: + from vertexai.language_models import TextGenerationModel, ChatModel, InputOutputTextPair + from vertexai.preview.generative_models import GenerativeModel, Content, Part, Image + import vertexai + + # def get_chat_response(message): + # vertexai.init(project="PROJECT_ID", location="us-central1") + # model = GenerativeModel("gemini-pro") + # chat = model.start_chat() + # response = chat.send_message(message) + # return response.text + + # print(get_chat_response("Hello")) + # print(get_chat_response("What are all the colors in a rainbow?")) + # print(get_chat_response("Why does it appear when it rains?")) + is_vertexai = True +except ImportError: + is_vertexai = False + +# class GeminiCompletion(VertexAICompletion): +# def __init__(self, model, tokenizer=None, echo=True, caching=True, temperature=0.0, max_streaming_tokens=None, **kwargs): + +# if isinstance(model, str): +# self.model_name = model +# self.model_obj = TextGenerationModel.from_pretrained(self.model_name) + +# # Gemini does not have a public tokenizer, so we pretend it tokenizes like gpt2... +# if tokenizer is None: +# tokenizer = tiktoken.get_encoding("gpt2") + +# # the superclass does all the work +# super().__init__( +# model, +# tokenizer=tokenizer, +# echo=echo, +# caching=caching, +# temperature=temperature, +# max_streaming_tokens=max_streaming_tokens, +# **kwargs +# ) + +# class GeminiInstruct(VertexAIInstruct): +# def __init__(self, model, tokenizer=None, echo=True, caching=True, temperature=0.0, max_streaming_tokens=None, **kwargs): + +# if isinstance(model, str): +# self.model_name = model +# self.model_obj = TextGenerationModel.from_pretrained(self.model_name) + +# # Gemini does not have a public tokenizer, so we pretend it tokenizes like gpt2... +# if tokenizer is None: +# tokenizer = tiktoken.get_encoding("gpt2") + +# # the superclass does all the work +# super().__init__( +# model, +# tokenizer=tokenizer, +# echo=echo, +# caching=caching, +# temperature=temperature, +# max_streaming_tokens=max_streaming_tokens, +# **kwargs +# ) + +class GeminiChat(VertexAIChat): + def __init__(self, model, tokenizer=None, echo=True, caching=True, temperature=0.0, max_streaming_tokens=None, **kwargs): + + if isinstance(model, str): + self.model_name = model + self.model_obj = GenerativeModel(self.model_name) + + # Gemini does not have a public tokenizer, so we pretend it tokenizes like gpt2... + if tokenizer is None: + tokenizer = tiktoken.get_encoding("gpt2") + + # the superclass does all the work + super().__init__( + model, + tokenizer=tokenizer, + echo=echo, + caching=caching, + temperature=temperature, + max_streaming_tokens=max_streaming_tokens, + **kwargs + ) + + def _start_chat(self, system_text, messages): + assert system_text == "", "We don't support passing system text to Gemini models (yet?)!" + out = self.model_obj.start_chat( + history=messages + ) + return out + + def _start_generator(self, system_text, messages, temperature): + # last_user_text = messages[-1]["content"] + formated_messages = [] + for m in messages: + raw_parts = _image_token_pattern.split(m["content"]) + parts = [] + for i in range(0, len(raw_parts), 2): + + # append the text portion + if len(raw_parts[i]) > 0: + parts.append(Part.from_text(raw_parts[i])) + + # append any image + if i + 1 < len(raw_parts): + parts.append(Part.from_image(Image.from_bytes(self[raw_parts[i+1]]))) + formated_messages.append(Content(role=m["role"], parts=parts)) + last_user_parts = formated_messages.pop() # remove the last user stuff that goes in send_message (and not history) + + chat_session = self.model_obj.start_chat( + history=formated_messages, + ) + + generation_config = { + "temperature": temperature + } + if self.max_streaming_tokens is not None: + generation_config["max_output_tokens"] = self.max_streaming_tokens + generator = chat_session.send_message(last_user_parts, generation_config=generation_config, stream=True) + + for chunk in generator: + yield chunk.candidates[0].content.parts[0].text.encode("utf8") \ No newline at end of file diff --git a/guidance/models/vertexai/__init__.py b/guidance/models/vertexai/__init__.py index 59a72e88b..e4854fbf5 100644 --- a/guidance/models/vertexai/__init__.py +++ b/guidance/models/vertexai/__init__.py @@ -1,2 +1,3 @@ from ._PaLM2 import PaLM2Completion, PaLM2Chat, PaLM2Instruct -from ._Codey import CodeyCompletion, CodeyInstruct, CodeyChat \ No newline at end of file +from ._Codey import CodeyCompletion, CodeyInstruct, CodeyChat +from ._Gemini import GeminiChat \ No newline at end of file diff --git a/guidance/models/vertexai/_vertexai.py b/guidance/models/vertexai/_vertexai.py index f23bc4ae9..aaf67af49 100644 --- a/guidance/models/vertexai/_vertexai.py +++ b/guidance/models/vertexai/_vertexai.py @@ -51,6 +51,10 @@ def __init__(self, model, tokenizer=None, echo=True, caching=True, temperature=0 # PaLM2Chat elif re.match("chat-bison(@[0-9]+)?", model_name): found_subclass = vertexai.PaLM2Chat + + # Gemini2Chat + elif re.match("gemini-pro(@[0-9]+)?", model_name): + found_subclass = vertexai.GeminiChat # convert to any found subclass if found_subclass is not None: @@ -149,8 +153,8 @@ def _generator(self, prompt, temperature): end_pos = prompt[pos:].find(role_end) if end_pos < 0: break - messages.append(vertexai.language_models.ChatMessage( - author="user", + messages.append(dict( + role="user", content=prompt[pos:pos+end_pos].decode("utf8"), )) pos += end_pos + len(role_end) @@ -160,23 +164,39 @@ def _generator(self, prompt, temperature): if end_pos < 0: valid_end = True break - messages.append(vertexai.language_models.ChatMessage( - author="assistant", + messages.append(dict( + role="assistant", content=prompt[pos:pos+end_pos].decode("utf8"), )) pos += end_pos + len(role_end) + else: + raise Exception("It looks like your prompt is not a well formed chat prompt! Please enclose all model state appends inside chat role blocks like `user()` or `assistant()`.") self._shared_state["data"] = prompt[:pos] assert len(messages) > 0, "Bad chat format! No chat blocks were defined." - assert messages[-1].author == "user", "Bad chat format! There must be a user() role before the last assistant() role." + assert messages[-1]["role"] == "user", "Bad chat format! There must be a user() role before the last assistant() role." assert valid_end, "Bad chat format! You must generate inside assistant() roles." # TODO: don't make a new session on every call - last_user_text = messages.pop().content + # last_user_text = messages.pop().content + return self._start_generator(system_text.decode("utf8"), messages, temperature) + + # kwargs = {} + # if self.max_streaming_tokens is not None: + # kwargs["max_output_tokens"] = self.max_streaming_tokens + # generator = chat_session.send_message_streaming(last_user_text, temperature=temperature, **kwargs) + + # for chunk in generator: + # yield chunk.text.encode("utf8") + + def _start_generator(self, system_text, messages, temperature): + messages = [vertexai.language_models.ChatMessage(author=m["role"], content=m["content"]) for m in messages] + last_user_text = messages.pop().content + chat_session = self.model_obj.start_chat( - context=system_text.decode("utf8"), + context=system_text, message_history=messages, ) diff --git a/tests/models/test_llama_cpp.py b/tests/models/test_llama_cpp.py index c9e070c7b..33bbbc967 100644 --- a/tests/models/test_llama_cpp.py +++ b/tests/models/test_llama_cpp.py @@ -36,7 +36,7 @@ def test_repeat_calls(): a = [] lm = llama2 + 'How much is 2 + 2? ' + gen(name='test', max_tokens=10) a.append(lm['test']) - lm = llama2 + 'How much is 2 + 2? ' + gen(name='test',max_tokens=10, pattern=r'\d+') + lm = llama2 + 'How much is 2 + 2? ' + gen(name='test',max_tokens=10, regex=r'\d+') a.append(lm['test']) lm = llama2 + 'How much is 2 + 2? ' + gen(name='test', max_tokens=10) a.append(lm['test']) diff --git a/tests/models/test_vertexai.py b/tests/models/test_vertexai.py index e5af2cd7b..20679ecbc 100644 --- a/tests/models/test_vertexai.py +++ b/tests/models/test_vertexai.py @@ -6,7 +6,7 @@ def test_palm2_instruct(): try: vmodel = models.VertexAI("text-bison@001") except: - pytest.skip("Skipping OpenAI test because we can't load the model!") + pytest.skip("Skipping VertexAI test because we can't load the model!") with instruction(): lm = vmodel + "this is a test about" @@ -19,7 +19,7 @@ def test_palm2_chat(): try: vmodel = models.VertexAI("chat-bison@001") except: - pytest.skip("Skipping OpenAI test because we can't load the model!") + pytest.skip("Skipping VertexAI test because we can't load the model!") with system(): lm = vmodel + "You are an always-happy agent no matter what." @@ -43,6 +43,50 @@ def test_palm2_chat(): with system(): lm = vmodel + "You are an always-happy agent no matter what." + with user(): + lm += "The economy is crashing!" + + with assistant(): + lm += gen("test1", max_tokens=100) + + with user(): + lm += "What is the best again?" + + with assistant(): + lm += gen("test2", max_tokens=100) + + assert len(lm["test1"]) > 0 + assert len(lm["test2"]) > 0 + assert lm["test1"].find("<|im_end|>") < 0 + +def test_gemini_chat(): + from guidance import models, gen, system, user, assistant + + try: + vmodel = models.VertexAI("gemini-pro") + except: + pytest.skip("Skipping VertexAI test because we can't load the model!") + + lm = vmodel + + with user(): + lm += "The economy is crashing!" + + with assistant(): + lm += gen("test1", max_tokens=100) + + with user(): + lm += "What is the best again?" + + with assistant(): + lm += gen("test2", max_tokens=100) + + assert len(lm["test1"]) > 0 + assert len(lm["test2"]) > 0 + + # second time to make sure cache reuse is okay + lm = vmodel + with user(): lm += "The economy is crashing!"