Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to shrink conversation based on cost/budget/importance #228

Merged
merged 4 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/ai-jsx/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"repository": "fixie-ai/ai-jsx",
"bugs": "https://github.com/fixie-ai/ai-jsx/issues",
"homepage": "https://ai-jsx.com",
"version": "0.8.0",
"version": "0.8.1",
"volta": {
"extends": "../../package.json"
},
Expand Down Expand Up @@ -342,8 +342,8 @@
"axios": "^1.4.0",
"cli-highlight": "^2.1.11",
"cli-spinners": "^2.9.0",
"gpt3-tokenizer": "^1.1.5",
"ink": "^4.2.0",
"js-tiktoken": "^1.0.7",
"js-yaml": "^4.1.0",
"langchain": "^0.0.81",
"lodash": "^4.17.21",
Expand Down
2 changes: 1 addition & 1 deletion packages/ai-jsx/src/batteries/natural-language-router.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ export async function* NaturalLanguageRouter(props: { children: Node; query: Nod
}

const props = e.props as RouteProps;
return props.unmatched ? choiceIndex === 0 : props.when === whenOptions[choiceIndex];
return props.unmatched ? choiceIndex === whenOptions.length - 1 : props.when === whenOptions[choiceIndex];
});
}

Expand Down
219 changes: 217 additions & 2 deletions packages/ai-jsx/src/core/conversation.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { ChatCompletionResponseMessage } from 'openai';
import * as AI from '../index.js';
import { Node } from '../index.js';
import { AIJSXError, ErrorCode } from '../core/errors.js';
import { debug } from './debug.js';

