Skip to content

Commit

Permalink
Add multi-model gemini-pro support
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed Dec 14, 2023
1 parent 96f2e02 commit 3e0a976
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 13 deletions.
1 change: 1 addition & 0 deletions guidance/library/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# import functions that can be called directly
from ._gen import gen, call_tool, will_gen
from ._image import image

# core grammar functions
from .._grammar import select
Expand Down
30 changes: 30 additions & 0 deletions guidance/library/_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import guidance
import urllib
import typing
import http
import re

@guidance
def image(lm, src, allow_local=True):

# load the image bytes
# ...from a url
if isinstance(src, str) and re.match(r'$[^:/]+://', src):
with urllib.request.urlopen(src) as response:
response = typing.cast(http.client.HTTPResponse, response)
bytes = response.read()

# ...from a local path
elif allow_local and isinstance(src, str):
with open(src, "rb") as f:
bytes = f.read()

else:
raise Exception(f"Unable to load image bytes from {src}!")

bytes_id = str(id(bytes))

# set the image bytes
lm = lm.set(bytes_id, bytes)
lm += f'<|_image:{bytes_id}|>'
return lm
2 changes: 1 addition & 1 deletion guidance/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._model import Model, Chat
from .vertexai._vertexai import VertexAI, VertexAIChat
from .vertexai._vertexai import VertexAI, VertexAIChat, VertexAICompletion, VertexAIInstruct
from ._azure_openai import AzureOpenAI, AzureOpenAIChat, AzureOpenAICompletion, AzureOpenAIInstruct
from ._openai import OpenAI, OpenAIChat, OpenAIInstruct, OpenAICompletion
from .transformers._transformers import Transformers, TransformersChat
Expand Down
6 changes: 5 additions & 1 deletion guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import time
import numpy as np
import logging
import base64

logger = logging.getLogger(__name__)
try:
from .. import cpp
Expand All @@ -29,6 +31,7 @@
format_pattern = re.compile(r"<\|\|_.*?_\|\|>", flags=re.DOTALL)
nodisp_pattern = re.compile(r"&lt;\|\|_#NODISP_\|\|&gt;.*?&lt;\|\|_/NODISP_\|\|&gt;", flags=re.DOTALL)
html_pattern = re.compile(r"&lt;\|\|_html:(.*?)_\|\|&gt;", flags=re.DOTALL)
image_pattern = re.compile(r"&lt;\|_image:(.*?)\|&gt;")

