Skip to content

Commit

Permalink
Merge branch 'feature/image'
Browse files Browse the repository at this point in the history
  • Loading branch information
amoffat committed Aug 3, 2024
2 parents b916f74 + e3b7257 commit bd7d27e
Show file tree
Hide file tree
Showing 14 changed files with 201 additions and 29 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 0.2.0 - 8/2/24

- Support for `tuple` serde datatype

# 0.1.2 - 8/2/24

- Package metadata tweaks
Expand Down
26 changes: 16 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ from manifest import ai
@ai
def is_optimistic(text: str) -> bool:
""" Determines if the text is optimistic"""
...

assert is_optimistic("This is amazing!")
```
Expand All @@ -40,25 +39,24 @@ from manifest import ai
@ai
def translate(english_text: str, target_lang: str) -> str:
""" Translates text from english into a target language """
...

assert translate("Hello", "fr") == "Bonjour"
```

## Image analysis

You can pass in bytes to make use of a model's multimodal abilities. COMING SOON
You can pass in a file to make use of a model's multimodal abilities. It
supports `Path`, `io.BytesIO` and `io.BufferedReader`

```python
import io
from pathlib import Path
from manifest import ai

@ai
def breed_of_dog(image: io.BytesIO) -> str:
""" Determines the breed of dog from a photo """
...
def breed_of_dog(image: Path) -> str:
"""Determines the breed of dog from a photo"""

image = open("/path/to/terrier.jpg", "r")
image = Path("path/to/dog.jpg")
print(breed_of_dog(image))
```

Expand Down Expand Up @@ -86,7 +84,6 @@ class Movie:
def similar_movie(movie: str, before_year: int | None=None) -> Movie:
"""Discovers a similar movie, before a certain year, if the year is
provided."""
...

like_inception = similar_movie("Inception")
print(like_inception)
Expand All @@ -109,15 +106,24 @@ To get the most out the `@ai` decorator:

# Limitations

## REPL

Manifest doesn't work from the REPL, due to it needing access to the source code
of the functions it decorates.

## Types

You can only pass in and return the following types:

- Dataclasses
- `Enum` subclasses
- primitives (str, int, bool, None, etc)
- basic container types (list, dict)
- basic container types (list, dict, tuple)
- unions
- Any combination of the above

## Prompts

The prompt templates are also a little fiddly sometimes. They can be improved.

# Initialization
Expand Down
30 changes: 28 additions & 2 deletions manifest/decorator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import inspect
import json
from functools import wraps
from io import BytesIO
from typing import Any, Callable

from manifest import initialize, parser, serde, tmpl
from manifest.llm.base import LLM
from manifest.types.service import Service
from manifest.utils.args import get_args_dict
from manifest.utils.types import extract_type_registry
from manifest.utils.asset import get_asset_data
from manifest.utils.types import extract_type_registry, is_asset

# Lazily-initialized singleton LLM client. This must be lazy because we don't
# know when the user will have initialized the client, either automatically or
Expand Down Expand Up @@ -68,16 +70,36 @@ def inner(*exec_args, **exec_kwargs) -> Any:
kwargs=exec_kwargs,
)

images: list[BytesIO] = []

# Aggregate our arguments into a format that the LLM can understand
args = []
for arg_name, arg_value in call_args.items():

try:
arg_type = ants[arg_name]
except KeyError:
raise ValueError(
f"Function '{name}' is missing an annotation for '{arg_name}'"
)

if is_asset(arg_type):
images.append(get_asset_data(arg_value))

args.append(
{
"name": arg_name,
"value": "Uploaded image",
}
)
continue

elif not isinstance(arg_value, arg_type):
raise TypeError(
f"Argument '{arg_name}' is of type {type(arg_value)}, "
f"but should be of type {arg_type}"
)

