Skip to content

Commit

Permalink
Merge pull request #63 from FormulaMonks/add/open-ai-schema-constrain…
Browse files Browse the repository at this point in the history
…ed-tokens

feat: add support for schema-constrained tokens in `KurtOpenAI`
  • Loading branch information
jemc authored Dec 9, 2024
2 parents 9ba8bdb + a98877b commit c3a255d
Show file tree
Hide file tree
Showing 17 changed files with 1,019 additions and 228 deletions.
2 changes: 1 addition & 1 deletion examples/basic/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"@formula-monks/kurt-open-ai": "workspace:*",
"@formula-monks/kurt-vertex-ai": "workspace:*",
"@google-cloud/vertexai": "1.1.0",
"openai": "4.66.1",
"openai": "^4.76.0",
"zod": "^3.23.8"
}
}
2 changes: 1 addition & 1 deletion packages/kurt-open-ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
},
"dependencies": {
"@formula-monks/kurt": "^1.4.0",
"openai": "4.66.1",
"openai": "4.76.0",
"zod": "^3.23.8",
"zod-to-json-schema": "^3.23.3"
},
Expand Down
7 changes: 4 additions & 3 deletions packages/kurt-open-ai/spec/generateNaturalLanguage.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { KurtResultLimitError } from "@formula-monks/kurt"

