Skip to content

Commit

Permalink
Add support for images in UserMessages (#546)
Browse files Browse the repository at this point in the history
This adds support for passing images as inputs to OpenAI models. They
can be embedded in `<UserMessage>`s using the existing `<Image>` tag.

By default `OpenAIChatModel` passes images for models that supports it,
but this can be forced (or denied) by setting the `includeImages` props
explicitly.
  • Loading branch information
petersalas authored May 10, 2024
1 parent e7b3e2e commit 2df9c5e
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 22 deletions.
2 changes: 1 addition & 1 deletion packages/ai-jsx/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"repository": "fixie-ai/ai-jsx",
"bugs": "https://github.com/fixie-ai/ai-jsx/issues",
"homepage": "https://ai-jsx.com",
"version": "0.32.0",
"version": "0.33.0",
"volta": {
"extends": "../../package.json"
},
Expand Down
4 changes: 2 additions & 2 deletions packages/ai-jsx/src/core/conversation.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,13 @@ export async function renderToConversation(
render: AI.ComponentContext['render'],
logger?: AI.ComponentContext['logger'],
logType?: 'prompt' | 'completion',
cost?: (message: ConversationMessage, render: AI.ComponentContext['render']) => Promise<number>,
cost?: (message: ConversationMessage) => Promise<number>,
budget?: number
) {
const cachedCosts = new WeakMap<AI.Element<any>, Promise<number>>();
function cachedCost(message: ConversationMessage): Promise<number> {
if (!cachedCosts.has(message.element)) {
cachedCosts.set(message.element, cost!(message, render));
cachedCosts.set(message.element, cost!(message));
}

return cachedCosts.get(message.element)!;
Expand Down
22 changes: 12 additions & 10 deletions packages/ai-jsx/src/core/image-gen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,6 @@ export function ImageGen({ children, ...props }: ImageGenPropsWithChildren, { ge
);
}

export interface GeneratedImageProps {
/** The image URL. */
url: string;
/** The prompt used for generating the image. Currently only used for debugging. */
prompt: string;
/** The model used for generating the image. Currently only used for debugging. */
modelName: string;
}

/**
* This component represents an image via a single `url` prop.
* It is a wrapper for the output of {@link ImageGen} to allow for first-class support of images.
Expand All @@ -106,6 +97,17 @@ export interface GeneratedImageProps {
* - In terminal-based environments, this component will be rendered as a URL.
* - In browser-based environments, this component will be rendered as an `img` tag.
*/
export function Image(props: GeneratedImageProps) {
export function Image(props: {
/** The image URL. */
url: string;
/** The prompt used for generating the image. Currently only used for debugging. */
prompt?: string;
/** The model used for generating the image. Currently only used for debugging. */
modelName?: string;
/** The level of detail required. */
detail?: string;
/** The number of input tokens required. */
inputTokens?: number;
}) {
return props.url;
}
102 changes: 94 additions & 8 deletions packages/ai-jsx/src/lib/openai.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,20 @@ export type ValidChatModel =
| 'gpt-4-0125-preview'
| 'gpt-4-0613'
| 'gpt-4-1106-preview'
| 'gpt-4-1106-vision-preview'
| 'gpt-4-32k'
| 'gpt-4-32k-0613'
| 'gpt-4-turbo'
| 'gpt-4-turbo-2024-04-09'
| 'gpt-4-turbo-preview';
| 'gpt-4-turbo-preview'
| 'gpt-4-vision-preview';

const visionModels: Partial<Record<ValidChatModel, true>> = {
'gpt-4-1106-vision-preview': true,
'gpt-4-turbo-2024-04-09': true,
'gpt-4-turbo': true,
'gpt-4-vision-preview': true,
};

/**
* An OpenAI client that talks to the Azure OpenAI service.
Expand Down Expand Up @@ -230,6 +239,8 @@ function tokenLimitForChatModel(
case 'gpt-4-turbo-preview':
case 'gpt-4-turbo-2024-04-09':
case 'gpt-4-turbo':
case 'gpt-4-vision-preview':
case 'gpt-4-1106-vision-preview':
return 128_000 - functionEstimate - TOKENS_CONSUMED_BY_REPLY_PREFIX;
case 'gpt-3.5-turbo-0301':
case 'gpt-3.5-turbo-0613':
Expand All @@ -248,19 +259,70 @@ function tokenLimitForChatModel(
}
}

export async function tokenCountForConversationMessage(
const imageTokenCost = Symbol('imageTokenCost');
interface ImagePartWithTokenCost extends OpenAIClient.ChatCompletionContentPartImage {
[imageTokenCost]: () => number;
}
type PromptPart = OpenAIClient.ChatCompletionContentPartText | ImagePartWithTokenCost;

async function renderWithImages(render: AI.RenderContext['render'], element: Node): Promise<PromptPart[]> {
const textAndImages = await render(element, { stop: (node) => node.tag === Image });
const content: PromptPart[] = [];

textAndImages.forEach((node: string | AI.Element<AI.PropsOfComponent<typeof Image>>) => {
if (typeof node === 'string') {
const lastContentPart = content.at(-1);
if (lastContentPart?.type === 'text') {
// Merge adjacent text nodes.
lastContentPart.text += node;
} else {
content.push({ type: 'text', text: node });
}
} else {
content.push({
type: 'image_url',
image_url: { url: node.props.url, detail: node.props.detail as any },
[imageTokenCost]() {
if (node.props.inputTokens) {
return node.props.inputTokens;
}

// https://platform.openai.com/docs/guides/vision/calculating-costs
if (node.props.detail === 'high') {
// Assume 6 tiles.
return 6 * 170 + 85;
}

// Otherwise assume low detail.
return 85;
},
});
}
});

return content;
}

async function tokenCountForConversationMessage(
message: ConversationMessage,
render: AI.RenderContext['render']
render: AI.RenderContext['render'],
includeImages: boolean
): Promise<number> {
const TOKENS_PER_MESSAGE = 3;
const TOKENS_PER_NAME = 1;
switch (message.type) {
case 'user':
case 'user': {
const textAndImages: PromptPart[] = includeImages
? await renderWithImages(render, message.element)
: [{ type: 'text', text: await render(message.element) }];
return (
TOKENS_PER_MESSAGE +
tokenizer.encode(await render(message.element)).length +
textAndImages
.map((part) => (part.type === 'text' ? tokenizer.encode(part.text).length : part[imageTokenCost]()))
.reduce((a, b) => a + b, 0) +
(message.element.props.name ? tokenizer.encode(message.element.props.name).length + TOKENS_PER_NAME : 0)
);
}
case 'assistant':
case 'system':
return TOKENS_PER_MESSAGE + tokenizer.encode(await render(message.element)).length;
Expand Down Expand Up @@ -368,6 +430,7 @@ export async function* OpenAIChatModel(
props: ModelPropsWithChildren & {
model: ValidChatModel;
logitBias?: Record<string, number>;
includeImages?: boolean;
},
{ render, getContext, logger, memo }: AI.ComponentContext
): AI.RenderableStream {
Expand All @@ -383,13 +446,17 @@ export async function* OpenAIChatModel(

const modelTokenLimit = tokenLimitForChatModel(props.model, props.functionDefinitions);
const promptTokenLimit = props.maxInputTokens ?? modelTokenLimit - (props.reservedTokens ?? props.maxTokens ?? 0);
const includeImages = props.includeImages ?? props.model in visionModels;

const tokenCostForMessage = (message: ConversationMessage) =>
tokenCountForConversationMessage(message, render, includeImages);

const conversationMessages = await renderToConversation(
props.children,
render,
logger,
'prompt',
tokenCountForConversationMessage,
tokenCostForMessage,
promptTokenLimit
);

Expand All @@ -406,11 +473,30 @@ export async function* OpenAIChatModel(
role: 'system',
content: await render(message.element),
};
case 'user':
case 'user': {
if (includeImages) {
const content = await renderWithImages(render, message.element);
// Prefer to pass as a single string if possible.
if (content.some((part) => part.type !== 'text')) {
return {
role: 'user',
content,
};
}

return {
role: 'user',
content: (content as OpenAIClient.ChatCompletionContentPartText[])
.map((part: OpenAIClient.ChatCompletionContentPartText) => part.text)
.join(''),
};
}

return {
role: 'user',
content: await render(message.element),
};
}
case 'assistant':
return {
role: 'assistant',
Expand Down Expand Up @@ -656,7 +742,7 @@ export async function* OpenAIChatModel(
}

// Render the completion conversation to log it.
await renderToConversation(outputMessages, render, logger, 'completion', tokenCountForConversationMessage);
await renderToConversation(outputMessages, render, logger, 'completion', tokenCostForMessage);
return AI.AppendOnlyStream;
}

Expand Down
6 changes: 5 additions & 1 deletion packages/docs/docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Changelog

## 0.32.0
## 0.33.0

- Add support for passing `<Image>` to OpenAI models.

## [0.32.0](https://github.com/fixie-ai/ai-jsx/tree/e7b3e2e444659a49e04693337c2af023c506fbe6)

- Improve rendering performance:
- Change `<ShrinkConversation>` to cache token costs when remeasuring the same elements
Expand Down
1 change: 1 addition & 0 deletions packages/examples/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"demo:image-generation": "yarn build && node dist/src/image-generation.js",
"demo:shrink": "yarn build && node dist/src/conversation-shrinking.js",
"demo:fixie-corpus": "yarn build && node dist/src/fixie-corpus.js",
"demo:vision": "yarn build && node dist/src/vision.js",
"view-logs": "cat ai-jsx.log | pino-pretty",
"lint": "eslint . --max-warnings 0",
"lint:fix": "eslint . --fix",
Expand Down
23 changes: 23 additions & 0 deletions packages/examples/src/vision.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { ChatCompletion, UserMessage } from 'ai-jsx/core/completion';
import { Image } from 'ai-jsx/core/image-gen';
import { showInspector } from 'ai-jsx/core/inspector';
import { OpenAI } from 'ai-jsx/lib/openai';

function App() {
return (
<OpenAI chatModel="gpt-4-turbo">
<ChatCompletion>
<UserMessage>
What do the following images have in common?
<Image url="https://upload.wikimedia.org/wikipedia/commons/8/89/Apollo_11_bootprint.jpg" detail="low" />
<Image
url="https://upload.wikimedia.org/wikipedia/commons/3/36/PIA00563-Viking1-FirstColorImage-19760721.jpg"
detail="low"
/>
</UserMessage>
</ChatCompletion>
</OpenAI>
);
}

showInspector(<App />);

0 comments on commit 2df9c5e

Please sign in to comment.