/**
* Provide a System Message to the LLM, for use within a {@link ChatCompletion}.
Expand Down Expand Up @@ -183,8 +184,21 @@ function toConversationMessages(partialRendering: AI.PartiallyRendered[]): Conve
}

/** @hidden */
export async function renderToConversation(conversation: AI.Node, render: AI.ComponentContext['render']) {
return toConversationMessages(await render(conversation, { stop: isConversationalComponent }));
export async function renderToConversation(
petersalas marked this conversation as resolved.
Show resolved Hide resolved
conversation: AI.Node,
render: AI.ComponentContext['render'],
cost?: (message: ConversationMessage, render: AI.ComponentContext['render']) => Promise<number>,
budget?: number
) {
const conversationToUse =
cost && budget ? (
<ShrinkConversation cost={cost} budget={budget}>
{conversation}
</ShrinkConversation>
) : (
conversation
);
return toConversationMessages(await render(conversationToUse, { stop: isConversationalComponent }));
}

/**
Expand Down Expand Up @@ -309,3 +323,204 @@ export async function* ShowConversation(
// we can indicate that we've already yielded the final frame.
return AI.AppendOnlyStream;
}

/**
* @hidden
* "Shrinks" a conversation messages according to a cost function (i.e. token length),
petersalas marked this conversation as resolved.
Show resolved Hide resolved
* a budget (i.e. context window size), and the `importance` prop set on any `<Shrinkable>`
* components within the conversation.
*
* Currently, `<Shrinkable>` components must wrap conversational messages and have no
* effect within the messages themselves.
petersalas marked this conversation as resolved.
Show resolved Hide resolved
*/
export async function ShrinkConversation(
{
cost: costFn,
budget: budget,
petersalas marked this conversation as resolved.
Show resolved Hide resolved
children,
}: {
cost: (message: ConversationMessage, render: AI.RenderContext['render']) => Promise<number>;
budget: number;
children: Node;
},
{ render, memo, logger }: AI.ComponentContext
) {
/**
* We construct a tree of immutable and shrinkable nodes such that shrinkable nodes
* can contain other nodes.
*/
type TreeNode = ImmutableTreeNode | ShrinkableTreeNode;

interface ImmutableTreeNode {
type: 'immutable';
element: AI.Element<any>;
cost: number;
}

interface ShrinkableTreeNode {
type: 'shrinkable';
element: AI.Element<AI.PropsOfComponent<typeof InternalShrinkable>>;
cost: number;
children: TreeNode[];
}

/** Converts a conversational `AI.Node` into a shrinkable tree. */
async function conversationToTreeRoots(conversation: AI.Node): Promise<TreeNode[]> {
const rendered = await render(conversation, {
stop: (e) => isConversationalComponent(e) || e.tag === InternalShrinkable,
});

const asTreeNodes = await Promise.all(
rendered.map<Promise<TreeNode | null>>(async (value) => {
if (typeof value === 'string') {
return null;
}

if (value.tag === InternalShrinkable) {
const children = await conversationToTreeRoots(value.props.children);
return { type: 'shrinkable', element: value, cost: aggregateCost(children), children };
}

return {
type: 'immutable',
element: value,
cost: await costFn(toConversationMessages([value])[0], render),
};
})
);

return asTreeNodes.filter((n): n is TreeNode => n !== null);
}

/** Finds the least important node in the tree, considering cost as a second factor. */
function leastImportantNode(roots: TreeNode[]): ShrinkableTreeNode | undefined {
petersalas marked this conversation as resolved.
Show resolved Hide resolved
function compareImportance(nodeA: ShrinkableTreeNode, nodeB: ShrinkableTreeNode) {
petersalas marked this conversation as resolved.
Show resolved Hide resolved
// If the two nodes are of the same importance, consider the higher cost node less important.
return nodeA.element.props.importance - nodeB.element.props.importance || nodeB.cost - nodeA.cost;
}

let current = undefined as ShrinkableTreeNode | undefined;
roots.forEach((node) => {
if (node.type !== 'shrinkable') {
return;
}

if (current === undefined || compareImportance(node, current) < 0) {
current = node;
}

const leastImportantDescendant = leastImportantNode(node.children);
if (leastImportantDescendant !== undefined && compareImportance(leastImportantDescendant, current) < 0) {
current = leastImportantDescendant;
}
});

return current;
}

function aggregateCost(roots: TreeNode[]): number {
return roots.reduce((cost, node) => cost + node.cost, 0);
petersalas marked this conversation as resolved.
Show resolved Hide resolved
}

/** Replaces a single ShrinkableTreeNode in the tree. */
async function replaceNode(roots: TreeNode[], nodeToReplace: ShrinkableTreeNode): Promise<TreeNode[]> {
const newRoots = await Promise.all(
roots.map<Promise<TreeNode[]>>(async (root) => {
petersalas marked this conversation as resolved.
Show resolved Hide resolved
if (root === nodeToReplace) {
return conversationToTreeRoots(root.element.props.replacement);
}

if (root.type !== 'shrinkable') {
return [root];
}

// Look for a replacement among the children and recalculate the cost.
const replacementChildren = await replaceNode(root.children, nodeToReplace);
return [
{
type: 'shrinkable',
element: root.element,
cost: aggregateCost(replacementChildren),
children: replacementChildren,
},
];
})
);

return newRoots.flat(1);
}

/** Converts the shrinkable tree into a single AI.Node for rendering. */
function treeRootsToNode(roots: TreeNode[]): AI.Node {
return roots.map((root) => (root.type === 'immutable' ? root.element : treeRootsToNode(root.children)));
}

const memoized = memo(children);

// If there are no shrinkable elements, there's no need to evaluate the cost.
const shrinkableOrConversationElements = (
await render(memoized, {
stop: (e) => isConversationalComponent(e) || e.tag === InternalShrinkable,
})
).filter(AI.isElement);
if (!shrinkableOrConversationElements.find((value) => value.tag === InternalShrinkable)) {
return shrinkableOrConversationElements;
}

let roots = await conversationToTreeRoots(shrinkableOrConversationElements);
while (aggregateCost(roots) > budget) {
const nodeToReplace = leastImportantNode(roots);
if (nodeToReplace === undefined) {
// Nothing left to replace.
break;
}

logger.debug(
{
node: debug(nodeToReplace.element.props.children, true),
importance: nodeToReplace.element.props.importance,
replacement: debug(nodeToReplace.element.props.replacement, true),
nodeCost: nodeToReplace.cost,
totalCost: aggregateCost(roots),
budget,
},
'Replacing shrinkable content'
);

// N.B. This currently quadratic in that each time we replace a node we search the entire
// tree for the least important node, and then search _again_ to replace it. If we end up
// doing many replacements we should be smarter about this.
roots = await replaceNode(roots, nodeToReplace);
petersalas marked this conversation as resolved.
Show resolved Hide resolved
}

return treeRootsToNode(roots);
}

/**
* @hidden
* Indicates that a portion of a conversation is "shrinkable".
*/
export function Shrinkable(
{ children, importance, replacement }: { children: Node; importance: number; replacement?: Node },
{ memo }: AI.ComponentContext
) {
// We renders to a separate component so that:
//
// a) The memoization happens in the expected context (that of the <Shrinkable>)
// b) The memoization can be applied directly to the replacement and children
//
// This allows them to be taken off the props correctly memoized.
petersalas marked this conversation as resolved.
Show resolved Hide resolved
return (
<InternalShrinkable importance={importance} replacement={replacement && memo(replacement)}>
{children && memo(children)}
</InternalShrinkable>
);
}

/**
* @hidden
* An internal component to facilitate prop memoization. See comment in {@link Shrinkable}.
*/
function InternalShrinkable({ children }: { children: Node; importance: number; replacement: Node }) {
return children;
}
31 changes: 22 additions & 9 deletions packages/ai-jsx/src/core/node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,30 @@ export function makeIndirectNode<T extends object>(value: T, node: Node): T & In

/** @hidden */
export function withContext(renderable: Renderable, context: RenderContext): Element<any> {
function SwitchContext() {
return renderable;
if (isElement(renderable)) {
if (renderable[attachedContextSymbol]) {
// It's already been bound to a context; don't replace it.
return renderable;
}

const elementWithContext = {
...renderable,
[attachedContextSymbol]: context,
};
Object.freeze(elementWithContext);
return elementWithContext;
}

const elementWithContext = {
...(isElement(renderable) ? renderable : createElement(SwitchContext, null)),
[attachedContextSymbol]: context,
};

Object.freeze(elementWithContext);
return elementWithContext;
// Wrap it in an element and bind to that.
return withContext(
createElement(
function SwitchContext({ children }) {
return children;
},
{ children: renderable }
),
context
);
}

/** @hidden */
Expand Down
Loading