Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to shrink conversation based on cost/budget/importance #228

Merged
merged 4 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/ai-jsx/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"repository": "fixie-ai/ai-jsx",
"bugs": "https://github.com/fixie-ai/ai-jsx/issues",
"homepage": "https://ai-jsx.com",
"version": "0.8.0",
"version": "0.8.1",
"volta": {
"extends": "../../package.json"
},
Expand Down Expand Up @@ -351,8 +351,8 @@
"axios": "^1.4.0",
"cli-highlight": "^2.1.11",
"cli-spinners": "^2.9.0",
"gpt3-tokenizer": "^1.1.5",
"ink": "^4.2.0",
"js-tiktoken": "^1.0.7",
"js-yaml": "^4.1.0",
"langchain": "^0.0.81",
"lodash": "^4.17.21",
Expand Down
2 changes: 1 addition & 1 deletion packages/ai-jsx/src/batteries/natural-language-router.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ export async function* NaturalLanguageRouter(props: { children: Node; query: Nod
}

const props = e.props as RouteProps;
return props.unmatched ? choiceIndex === 0 : props.when === whenOptions[choiceIndex];
return props.unmatched ? choiceIndex === whenOptions.length - 1 : props.when === whenOptions[choiceIndex];
});
}

Expand Down
245 changes: 243 additions & 2 deletions packages/ai-jsx/src/core/conversation.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import { ChatCompletionResponseMessage } from 'openai';
import * as AI from '../index.js';
import { Node } from '../index.js';
import { AIJSXError, ErrorCode } from '../core/errors.js';
import { debug } from './debug.js';
import _ from 'lodash';