class Model:
'''A guidance model object, which represents a sequence model in a given state.
Expand Down Expand Up @@ -122,7 +125,8 @@ def _html(self):
display_out = html.escape(display_out)
display_out = nodisp_pattern.sub("", display_out)
display_out = html_pattern.sub(lambda x: html.unescape(x.group(1)), display_out)
display_out = "<pre style='margin: 0px; padding: 0px; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'>"+display_out+"</pre>"
display_out = image_pattern.sub(lambda x: '<img src="data:image/png;base64,' + base64.b64encode(self[x.groups(1)[0]]).decode() + '" style="max-width: 400px; vertical-align: middle; margin: 4px;">', display_out)
display_out = "<pre style='margin: 0px; padding: 0px; vertical-align: middle; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'>"+display_out+"</pre>"
return display_out

def _send_to_event_queue(self, value):
Expand Down
138 changes: 138 additions & 0 deletions guidance/models/vertexai/_Gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import os
from pathlib import Path
import multiprocessing
from itertools import takewhile
import operator
import threading
import numpy as np
import queue
import time
import tiktoken
import re

from ._vertexai import VertexAICompletion, VertexAIInstruct, VertexAIChat
_image_token_pattern = re.compile(r'<\|_image:(.*)\|>')

try:
from vertexai.language_models import TextGenerationModel, ChatModel, InputOutputTextPair
from vertexai.preview.generative_models import GenerativeModel, Content, Part, Image
import vertexai

# def get_chat_response(message):
# vertexai.init(project="PROJECT_ID", location="us-central1")
# model = GenerativeModel("gemini-pro")
# chat = model.start_chat()
# response = chat.send_message(message)
# return response.text

# print(get_chat_response("Hello"))
# print(get_chat_response("What are all the colors in a rainbow?"))
# print(get_chat_response("Why does it appear when it rains?"))
is_vertexai = True
except ImportError:
is_vertexai = False

# class GeminiCompletion(VertexAICompletion):
# def __init__(self, model, tokenizer=None, echo=True, caching=True, temperature=0.0, max_streaming_tokens=None, **kwargs):

# if isinstance(model, str):
# self.model_name = model
# self.model_obj = TextGenerationModel.from_pretrained(self.model_name)

# # Gemini does not have a public tokenizer, so we pretend it tokenizes like gpt2...
# if tokenizer is None:
# tokenizer = tiktoken.get_encoding("gpt2")

# # the superclass does all the work
# super().__init__(
# model,
# tokenizer=tokenizer,
# echo=echo,
# caching=caching,
# temperature=temperature,
# max_streaming_tokens=max_streaming_tokens,
# **kwargs
# )

# class GeminiInstruct(VertexAIInstruct):
# def __init__(self, model, tokenizer=None, echo=True, caching=True, temperature=0.0, max_streaming_tokens=None, **kwargs):

# if isinstance(model, str):
# self.model_name = model
# self.model_obj = TextGenerationModel.from_pretrained(self.model_name)

# # Gemini does not have a public tokenizer, so we pretend it tokenizes like gpt2...
# if tokenizer is None:
# tokenizer = tiktoken.get_encoding("gpt2")

# # the superclass does all the work
# super().__init__(
# model,
# tokenizer=tokenizer,
# echo=echo,
# caching=caching,
# temperature=temperature,
# max_streaming_tokens=max_streaming_tokens,
# **kwargs
# )

class GeminiChat(VertexAIChat):
def __init__(self, model, tokenizer=None, echo=True, caching=True, temperature=0.0, max_streaming_tokens=None, **kwargs):

if isinstance(model, str):
self.model_name = model
self.model_obj = GenerativeModel(self.model_name)

# Gemini does not have a public tokenizer, so we pretend it tokenizes like gpt2...
if tokenizer is None:
tokenizer = tiktoken.get_encoding("gpt2")

# the superclass does all the work
super().__init__(
model,
tokenizer=tokenizer,
echo=echo,
caching=caching,
temperature=temperature,
max_streaming_tokens=max_streaming_tokens,
**kwargs
)

def _start_chat(self, system_text, messages):
assert system_text == "", "We don't support passing system text to Gemini models (yet?)!"
out = self.model_obj.start_chat(
history=messages
)
return out

def _start_generator(self, system_text, messages, temperature):
# last_user_text = messages[-1]["content"]
formated_messages = []
for m in messages:
raw_parts = _image_token_pattern.split(m["content"])
parts = []
for i in range(0, len(raw_parts), 2):

# append the text portion
if len(raw_parts[i]) > 0:
parts.append(Part.from_text(raw_parts[i]))

# append any image
if i + 1 < len(raw_parts):
parts.append(Part.from_image(Image.from_bytes(self[raw_parts[i+1]])))
formated_messages.append(Content(role=m["role"], parts=parts))
last_user_parts = formated_messages.pop() # remove the last user stuff that goes in send_message (and not history)

chat_session = self.model_obj.start_chat(
history=formated_messages,
)

generation_config = {
"temperature": temperature
}
if self.max_streaming_tokens is not None:
generation_config["max_output_tokens"] = self.max_streaming_tokens
generator = chat_session.send_message(last_user_parts, generation_config=generation_config, stream=True)

for chunk in generator:
yield chunk.candidates[0].content.parts[0].text.encode("utf8")
3 changes: 2 additions & 1 deletion guidance/models/vertexai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ._PaLM2 import PaLM2Completion, PaLM2Chat, PaLM2Instruct
from ._Codey import CodeyCompletion, CodeyInstruct, CodeyChat
from ._Codey import CodeyCompletion, CodeyInstruct, CodeyChat
from ._Gemini import GeminiChat
34 changes: 27 additions & 7 deletions guidance/models/vertexai/_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def __init__(self, model, tokenizer=None, echo=True, caching=True, temperature=0
# PaLM2Chat
elif re.match("chat-bison(@[0-9]+)?", model_name):
found_subclass = vertexai.PaLM2Chat

# Gemini2Chat
elif re.match("gemini-pro(@[0-9]+)?", model_name):
found_subclass = vertexai.GeminiChat

# convert to any found subclass
if found_subclass is not None:
Expand Down Expand Up @@ -149,8 +153,8 @@ def _generator(self, prompt, temperature):
end_pos = prompt[pos:].find(role_end)
if end_pos < 0:
break
messages.append(vertexai.language_models.ChatMessage(
author="user",
messages.append(dict(
role="user",
content=prompt[pos:pos+end_pos].decode("utf8"),
))
pos += end_pos + len(role_end)
Expand All @@ -160,23 +164,39 @@ def _generator(self, prompt, temperature):
if end_pos < 0:
valid_end = True
break
messages.append(vertexai.language_models.ChatMessage(
author="assistant",
messages.append(dict(
role="assistant",
content=prompt[pos:pos+end_pos].decode("utf8"),
))
pos += end_pos + len(role_end)
else:
raise Exception("It looks like your prompt is not a well formed chat prompt! Please enclose all model state appends inside chat role blocks like `user()` or `assistant()`.")

self._shared_state["data"] = prompt[:pos]

assert len(messages) > 0, "Bad chat format! No chat blocks were defined."
assert messages[-1].author == "user", "Bad chat format! There must be a user() role before the last assistant() role."
assert messages[-1]["role"] == "user", "Bad chat format! There must be a user() role before the last assistant() role."
assert valid_end, "Bad chat format! You must generate inside assistant() roles."

# TODO: don't make a new session on every call
last_user_text = messages.pop().content
# last_user_text = messages.pop().content

return self._start_generator(system_text.decode("utf8"), messages, temperature)

# kwargs = {}
# if self.max_streaming_tokens is not None:
# kwargs["max_output_tokens"] = self.max_streaming_tokens
# generator = chat_session.send_message_streaming(last_user_text, temperature=temperature, **kwargs)

# for chunk in generator:
# yield chunk.text.encode("utf8")

def _start_generator(self, system_text, messages, temperature):
messages = [vertexai.language_models.ChatMessage(author=m["role"], content=m["content"]) for m in messages]
last_user_text = messages.pop().content

chat_session = self.model_obj.start_chat(
context=system_text.decode("utf8"),
context=system_text,
message_history=messages,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_repeat_calls():
a = []
lm = llama2 + 'How much is 2 + 2? ' + gen(name='test', max_tokens=10)
a.append(lm['test'])
lm = llama2 + 'How much is 2 + 2? ' + gen(name='test',max_tokens=10, pattern=r'\d+')
lm = llama2 + 'How much is 2 + 2? ' + gen(name='test',max_tokens=10, regex=r'\d+')
a.append(lm['test'])
lm = llama2 + 'How much is 2 + 2? ' + gen(name='test', max_tokens=10)
a.append(lm['test'])
Expand Down
48 changes: 46 additions & 2 deletions tests/models/test_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def test_palm2_instruct():
try:
vmodel = models.VertexAI("text-bison@001")
except:
pytest.skip("Skipping OpenAI test because we can't load the model!")
pytest.skip("Skipping VertexAI test because we can't load the model!")

with instruction():
lm = vmodel + "this is a test about"
Expand All @@ -19,7 +19,7 @@ def test_palm2_chat():
try:
vmodel = models.VertexAI("chat-bison@001")
except:
pytest.skip("Skipping OpenAI test because we can't load the model!")
pytest.skip("Skipping VertexAI test because we can't load the model!")

with system():
lm = vmodel + "You are an always-happy agent no matter what."
Expand All @@ -43,6 +43,50 @@ def test_palm2_chat():
with system():
lm = vmodel + "You are an always-happy agent no matter what."

with user():
lm += "The economy is crashing!"

with assistant():
lm += gen("test1", max_tokens=100)

with user():
lm += "What is the best again?"

with assistant():
lm += gen("test2", max_tokens=100)

assert len(lm["test1"]) > 0
assert len(lm["test2"]) > 0
assert lm["test1"].find("<|im_end|>") < 0

def test_gemini_chat():
from guidance import models, gen, system, user, assistant

try:
vmodel = models.VertexAI("gemini-pro")
except:
pytest.skip("Skipping VertexAI test because we can't load the model!")

lm = vmodel

with user():
lm += "The economy is crashing!"

with assistant():
lm += gen("test1", max_tokens=100)

with user():
lm += "What is the best again?"

with assistant():
lm += gen("test2", max_tokens=100)

assert len(lm["test1"]) > 0
assert len(lm["test2"]) > 0

# second time to make sure cache reuse is okay
lm = vmodel

with user():
lm += "The economy is crashing!"

Expand Down

0 comments on commit 3e0a976

Please sign in to comment.