describe("KurtOpenAI generateNaturalLanguage", () => {
test("says hello", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateNaturalLanguage({
prompt: "Say hello!",
})
Expand All @@ -13,7 +13,7 @@ describe("KurtOpenAI generateNaturalLanguage", () => {
})

test("writes a haiku with high temperature", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateNaturalLanguage({
prompt: "Compose a haiku about a mountain stream at night.",
sampling: {
Expand All @@ -34,6 +34,7 @@ describe("KurtOpenAI generateNaturalLanguage", () => {

test("throws a limit error", async () => {
await snapshotAndMockWithError(
"gpt-4o-2024-05-13",
(kurt) =>
kurt.generateNaturalLanguage({
prompt: "Compose a haiku about content length limitations.",
Expand All @@ -50,7 +51,7 @@ describe("KurtOpenAI generateNaturalLanguage", () => {
})

test("describes a base64-encoded image", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateNaturalLanguage({
prompt: "Describe this emoji, in two words.",
extraMessages: [
Expand Down
62 changes: 60 additions & 2 deletions packages/kurt-open-ai/spec/generateStructuredData.spec.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import { describe, test, expect } from "@jest/globals"
import { z } from "zod"
import { snapshotAndMock, snapshotAndMockWithError } from "./snapshots"
import { KurtResultValidateError } from "@formula-monks/kurt"
import {
KurtCapabilityError,
KurtResultValidateError,
} from "@formula-monks/kurt"

describe("KurtOpenAI generateStructuredData", () => {
test("says hello", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-mini-2024-07-18", (kurt) =>
kurt.generateStructuredData({
prompt: "Say hello!",
schema: z
Expand All @@ -18,8 +21,63 @@ describe("KurtOpenAI generateStructuredData", () => {
expect(result.data).toEqual({ say: "hello" })
})

test("says hello with system prompt", async () => {
const result = await snapshotAndMock("gpt-4o-mini-2024-07-18", (kurt) =>
kurt.generateStructuredData({
systemPrompt: "Be nice.",
prompt: "Say hello!",
schema: z
.object({
say: z.string().describe("A single word to say"),
})
.describe("Say a word"),
})
)
expect(result.data).toEqual({ say: "hello" })
})

test("says hello with schema constrained tokens", async () => {
const result = await snapshotAndMock("gpt-4o-mini-2024-07-18", (kurt) =>
kurt.generateStructuredData({
prompt: "Say hello!",
schema: z
.object({
say: z.string().describe("A single word to say"),
})
.describe("Say a word"),
sampling: { forceSchemaConstrainedTokens: true },
})
)
expect(result.data).toEqual({ say: "hello" })
})

test("throws a capability error for schema constrained tokens in an older model", async () => {
await snapshotAndMockWithError(
"gpt-4o-2024-05-13",
(kurt) =>
kurt.generateStructuredData({
prompt: "Say hello!",
schema: z
.object({
say: z.string().describe("A single word to say"),
})
.describe("Say a word"),
sampling: { forceSchemaConstrainedTokens: true },
}),
(errorAny) => {
expect(errorAny).toBeInstanceOf(KurtCapabilityError)
const error = errorAny as KurtCapabilityError
expect(error.missingCapability).toEqual(
"forceSchemaConstrainedTokens is not available for older models, including gpt-4o-2024-05-13"
)
expect(error.message).toContain(error.missingCapability)
}
)
})

test("throws a validate error from an impossible schema", async () => {
await snapshotAndMockWithError(
"gpt-4o-mini-2024-07-18",
(kurt) =>
kurt.generateStructuredData({
prompt: "Say hello!",
Expand Down
26 changes: 21 additions & 5 deletions packages/kurt-open-ai/spec/generateWithOptionalTools.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const calculatorTools = {

describe("KurtOpenAI generateWithOptionalTools", () => {
test("calculator (with tool call)", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateWithOptionalTools({
prompt:
"What's 9876356 divided by 30487, rounded to the nearest integer?",
Expand All @@ -34,8 +34,24 @@ describe("KurtOpenAI generateWithOptionalTools", () => {
expect(result.additionalData).toBeUndefined() // no parallel tool calls
})

test("calculator (with strict tool call)", async () => {
const result = await snapshotAndMock("gpt-4o-mini-2024-07-18", (kurt) =>
kurt.generateWithOptionalTools({
prompt:
"What's 9876356 divided by 30487, rounded to the nearest integer?",
tools: calculatorTools,
sampling: { forceSchemaConstrainedTokens: true },
})
)
expect(result.data).toEqual({
name: "divide",
args: { dividend: 9876356, divisor: 30487 },
})
expect(result.additionalData).toBeUndefined() // no parallel tool calls
})

test("calculator (after tool call)", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateWithOptionalTools({
prompt:
"What's 9876356 divided by 30487, rounded to the nearest integer?",
Expand All @@ -58,7 +74,7 @@ describe("KurtOpenAI generateWithOptionalTools", () => {
})

test("calculator (with parallel tool calls)", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateWithOptionalTools({
prompt: [
"Calculate each of the following:",
Expand Down Expand Up @@ -86,7 +102,7 @@ describe("KurtOpenAI generateWithOptionalTools", () => {
})

test("calculator (after parallel tool calls)", async () => {
const result = await snapshotAndMock((kurt) =>
const result = await snapshotAndMock("gpt-4o-2024-05-13", (kurt) =>
kurt.generateWithOptionalTools({
prompt: [
"Calculate each of the following:",
Expand Down Expand Up @@ -125,7 +141,7 @@ describe("KurtOpenAI generateWithOptionalTools", () => {
)
expect(result.text).toEqual(
[
"Here are the results of the calculations:",
"Here are the results:",
"",
"1. 8026256882 divided by 3402398 is 2359.",
"2. 1185835515 divided by 348263 is 3405.",
Expand Down
8 changes: 5 additions & 3 deletions packages/kurt-open-ai/spec/snapshots.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import type {
OpenAIResponse,
OpenAIResponseChunk,
} from "../src/OpenAI.types"
import { KurtOpenAI } from "../src/KurtOpenAI"
import { KurtOpenAI, type KurtOpenAISupportedModel } from "../src/KurtOpenAI"

function snapshotFilenameFor(testName: string | undefined) {
return `${__dirname}/snapshots/${testName?.replace(/ /g, "_")}.yaml`
Expand All @@ -29,6 +29,7 @@ function dumpYaml(filename: string, data: object) {
}

export async function snapshotAndMock<T>(
model: KurtOpenAISupportedModel,
testCaseFn: (kurt: Kurt) => KurtStream<T>
) {
// Here's the data structure we will use to snapshot a request/response cycle.
Expand Down Expand Up @@ -91,7 +92,7 @@ export async function snapshotAndMock<T>(
} as unknown as OpenAI

// Run the test case function with a new instance of Kurt.
const kurt = new Kurt(new KurtOpenAI({ openAI, model: "gpt-4o-2024-05-13" }))
const kurt = new Kurt(new KurtOpenAI({ openAI, model }))
const stream = testCaseFn(kurt)

// Save the final stream of Kurt events.
Expand All @@ -114,11 +115,12 @@ export async function snapshotAndMock<T>(
}

export async function snapshotAndMockWithError<T>(
model: KurtOpenAISupportedModel,
testCaseFn: (kurt: Kurt) => KurtStream<T>,
errorCheckFn: (error: Error) => void
) {
try {
await snapshotAndMock(testCaseFn)
await snapshotAndMock(model, testCaseFn)
expectedErrorToBeThrownBeforeThisPoint()
} catch (error: unknown) {
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ step1Request:
stream: true
stream_options:
include_usage: true
model: gpt-4o-2024-05-13
model: gpt-4o-mini-2024-07-18
max_tokens: 4096
temperature: 0.5
top_p: 0.95
messages:
- role: system
content:
- type: text
text: Respond with JSON.
- role: user
content:
- type: text
Expand All @@ -31,6 +35,8 @@ step1Request:
type: function
function:
name: structured_data
response_format:
type: json_object
step2RawChunks:
- choices:
- index: 0
Expand All @@ -39,14 +45,15 @@ step2RawChunks:
content: null
tool_calls:
- index: 0
id: call_oZj1FnPJSZCVbFYtpNYPAm7P
id: call_9x7qX8eO6DgWYP8h1xc5kHsl
type: function
function:
name: structured_data
arguments: ""
refusal: null
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -57,7 +64,7 @@ step2RawChunks:
arguments: '{"'
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -68,7 +75,7 @@ step2RawChunks:
arguments: say
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -79,7 +86,7 @@ step2RawChunks:
arguments: '":"'
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -90,7 +97,7 @@ step2RawChunks:
arguments: hello
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
Expand All @@ -101,21 +108,29 @@ step2RawChunks:
arguments: '"}'
logprobs: null
finish_reason: null
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices:
- index: 0
delta: {}
logprobs: null
finish_reason: stop
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage: null
- choices: []
system_fingerprint: fp_5bf7397cd3
system_fingerprint: fp_bba3c8e70b
usage:
prompt_tokens: 66
prompt_tokens: 70
completion_tokens: 5
total_tokens: 71
total_tokens: 75
prompt_tokens_details:
cached_tokens: 0
audio_tokens: 0
completion_tokens_details:
reasoning_tokens: 0
audio_tokens: 0
accepted_prediction_tokens: 0
rejected_prediction_tokens: 0
step3KurtEvents:
- chunk: '{"'
- chunk: say
Expand All @@ -127,6 +142,6 @@ step3KurtEvents:
data:
say: hello
metadata:
totalInputTokens: 66
totalInputTokens: 70
totalOutputTokens: 5
systemFingerprint: fp_5bf7397cd3
systemFingerprint: fp_bba3c8e70b
Loading

0 comments on commit c3a255d

Please sign in to comment.