/**
* Provide a System Message to the LLM, for use within a {@link ChatCompletion}.
Expand Down Expand Up @@ -183,8 +185,21 @@ function toConversationMessages(partialRendering: AI.PartiallyRendered[]): Conve
}

/** @hidden */
export async function renderToConversation(conversation: AI.Node, render: AI.ComponentContext['render']) {
return toConversationMessages(await render(conversation, { stop: isConversationalComponent }));
export async function renderToConversation(
petersalas marked this conversation as resolved.
Show resolved Hide resolved
conversation: AI.Node,
render: AI.ComponentContext['render'],
cost?: (message: ConversationMessage, render: AI.ComponentContext['render']) => Promise<number>,
budget?: number
) {
const conversationToUse =
cost && budget ? (
<ShrinkConversation cost={cost} budget={budget}>
{conversation}
</ShrinkConversation>
) : (
conversation
);
return toConversationMessages(await render(conversationToUse, { stop: isConversationalComponent }));
}

/**
Expand Down Expand Up @@ -309,3 +324,229 @@ export async function* ShowConversation(
// we can indicate that we've already yielded the final frame.
return AI.AppendOnlyStream;
}

/**
* @hidden
* "Shrinks" a conversation messages according to a cost function (e.g. token length),
* a budget (e.g. context window size), and the `importance` prop set on any `<Shrinkable>`
* components within the conversation.
*
* Currently, `<Shrinkable>` components must wrap conversational components and do not allow
* content to shrink _within_ conversational components. For example this:
*
* @example
* ```tsx
* // Do not do this!
* <UserMessage>
* Content
* <Shrinkable importance={0}>Not shrinkable!</Shrinkable>
* Content
* </UserMessage>
* ```
*
* is not shrinkable. Instead, do this:
*
* * @example
* ```tsx
* <Shrinkable importance={0} replacement={<UserMessage>Content Content</UserMessage>}>
* <UserMessage>
* Content
* Shrinkable!
* Content
* </UserMessage
* </Shrinkable>
* ```
*/
export async function ShrinkConversation(
{
cost: costFn,
budget,
children,
}: {
cost: (message: ConversationMessage, render: AI.RenderContext['render']) => Promise<number>;
budget: number;
children: Node;
},
{ render, memo, logger }: AI.ComponentContext
) {
/**
* We construct a tree of immutable and shrinkable nodes such that shrinkable nodes
* can contain other nodes.
*/
type TreeNode = ImmutableTreeNode | ShrinkableTreeNode;

interface ImmutableTreeNode {
type: 'immutable';
element: AI.Element<any>;
cost: number;
}

interface ShrinkableTreeNode {
type: 'shrinkable';
element: AI.Element<AI.PropsOfComponent<typeof InternalShrinkable>>;
cost: number;
children: TreeNode[];
}

/** Converts a conversational `AI.Node` into a shrinkable tree. */
async function conversationToTreeRoots(conversation: AI.Node): Promise<TreeNode[]> {
const rendered = await render(conversation, {
stop: (e) => isConversationalComponent(e) || e.tag === InternalShrinkable,
});

const asTreeNodes = await Promise.all(
rendered.map<Promise<TreeNode | null>>(async (value) => {
if (typeof value === 'string') {
return null;
}

if (value.tag === InternalShrinkable) {
const children = await conversationToTreeRoots(value.props.children);
return { type: 'shrinkable', element: value, cost: aggregateCost(children), children };
}

return {
type: 'immutable',
element: value,
cost: await costFn(toConversationMessages([value])[0], render),
};
})
);

return asTreeNodes.filter((n): n is TreeNode => n !== null);
}

/** Finds the least important node in any of the trees, considering cost as a second factor. */
function leastImportantNode(roots: TreeNode[]): ShrinkableTreeNode | undefined {
function compareImportance(nodeA: ShrinkableTreeNode, nodeB: ShrinkableTreeNode) {
petersalas marked this conversation as resolved.
Show resolved Hide resolved
// If the two nodes are of the same importance, consider the higher cost node less important.
return nodeA.element.props.importance - nodeB.element.props.importance || nodeB.cost - nodeA.cost;
}

let current = undefined as ShrinkableTreeNode | undefined;
roots.forEach((node) => {
if (node.type !== 'shrinkable') {
return;
}

if (current === undefined || compareImportance(node, current) < 0) {
current = node;
}

const leastImportantDescendant = leastImportantNode(node.children);
if (leastImportantDescendant !== undefined && compareImportance(leastImportantDescendant, current) < 0) {
current = leastImportantDescendant;
}
});

return current;
}

function aggregateCost(roots: TreeNode[]): number {
return _.sumBy(roots, (node) => node.cost);
}

/** Replaces a single ShrinkableTreeNode in the tree. */
async function replaceNode(roots: TreeNode[], nodeToReplace: ShrinkableTreeNode): Promise<TreeNode[]> {
const newRoots = await Promise.all(
roots.flatMap<Promise<TreeNode[]>>(async (root) => {
if (root === nodeToReplace) {
return conversationToTreeRoots(root.element.props.replacement);
}

if (root.type !== 'shrinkable') {
return [root];
}

// Look for a replacement among the children and recalculate the cost.
const replacementChildren = await replaceNode(root.children, nodeToReplace);
return [
{
type: 'shrinkable',
element: root.element,
cost: aggregateCost(replacementChildren),
children: replacementChildren,
},
];
})
);

return newRoots.flat(1);
}

/** Converts the shrinkable tree into a single AI.Node for rendering. */
function treeRootsToNode(roots: TreeNode[]): AI.Node {
return roots.map((root) => (root.type === 'immutable' ? root.element : treeRootsToNode(root.children)));
}

const memoized = memo(children);

// If there are no shrinkable elements, there's no need to evaluate the cost.
const shrinkableOrConversationElements = (
await render(memoized, {
stop: (e) => isConversationalComponent(e) || e.tag === InternalShrinkable,
})
).filter(AI.isElement);
if (!shrinkableOrConversationElements.find((value) => value.tag === InternalShrinkable)) {
return shrinkableOrConversationElements;
}

let roots = await conversationToTreeRoots(shrinkableOrConversationElements);
while (aggregateCost(roots) > budget) {
const nodeToReplace = leastImportantNode(roots);
if (nodeToReplace === undefined) {
// Nothing left to replace.
break;
}

logger.debug(
{
node: debug(nodeToReplace.element.props.children, true),
importance: nodeToReplace.element.props.importance,
replacement: debug(nodeToReplace.element.props.replacement, true),
nodeCost: nodeToReplace.cost,
totalCost: aggregateCost(roots),
budget,
},
'Replacing shrinkable content'
);

// N.B. This currently quadratic in that each time we replace a node we search the entire
// tree for the least important node (and then search again to replace it). If we end up
// doing many replacements we should be smarter about this.
roots = await replaceNode(roots, nodeToReplace);
}

return treeRootsToNode(roots);
}

/**
* @hidden
* Indicates that a portion of a conversation is "shrinkable".
*/
export function Shrinkable(
{ children, importance, replacement }: { children: Node; importance: number; replacement?: Node },
{ memo }: AI.ComponentContext
) {
// We render to a separate component so that:
//
// a) The memoization happens in the expected context (that of the <Shrinkable>)
// b) The memoization can be applied directly to the replacement and children
//
// This allows `children` and `replacement` to be taken off the props of <InternalShrinkable>
// and be correctly memoized, which would not otherwise be the case even if the <Shrinkable>
// or <InternalShrinkable> were memoized.
return (
<InternalShrinkable importance={importance} replacement={replacement && memo(replacement)}>
{children && memo(children)}
</InternalShrinkable>
);
}

/**
* @hidden
* An internal component to facilitate prop memoization. See comment in {@link Shrinkable}.
*/
function InternalShrinkable({ children }: { children: Node; importance: number; replacement: Node }) {
return children;
}
31 changes: 22 additions & 9 deletions packages/ai-jsx/src/core/node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,30 @@ export function makeIndirectNode<T extends object>(value: T, node: Node): T & In

/** @hidden */
export function withContext(renderable: Renderable, context: RenderContext): Element<any> {
function SwitchContext() {
return renderable;
if (isElement(renderable)) {
if (renderable[attachedContextSymbol]) {
// It's already been bound to a context; don't replace it.
return renderable;
}

const elementWithContext = {
...renderable,
[attachedContextSymbol]: context,
};
Object.freeze(elementWithContext);
return elementWithContext;
}

const elementWithContext = {
...(isElement(renderable) ? renderable : createElement(SwitchContext, null)),
[attachedContextSymbol]: context,
};

Object.freeze(elementWithContext);
return elementWithContext;
// Wrap it in an element and bind to that.
return withContext(
createElement(
function SwitchContext({ children }) {
return children;
},
{ children: renderable }
),
context
);
}

/** @hidden */
Expand Down
1 change: 1 addition & 0 deletions packages/ai-jsx/src/lib/anthropic.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ export async function* AnthropicChatModel(
}
yield AI.AppendOnlyStream;
const messages = await Promise.all(
// TODO: Support token budget/conversation shrinking
(
await renderToConversation(props.children, render)
)
Expand Down
Loading