Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

912: Update available GPT models #915

Merged
merged 3 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion backend/src/langchain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@ function makePromptTemplate(
}

function getChatModel(): CHAT_MODEL_ID {
return getValidOpenAIModels().includes('gpt-4') ? 'gpt-4' : 'gpt-3.5-turbo';
const validModels = getValidOpenAIModels();
// GPT-4 is the most expensive model by a long way, avoid at all costs!
return (
validModels.find((model) => model === 'gpt-4o') ??
validModels.find((model) => model === 'gpt-4-turbo') ??
validModels.find((model) => model === 'gpt-3.5-turbo') ??
validModels[0]
);
}

function initQAModel(level: LEVEL_NAMES, Prompt: string) {
Expand Down
29 changes: 20 additions & 9 deletions backend/src/models/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@ import { ChatInfoMessage, ChatMessage } from './chatMessage';
import { DEFENCE_ID } from './defence';
import { EmailInfo } from './email';

const chatModelIds = [
'gpt-4-1106-preview',
'gpt-4',
'gpt-4-0613',
'gpt-3.5-turbo',
] as const;
// Size of each model's context window in number of tokens
// https://platform.openai.com/docs/models
const chatModelContextWindow = {
'gpt-4o': 128000,
'gpt-4-turbo': 128000,
'gpt-4': 8192,
'gpt-3.5-turbo': 16385,
} as const;

type CHAT_MODEL_ID = keyof typeof chatModelContextWindow;

type CHAT_MODEL_ID = (typeof chatModelIds)[number];
const chatModelIds = Object.freeze(
Object.keys(chatModelContextWindow)
) as readonly [CHAT_MODEL_ID];

type ChatModel = {
id: CHAT_MODEL_ID;
Expand Down Expand Up @@ -103,6 +109,7 @@ const defaultChatModel: ChatModel = {
};

export type {
CHAT_MODEL_ID,
DefenceReport,
ChatModel,
ChatModelConfigurations,
Expand All @@ -116,5 +123,9 @@ export type {
SingleDefenceReport,
MODEL_CONFIG_ID,
};
export { defaultChatModel, modelConfigIds, chatModelIds };
export type { CHAT_MODEL_ID };
export {
defaultChatModel,
modelConfigIds,
chatModelIds,
chatModelContextWindow,
};
14 changes: 9 additions & 5 deletions backend/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { queryDocuments } from './langchain';
import {
CHAT_MODEL_ID,
ChatModel,
chatModelContextWindow,
ChatModelReply,
FunctionCallResponse,
ToolCallResponse,
Expand All @@ -27,9 +28,8 @@ import {
} from './models/openai';
import { pushMessageToHistory } from './utils/chat';
import {
chatModelMaxTokens,
countTotalPromptTokens,
filterChatHistoryByMaxTokens,
truncateChatHistoryToContextWindow,
} from './utils/token';

// tools available to OpenAI models
Expand Down Expand Up @@ -121,12 +121,15 @@ async function getValidModelsFromOpenAI() {
.map((model) => model.id as CHAT_MODEL_ID)
.filter((id) => chatModelIds.includes(id))
.sort();
if (!validModels.length) {
throw new Error('No chat models found');
}

validOpenAiModels.set(validModels);
console.debug('Valid OpenAI models:', validModels);
return validModels;
} catch (error) {
console.error('Error getting valid models: ', error);
console.error('Error getting valid models:', error);
throw error;
}
}
Expand Down Expand Up @@ -289,8 +292,9 @@ function getChatCompletionsInContextWindow(
countTotalPromptTokens(completions)
);

const maxTokens = chatModelMaxTokens[gptModel] * 0.95; // 95% of max tokens to allow for response tokens
const reducedCompletions = filterChatHistoryByMaxTokens(
// 95% of max tokens to allow for response tokens - bit crude :(
const maxTokens = chatModelContextWindow[gptModel] * 0.95;
const reducedCompletions = truncateChatHistoryToContextWindow(
completions,
maxTokens
);
Expand Down
31 changes: 7 additions & 24 deletions backend/src/utils/token.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,8 @@ import {
} from 'openai/resources/chat/completions';
import { promptTokensEstimate, stringTokens } from 'openai-chat-tokens';

import { CHAT_MODEL_ID } from '@src/models/chat';
import { chatModelTools } from '@src/openai';

// The size of each model's context window in number of tokens. https://platform.openai.com/docs/models
const chatModelMaxTokens: {
[key in CHAT_MODEL_ID]: number;
} = {
'gpt-4': 8192,
'gpt-4-1106-preview': 128000,
'gpt-4-0613': 8192,
'gpt-3.5-turbo': 16385,
};

const TOKENS_PER_TOOL_CALL = 4;
const OFFSET_OPENAI_TOKENS = 5; // there is an offset of 5 between openai completion prompt_tokens

Expand Down Expand Up @@ -53,22 +42,20 @@ function countTotalPromptTokens(chatHistory: ChatCompletionMessageParam[]) {
);
}

function filterChatHistoryByMaxTokens(
function truncateChatHistoryToContextWindow(
chatHistory: ChatCompletionMessageParam[],
maxNumTokens: number
contextWindowSize: number
): ChatCompletionMessageParam[] {
if (chatHistory.length === 0) {
return chatHistory;
}
const estimatedTokens = countTotalPromptTokens(chatHistory);
if (estimatedTokens <= maxNumTokens) {
if (estimatedTokens <= contextWindowSize) {
return chatHistory;
}
console.log(
'Filtering chat history to fit inside context window. estimated_tokens= ',
estimatedTokens,
'maxNumTokens=',
maxNumTokens
'Truncating chat history to fit inside context window:',
`estimated_tokens=${estimatedTokens} context_window=${contextWindowSize}`
);
const newList: ChatCompletionMessageParam[] = [];

Expand All @@ -95,7 +82,7 @@ function filterChatHistoryByMaxTokens(

const currentTokens = countTotalPromptTokens(tempList);
// if it fits add it to the list
if (currentTokens <= maxNumTokens) {
if (currentTokens <= contextWindowSize) {
newList.splice(i, 0, message);
} else {
console.debug('Max tokens reached on completion=', message);
Expand All @@ -110,8 +97,4 @@ function filterChatHistoryByMaxTokens(
return newList.reverse();
}

export {
chatModelMaxTokens,
filterChatHistoryByMaxTokens,
countTotalPromptTokens,
};
export { truncateChatHistoryToContextWindow, countTotalPromptTokens };
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,35 @@ test('GIVEN the prompt evaluation model is not initialised WHEN it is asked to e
expect(result).toEqual(false);
});

test('GIVEN the users api key supports GPT-4 WHEN the prompt evaluation model is initialised THEN it is initialised with GPT-4', async () => {
mockValidModels = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3'];
test('GIVEN the users api key supports gpt-4o WHEN the prompt evaluation model is initialised THEN it is initialised with gpt-4o', async () => {
mockValidModels = ['gpt-4o', 'gpt-4-turbo', 'gpt-3.5-turbo', 'gpt-3'];

const prompt = 'this is a test prompt. ';

await evaluatePrompt('some input', prompt);

expect(OpenAI).toHaveBeenCalledWith({
modelName: 'gpt-4',
modelName: 'gpt-4o',
temperature: 0,
openAIApiKey: 'sk-12345',
});
});

test('GIVEN the users api key does not support GPT-4 WHEN the prompt evaluation model is initialised THEN it is initialised with gpt-3.5-turbo', async () => {
test('GIVEN the users api key supports gpt-4-turbo but not gpt-4o WHEN the prompt evaluation model is initialised THEN it is initialised with gpt-4-turbo', async () => {
mockValidModels = ['gpt-4', 'gpt-4-turbo', 'gpt-3.5-turbo', 'gpt-3'];

const prompt = 'this is a test prompt. ';

await evaluatePrompt('some input', prompt);

expect(OpenAI).toHaveBeenCalledWith({
modelName: 'gpt-4-turbo',
temperature: 0,
openAIApiKey: 'sk-12345',
});
});

test('GIVEN the users api key does not support gpt-4o or gpt-4-turbo WHEN the prompt evaluation model is initialised THEN it is initialised with gpt-3.5-turbo', async () => {
mockValidModels = ['gpt-2', 'gpt-3.5-turbo', 'gpt-3'];

const prompt = 'this is a test prompt. ';
Expand Down
23 changes: 19 additions & 4 deletions backend/test/unit/langchain.ts/initialiseQAModel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,37 @@ test('GIVEN the QA LLM WHEN a question is asked THEN it is initialised AND it an
expect(answer).toEqual(expectedAnswer);
});

test('GIVEN the users api key supports GPT-4 WHEN the QA model is initialised THEN it is initialised with GPT-4', async () => {
mockValidModels.push('gpt-4', 'gpt-3.5-turbo', 'gpt-3');
test('GIVEN the users api key supports gpt-4o WHEN the QA model is initialised THEN it is initialised with gpt-4o', async () => {
mockValidModels.push('gpt-4o', 'gpt-4-turbo', 'gpt-3.5-turbo', 'gpt-3');

const level = LEVEL_NAMES.LEVEL_1;
const prompt = 'this is a test prompt. ';

await queryDocuments('some question', prompt, level);

expect(ChatOpenAI).toHaveBeenCalledWith({
modelName: 'gpt-4',
modelName: 'gpt-4o',
streaming: true,
openAIApiKey: 'sk-12345',
});
});

test('GIVEN the users api key does not support GPT-4 WHEN the QA model is initialised THEN it is initialised with gpt-3.5-turbo', async () => {
test('GIVEN the users api key supports gpt-4-turbo but not gpt-4o WHEN the QA model is initialised THEN it is initialised with gpt-4-turbo', async () => {
mockValidModels.push('gpt-4', 'gpt-4-turbo', 'gpt-3.5-turbo', 'gpt-3');

const level = LEVEL_NAMES.LEVEL_1;
const prompt = 'this is a test prompt. ';

await queryDocuments('some question', prompt, level);

expect(ChatOpenAI).toHaveBeenCalledWith({
modelName: 'gpt-4o',
streaming: true,
openAIApiKey: 'sk-12345',
});
});

test('GIVEN the users api key does not support gpt-4o or gpt-4-turbo WHEN the QA model is initialised THEN it is initialised with gpt-3.5-turbo', async () => {
mockValidModels.push('gpt-2', 'gpt-3.5-turbo', 'gpt-3');

const level = LEVEL_NAMES.LEVEL_1;
Expand Down
22 changes: 20 additions & 2 deletions backend/test/unit/openai.ts/getValidModelsFromOpenai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ describe('getValidModelsFromOpenAI', () => {
{ id: 'gpt-3.5-turbo' },
{ id: 'gpt-3.5-turbo-0613' },
{ id: 'gpt-4' },
{ id: 'gpt-4-0613' },
{ id: 'gpt-4o' },
{ id: 'da-vinci-1' },
{ id: 'da-vinci-2' },
];
const expectedValidModels = ['gpt-3.5-turbo', 'gpt-4', 'gpt-4-0613'];
const expectedValidModels = ['gpt-3.5-turbo', 'gpt-4', 'gpt-4o'];

mockListFn.mockResolvedValueOnce({
data: mockModelList,
Expand All @@ -52,4 +52,22 @@ describe('getValidModelsFromOpenAI', () => {

expect(validModels).toEqual(expectedValidModels);
});

test('GIVEN the user has no valid chat models available WHEN getValidModelsFromOpenAI is called THEN an error is thrown', async () => {
process.env.OPENAI_API_KEY = 'sk-12345';
const mockModelList = [
{ id: 'gpt-3' },
{ id: 'davinci-001' },
{ id: 'davinci-002' },
{ id: 'text-moderation-stable' },
{ id: 'whisper-1' },
];
mockListFn.mockResolvedValueOnce({
data: mockModelList,
} as OpenAI.ModelsPage);

await expect(getValidModelsFromOpenAI()).rejects.toThrow(
'No chat models found'
);
});
});
Loading
Loading