Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

828 streamline change chat model request #883

Merged
merged 14 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions backend/src/controller/modelController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ import { Response } from 'express';

import { OpenAiConfigureModelRequest } from '@src/models/api/OpenAiConfigureModelRequest';
import { OpenAiSetModelRequest } from '@src/models/api/OpenAiSetModelRequest';
import { MODEL_CONFIG_ID, modelConfigIds } from '@src/models/chat';
import {
MODEL_CONFIG_ID,
chatModelIds,
modelConfigIds,
} from '@src/models/chat';
import { ChatInfoMessage } from '@src/models/chatMessage';
import { LEVEL_NAMES } from '@src/models/level';
import { pushMessageToHistory } from '@src/utils/chat';
Expand All @@ -13,14 +17,32 @@ function handleSetModel(req: OpenAiSetModelRequest, res: Response) {
const { model } = req.body;

if (model === undefined) {
res.status(400).send();
} else {
const configuration =
req.body.configuration ?? req.session.chatModel.configuration;
req.session.chatModel = { id: model, configuration };
console.debug('GPT model set:', JSON.stringify(req.session.chatModel));
res.status(200).send();
sendErrorResponse(res, 400, 'Missing model');
return;
}

if (!chatModelIds.includes(model)) {
sendErrorResponse(res, 400, 'Invalid model');
return;
}

const configuration =
req.body.configuration ?? req.session.chatModel.configuration;
req.session.chatModel = { id: model, configuration };
console.debug('GPT model set:', JSON.stringify(req.session.chatModel));

const chatInfoMessage = {
infoMessage: `changed model to ${model}`,
chatMessageType: 'GENERIC_INFO',
} as ChatInfoMessage;
// for now, the chat model only changes in the sandbox level
req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory =
chriswilty marked this conversation as resolved.
Show resolved Hide resolved
pushMessageToHistory(
req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory,
chatInfoMessage
);

res.send({ chatInfoMessage });
}

function handleConfigureModel(req: OpenAiConfigureModelRequest, res: Response) {
Expand Down Expand Up @@ -59,13 +81,14 @@ function handleConfigureModel(req: OpenAiConfigureModelRequest, res: Response) {
infoMessage: `changed ${configId} to ${value}`,
chatMessageType: 'GENERIC_INFO',
} as ChatInfoMessage;
// for now, the chat model only changes in the sandbox level
req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory =
pushMessageToHistory(
req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory,
chatInfoMessage
);

res.status(200).send({ chatInfoMessage });
res.send({ chatInfoMessage });
}

export { handleSetModel, handleConfigureModel };
8 changes: 3 additions & 5 deletions backend/src/langchain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { OpenAI } from 'langchain/llms/openai';
import { PromptTemplate } from 'langchain/prompts';

import { getDocumentVectors } from './document';
import { CHAT_MODELS } from './models/chat';
import { CHAT_MODEL_ID } from './models/chat';
import { PromptEvaluationChainReply, QaChainReply } from './models/langchain';
import { LEVEL_NAMES } from './models/level';
import { getOpenAIKey, getValidOpenAIModels } from './openai';
Expand All @@ -30,10 +30,8 @@ function makePromptTemplate(
return PromptTemplate.fromTemplate(fullPrompt);
}

