Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: #3698 - o1 preview models do not work with max_tokens #3728

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions extensions/inference-openai-extension/jest.config.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
transform: {
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
},
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
}
54 changes: 54 additions & 0 deletions extensions/inference-openai-extension/src/OpenAIExtension.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/**
* @jest-environment jsdom
*/
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
RemoteOAIEngine: jest.fn().mockImplementation(() => ({
onLoad: jest.fn(),
registerSettings: jest.fn(),
registerModels: jest.fn(),
getSetting: jest.fn(),
onSettingUpdate: jest.fn(),
})),
}))
import JanInferenceOpenAIExtension, { Settings } from '.'

describe('JanInferenceOpenAIExtension', () => {
let extension: JanInferenceOpenAIExtension

beforeEach(() => {
// @ts-ignore
extension = new JanInferenceOpenAIExtension()
})

it('should initialize with settings and models', async () => {
await extension.onLoad()
// Assuming there are some default SETTINGS and MODELS being registered
expect(extension.apiKey).toBe(undefined)
expect(extension.inferenceUrl).toBe('')
})

it('should transform the payload for preview models', () => {
const payload: any = {
max_tokens: 100,
model: 'o1-mini',
// Add other required properties...
}

const transformedPayload = extension.transformPayload(payload)
expect(transformedPayload.max_completion_tokens).toBe(payload.max_tokens)
expect(transformedPayload).not.toHaveProperty('max_tokens')
expect(transformedPayload).toHaveProperty('max_completion_tokens')
})

it('should not transform the payload for non-preview models', () => {
const payload: any = {
max_tokens: 100,
model: 'non-preview-model',
// Add other required properties...
}

const transformedPayload = extension.transformPayload(payload)
expect(transformedPayload).toEqual(payload)
})
})
28 changes: 25 additions & 3 deletions extensions/inference-openai-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
* @module inference-openai-extension/src/index
*/

import { RemoteOAIEngine, SettingComponentProps } from '@janhq/core'
import { ModelRuntimeParams, PayloadType, RemoteOAIEngine } from '@janhq/core'

declare const SETTINGS: Array<any>
declare const MODELS: Array<any>

enum Settings {
export enum Settings {
apiKey = 'openai-api-key',
chatCompletionsEndPoint = 'chat-completions-endpoint',
}

type OpenAIPayloadType = PayloadType &
ModelRuntimeParams & { max_completion_tokens: number }
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
Expand All @@ -24,6 +25,7 @@ enum Settings {
export default class JanInferenceOpenAIExtension extends RemoteOAIEngine {
inferenceUrl: string = ''
provider: string = 'openai'
previewModels = ['o1-mini', 'o1-preview']

override async onLoad(): Promise<void> {
super.onLoad()
Expand Down Expand Up @@ -63,4 +65,24 @@ export default class JanInferenceOpenAIExtension extends RemoteOAIEngine {
}
}
}

/**
* Tranform the payload before sending it to the inference endpoint.
* The new preview models such as o1-mini and o1-preview replaced max_tokens by max_completion_tokens parameter.
* Others do not.
* @param payload
* @returns
*/
transformPayload = (payload: OpenAIPayloadType): OpenAIPayloadType => {
// Transform the payload for preview models
if (this.previewModels.includes(payload.model)) {
const { max_tokens, ...params } = payload
return {
...params,
max_completion_tokens: max_tokens,
}
}
// Pass through for non-preview models
return payload
}
}
3 changes: 2 additions & 1 deletion extensions/inference-openai-extension/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
"skipLibCheck": true,
"rootDir": "./src"
},
"include": ["./src"]
"include": ["./src"],
"exclude": ["**/*.test.ts"]
}
Loading