# Collect the source code of the argument type, if it's not a
# builtin type
src = None
Expand Down Expand Up @@ -112,7 +134,11 @@ def inner(*exec_args, **exec_kwargs) -> Any:
prompt = call_tmpl.render(**tmpl_params)
system_msg = system_tmpl.render()

resp = llm.call(prompt=prompt, system_msg=system_msg)
resp = llm.call(
prompt=prompt,
system_msg=system_msg,
images=images,
)

data_spec = parser.parse_return_value(resp)
ret_val = serde.deserialize(
Expand Down
9 changes: 8 additions & 1 deletion manifest/llm/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from io import BytesIO
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -9,7 +10,13 @@ class LLM(ABC):
"""The base class for all LLM clients."""

@abstractmethod
def call(self, *, prompt: str, system_msg: str) -> str:
def call(
self,
*,
prompt: str,
system_msg: str,
images: list[BytesIO] | None = None,
) -> str:
"""Run an LLM completion"""
...

Expand Down
47 changes: 41 additions & 6 deletions manifest/llm/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
from io import BytesIO
from typing import TYPE_CHECKING

import openai
Expand All @@ -6,7 +8,10 @@
from manifest.types.service import Service

if TYPE_CHECKING:
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat import (
ChatCompletionContentPartParam,
ChatCompletionMessageParam,
)


class OpenAILLM(LLM):
Expand All @@ -18,21 +23,51 @@ def __init__(self, *, api_key: str, model: str):
def service() -> Service:
return Service.OPENAI

def call(self, *, prompt: str, system_msg: str) -> str:
def call(
self,
*,
prompt: str,
system_msg: str,
images: list[BytesIO] | None = None,
) -> str:

user_message: "ChatCompletionMessageParam" = {
"role": "user",
"content": prompt,
}

if images:
content: list["ChatCompletionContentPartParam"] = [
{"type": "text", "text": prompt}
]
for image in images:
b64_data = base64.b64encode(image.getvalue()).decode("utf-8")
img_block: "ChatCompletionContentPartParam" = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64_data}",
},
}
content.append(img_block)

user_message = {
"role": "user",
"content": content,
}

messages: list["ChatCompletionMessageParam"] = [
{
"role": "system",
"content": system_msg,
},
{
"role": "user",
"content": prompt,
},
user_message,
]

completion = self.client.chat.completions.create(
model=self.model,
messages=messages,
# Somehow produces worse outcomes
# response_format={"type": "json_object"},
)

resp = completion.choices[0].message.content or ""
Expand Down
13 changes: 9 additions & 4 deletions manifest/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ def parse_document(s: str) -> etree._Element:
return root


def find_and_parse_xml(s: str, root_tag="root") -> etree._Element:
def find_and_parse_xml(s: str, root_tag="root") -> etree._Element | None:
"""For a given string, find the first XML document and parse it."""
start_tag = f"<{root_tag}>"
end_tag = f"</{root_tag}>"

start_index = s.find(start_tag)
end_index = s.find(end_tag)

# Couldn't find the root tag
if start_index == -1 or end_index == -1 or start_index > end_index:
raise ValueError("No valid XML document found: " + s)
return None

xml_document = s[start_index : end_index + len(end_tag)]

Expand Down Expand Up @@ -68,6 +69,10 @@ def render_inner(node: etree._Element) -> str:

def parse_return_value(s: str) -> str:
root = find_and_parse_xml(s, root_tag="return-value")
json_str = render_inner(root)
data_spec = json.loads(json_str)
if root is not None:
json_str = render_inner(root)
data_spec = json.loads(json_str)
else:
data_spec = json.loads(s)

return data_spec
24 changes: 22 additions & 2 deletions manifest/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def serialize(data_type: Type | UnionType) -> dict:
"items": serialize(item_type),
}

if get_origin(data_type) is tuple:
item_types = get_args(data_type)
return {
"type": "array",
"prefixItems": [serialize(t) for t in item_types],
}