function getChatModel() {
return getValidOpenAIModels().includes(CHAT_MODELS.GPT_4)
? CHAT_MODELS.GPT_4
: CHAT_MODELS.GPT_3_5_TURBO;
function getChatModel(): CHAT_MODEL_ID {
return getValidOpenAIModels().includes('gpt-4') ? 'gpt-4' : 'gpt-3.5-turbo';
}

function initQAModel(level: LEVEL_NAMES, Prompt: string) {
Expand Down
9 changes: 6 additions & 3 deletions backend/src/models/api/OpenAiSetModelRequest.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import { Request } from 'express';

import { CHAT_MODELS, ChatModelConfigurations } from '@src/models/chat';
import { CHAT_MODEL_ID, ChatModelConfigurations } from '@src/models/chat';
import { ChatMessage } from '@src/models/chatMessage';

export type OpenAiSetModelRequest = Request<
never,
never,
{
model?: CHAT_MODELS;
chatInfoMessage: ChatMessage;
},
{
model?: CHAT_MODEL_ID;
configuration?: ChatModelConfigurations;
},
never
Expand Down
24 changes: 12 additions & 12 deletions backend/src/models/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@ import { ChatInfoMessage, ChatMessage } from './chatMessage';
import { DEFENCE_ID } from './defence';
import { EmailInfo } from './email';

enum CHAT_MODELS {
GPT_4_TURBO = 'gpt-4-1106-preview',
GPT_4 = 'gpt-4',
GPT_4_0613 = 'gpt-4-0613',
GPT_3_5_TURBO = 'gpt-3.5-turbo',
GPT_3_5_TURBO_0613 = 'gpt-3.5-turbo-0613',
GPT_3_5_TURBO_16K = 'gpt-3.5-turbo-16k',
GPT_3_5_TURBO_16K_0613 = 'gpt-3.5-turbo-16k-0613',
}
const chatModelIds = [
chriswilty marked this conversation as resolved.
Show resolved Hide resolved
'gpt-4-1106-preview',
'gpt-4',
'gpt-4-0613',
'gpt-3.5-turbo',
] as const;

type CHAT_MODEL_ID = (typeof chatModelIds)[number];

type ChatModel = {
id: CHAT_MODELS;
id: CHAT_MODEL_ID;
configuration: ChatModelConfigurations;
};

Expand Down Expand Up @@ -106,7 +105,7 @@ interface LevelHandlerResponse {
}

const defaultChatModel: ChatModel = {
id: CHAT_MODELS.GPT_3_5_TURBO,
id: 'gpt-3.5-turbo',
configuration: {
temperature: 1,
topP: 1,
Expand All @@ -131,4 +130,5 @@ export type {
SingleDefenceReport,
MODEL_CONFIG_ID,
};
export { CHAT_MODELS, defaultChatModel, modelConfigIds };
export { defaultChatModel, modelConfigIds, chatModelIds };
export type { CHAT_MODEL_ID };
13 changes: 7 additions & 6 deletions backend/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import { getQAPromptFromConfig } from './defence';
import { sendEmail } from './email';
import { queryDocuments } from './langchain';
import {
CHAT_MODELS,
CHAT_MODEL_ID,
ChatGptReply,
ChatModel,
ChatResponse,
FunctionCallResponse,
ToolCallResponse,
chatModelIds,
} from './models/chat';
import { ChatMessage } from './models/chatMessage';
import { QaLlmDefence } from './models/defence';
Expand Down Expand Up @@ -84,10 +85,10 @@ const chatGptTools: ChatCompletionTool[] = [

// list of valid chat models for the api key
const validOpenAiModels = (() => {
let validModels: string[] = [];
let validModels: CHAT_MODEL_ID[] = [];
return {
get: () => validModels,
set: (models: string[]) => {
set: (models: CHAT_MODEL_ID[]) => {
validModels = models;
},
};
Expand Down Expand Up @@ -117,8 +118,8 @@ async function getValidModelsFromOpenAI() {

// get the model ids that are supported by our app. Non-chat models like Dall-e and whisper are not supported.
const validModels = models.data
.map((model) => model.id)
.filter((id) => Object.values(CHAT_MODELS).includes(id as CHAT_MODELS))
.map((model) => model.id as CHAT_MODEL_ID)
.filter((id) => chatModelIds.includes(id))
.sort();

validOpenAiModels.set(validModels);
Expand Down Expand Up @@ -283,7 +284,7 @@ async function chatGptChatCompletion(

function getChatCompletionsInContextWindow(
chatHistory: ChatMessage[],
gptModel: CHAT_MODELS
gptModel: CHAT_MODEL_ID
): ChatCompletionMessageParam[] {
const completions = chatHistory
.map((chatMessage) =>
Expand Down
17 changes: 8 additions & 9 deletions backend/src/utils/token.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@ import {
} from 'openai/resources/chat/completions';
import { promptTokensEstimate, stringTokens } from 'openai-chat-tokens';

import { CHAT_MODELS } from '@src/models/chat';
import { CHAT_MODEL_ID } from '@src/models/chat';
import { chatGptTools } from '@src/openai';

// The size of each model's context window in number of tokens. https://platform.openai.com/docs/models
const chatModelMaxTokens = {
[CHAT_MODELS.GPT_4_TURBO]: 128000,
[CHAT_MODELS.GPT_4]: 8192,
[CHAT_MODELS.GPT_4_0613]: 8192,
[CHAT_MODELS.GPT_3_5_TURBO]: 4097,
[CHAT_MODELS.GPT_3_5_TURBO_0613]: 4097,
[CHAT_MODELS.GPT_3_5_TURBO_16K]: 16385,
[CHAT_MODELS.GPT_3_5_TURBO_16K_0613]: 16385,
const chatModelMaxTokens: {
[key in CHAT_MODEL_ID]: number;
} = {
'gpt-4': 8192,
'gpt-4-1106-preview': 128000,
'gpt-4-0613': 8192,
'gpt-3.5-turbo': 16385,
};

const TOKENS_PER_TOOL_CALL = 4;
Expand Down
4 changes: 2 additions & 2 deletions backend/test/integration/openai.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { expect, jest, test, describe } from '@jest/globals';

import { CHAT_MODELS, ChatModel } from '@src/models/chat';
import { ChatModel } from '@src/models/chat';
import { ChatMessage } from '@src/models/chatMessage';
import { chatGptSendMessage } from '@src/openai';

Expand Down Expand Up @@ -54,7 +54,7 @@ describe('OpenAI Integration Tests', () => {
},
];
const chatModel: ChatModel = {
id: CHAT_MODELS.GPT_4,
id: 'gpt-4',
configuration: {
temperature: 1,
topP: 1,
Expand Down
65 changes: 63 additions & 2 deletions backend/test/unit/controller/modelController.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import { expect, jest, test, describe } from '@jest/globals';
import { Response } from 'express';

import { handleConfigureModel } from '@src/controller/modelController';
import {
handleConfigureModel,
handleSetModel,
} from '@src/controller/modelController';
import { OpenAiConfigureModelRequest } from '@src/models/api/OpenAiConfigureModelRequest';
import { OpenAiSetModelRequest } from '@src/models/api/OpenAiSetModelRequest';
import { modelConfigIds } from '@src/models/chat';
import { ChatMessage } from '@src/models/chatMessage';
import { LEVEL_NAMES, LevelState } from '@src/models/level';
Expand All @@ -14,6 +18,64 @@ function responseMock() {
} as unknown as Response;
}

describe('handleSetModel', () => {
test('WHEN passed sensible parameters THEN sets model AND adds info message to chat history AND responds with info message', () => {
const req = {
body: {
model: 'gpt-4',
},
session: {
chatModel: {
id: 'gpt-3.5-turbo',
},
levelState: [{}, {}, {}, { chatHistory: [] } as unknown as LevelState],
},
} as OpenAiSetModelRequest;
const res = responseMock();

handleSetModel(req, res);

expect(req.session.chatModel.id).toBe('gpt-4');

const expectedInfoMessage = {
infoMessage: 'changed model to gpt-4',
chatMessageType: 'GENERIC_INFO',
} as ChatMessage;
expect(
req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory.at(-1)
).toEqual(expectedInfoMessage);
expect(res.send).toHaveBeenCalledWith({
chatInfoMessage: expectedInfoMessage,
});
});

test('WHEN missing model THEN does not set model', () => {
const req = {
body: {},
} as OpenAiSetModelRequest;
const res = responseMock();

handleSetModel(req, res);

expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalledWith('Missing model');
});

test('WHEN model is invalid THEN does not set model', () => {
const req = {
body: {
model: 'invalid model',
},
} as unknown as OpenAiSetModelRequest;
const res = responseMock();

handleSetModel(req, res);

expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalledWith('Invalid model');
});
});

describe('handleConfigureModel', () => {
test('WHEN passed sensible parameters THEN configures model AND adds info message to chat history AND responds with info message', () => {
const req = {
Expand All @@ -37,7 +99,6 @@ describe('handleConfigureModel', () => {

handleConfigureModel(req, res);

expect(res.status).toHaveBeenCalledWith(200);
expect(req.session.chatModel.configuration.topP).toBe(0.5);

const expectedInfoMessage = {
Expand Down
7 changes: 1 addition & 6 deletions backend/test/unit/openai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,7 @@ describe('getValidModelsFromOpenAI', () => {
{ id: 'da-vinci-1' },
{ id: 'da-vinci-2' },
];
const expectedValidModels = [
'gpt-3.5-turbo',
'gpt-3.5-turbo-0613',
'gpt-4',
'gpt-4-0613',
];
const expectedValidModels = ['gpt-3.5-turbo', 'gpt-4', 'gpt-4-0613'];

mockListFn.mockResolvedValueOnce({
data: mockModelList,
Expand Down
7 changes: 2 additions & 5 deletions frontend/src/components/ControlPanel/ControlPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import DefenceBox from '@src/components/DefenceBox/DefenceBox';
import ModelBox from '@src/components/ModelBox/ModelBox';
import DetailElement from '@src/components/ThemedButtons/DetailElement';
import ThemedButton from '@src/components/ThemedButtons/ThemedButton';
import { ChatMessage, ChatModel } from '@src/models/chat';
import { CHAT_MODEL_ID, ChatMessage, ChatModel } from '@src/models/chat';
import {
DEFENCE_ID,
DefenceConfigItem,
Expand All @@ -24,13 +24,12 @@ function ControlPanel({
resetDefenceConfiguration,
setDefenceConfiguration,
openDocumentViewer,
addInfoMessage,
addChatMessage,
}: {
currentLevel: LEVEL_NAMES;
defences: Defence[];
chatModel?: ChatModel;
setChatModelId: (modelId: string) => void;
setChatModelId: (modelId: CHAT_MODEL_ID) => void;
chatModelOptions: string[];
toggleDefence: (defence: Defence) => void;
resetDefenceConfiguration: (
Expand All @@ -42,7 +41,6 @@ function ControlPanel({
config: DefenceConfigItem[]
) => Promise<boolean>;
openDocumentViewer: () => void;
addInfoMessage: (message: string) => void;
addChatMessage: (chatMessage: ChatMessage) => void;
}) {
const configurableDefences =
Expand Down Expand Up @@ -101,7 +99,6 @@ function ControlPanel({
chatModel={chatModel}
setChatModelId={setChatModelId}
chatModelOptions={chatModelOptions}
addInfoMessage={addInfoMessage}
addChatMessage={addChatMessage}
/>
)}
Expand Down
Loading
Loading