if get_origin(data_type) is dict:
_, value_type = get_args(data_type)
return {
Expand Down Expand Up @@ -82,8 +89,8 @@ def deserialize(
"""
Deserialize a JSON-friendly data structure into a dataclass instance.
:param spec: The JSON-friendly data structure that describes the dataclass.
:param data: The JSON-friendly data structure to deserialize.
:param spec: The JSON-encodable data structure that describes the dataclass.
:param data: The JSON-encodable data structure to deserialize.
:param registry: A dictionary mapping type names to the corresponding types.
:return: An instance of the dataclass.
Expand Down Expand Up @@ -139,6 +146,19 @@ def deserialize(

# list container type
if type == "array":

# If items is a collection, we have a tuple
if "prefixItems" in schema:
return tuple(
deserialize(
schema=schema["prefixItems"][i],
data=d,
registry=registry,
)
for i, d in enumerate(data)
)

# Otherwise, we have a normal list
return [
deserialize(
schema=schema["items"],
Expand Down
6 changes: 4 additions & 2 deletions manifest/templates/openai/call.j2
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ The function is being called with the following arguments:
{%- for arg in args %}
<argument>
<name>{{ arg.name }}</name>
{%- if 'schema' in arg %}
<jsonschema>
{{ arg.schema | trim | indent(12) }}
</jsonschema>
{%- endif %}
<value>
{{ arg.value | trim | indent(12) }}
</value>
{%- if arg.src %}
{%- if 'src' in arg %}
<python-src>
{{ arg.src | trim | indent(12) }}
</python-src>
Expand Down Expand Up @@ -44,4 +46,4 @@ Imagine calling the function with the above arguments and returning a value of t
...
</return-value>

The <return-value> tag should contain the json-formatted return value of the function call, conforming to the jsonschema defined in the <return-type>'s <jsonschema> tag.
The <return-value> tag should contain only the json-formatted return value of the function call, conforming to the jsonschema defined in the <return-type>'s <jsonschema> tag. Even if the return type is simple, like {"type": "string"}, the return value should be formatted as jsonschema value.
2 changes: 1 addition & 1 deletion manifest/templates/openai/system.j2
Original file line number Diff line number Diff line change
@@ -1 +1 @@
You are a bot that executes function calls. You'll receive some metadata about a function, such as the function signature, details about its arguments, and details about its return type. You will imagine calling the function and producing a sensible return value. It is very important that your result value always be jsonschema-conformant.
You are a bot that executes function calls. You'll receive some metadata about a function, such as the function signature, details about its arguments, and details about its return type. You will imagine calling the function and producing a sensible return value. It is very important that your result value always be json that conforms to a given jsonschema.
21 changes: 21 additions & 0 deletions manifest/utils/asset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import io
from pathlib import Path
from typing import Any


def get_asset_data(arg_value: Any) -> io.BytesIO:
"""Get the data from an asset"""

if isinstance(arg_value, io.BytesIO):
return arg_value

if isinstance(arg_value, io.BufferedReader):
pos = arg_value.tell()
bio = io.BytesIO(arg_value.read())
arg_value.seek(pos)
return bio

if isinstance(arg_value, Path):
return io.BytesIO(arg_value.read_bytes())

raise TypeError(f"Unsupported asset type: {type(arg_value)}")
10 changes: 10 additions & 0 deletions manifest/utils/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import io
from dataclasses import fields, is_dataclass
from enum import Enum
from pathlib import Path
from types import UnionType
from typing import Any, Type, get_args, get_origin

Expand Down Expand Up @@ -33,3 +35,11 @@ def extract_type_registry(
extract_type_registry(type_registry, value_type)

return type_registry


def is_asset(arg_type: Type) -> bool:
"""Check if a type is an uploadable asset"""
if issubclass(arg_type, (io.BytesIO, io.BufferedReader, Path)):
return True

return False
Loading

0 comments on commit bd7d27e

Please sign in to comment.