diff --git a/backend/src/controller/modelController.ts b/backend/src/controller/modelController.ts index 36934e5d6..7204cc302 100644 --- a/backend/src/controller/modelController.ts +++ b/backend/src/controller/modelController.ts @@ -2,7 +2,11 @@ import { Response } from 'express'; import { OpenAiConfigureModelRequest } from '@src/models/api/OpenAiConfigureModelRequest'; import { OpenAiSetModelRequest } from '@src/models/api/OpenAiSetModelRequest'; -import { MODEL_CONFIG_ID, modelConfigIds } from '@src/models/chat'; +import { + MODEL_CONFIG_ID, + chatModelIds, + modelConfigIds, +} from '@src/models/chat'; import { ChatInfoMessage } from '@src/models/chatMessage'; import { LEVEL_NAMES } from '@src/models/level'; import { pushMessageToHistory } from '@src/utils/chat'; @@ -13,14 +17,32 @@ function handleSetModel(req: OpenAiSetModelRequest, res: Response) { const { model } = req.body; if (model === undefined) { - res.status(400).send(); - } else { - const configuration = - req.body.configuration ?? req.session.chatModel.configuration; - req.session.chatModel = { id: model, configuration }; - console.debug('GPT model set:', JSON.stringify(req.session.chatModel)); - res.status(200).send(); + sendErrorResponse(res, 400, 'Missing model'); + return; + } + + if (!chatModelIds.includes(model)) { + sendErrorResponse(res, 400, 'Invalid model'); + return; } + + const configuration = + req.body.configuration ?? req.session.chatModel.configuration; + req.session.chatModel = { id: model, configuration }; + console.debug('GPT model set:', JSON.stringify(req.session.chatModel)); + + const chatInfoMessage = { + infoMessage: `changed model to ${model}`, + chatMessageType: 'GENERIC_INFO', + } as ChatInfoMessage; + // for now, the chat model only changes in the sandbox level + req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory = + pushMessageToHistory( + req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory, + chatInfoMessage + ); + + res.send({ chatInfoMessage }); } function handleConfigureModel(req: OpenAiConfigureModelRequest, res: Response) { @@ -59,13 +81,14 @@ function handleConfigureModel(req: OpenAiConfigureModelRequest, res: Response) { infoMessage: `changed ${configId} to ${value}`, chatMessageType: 'GENERIC_INFO', } as ChatInfoMessage; + // for now, the chat model only changes in the sandbox level req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory = pushMessageToHistory( req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory, chatInfoMessage ); - res.status(200).send({ chatInfoMessage }); + res.send({ chatInfoMessage }); } export { handleSetModel, handleConfigureModel }; diff --git a/backend/src/langchain.ts b/backend/src/langchain.ts index 2e487e780..21a9ca396 100644 --- a/backend/src/langchain.ts +++ b/backend/src/langchain.ts @@ -4,7 +4,7 @@ import { OpenAI } from 'langchain/llms/openai'; import { PromptTemplate } from 'langchain/prompts'; import { getDocumentVectors } from './document'; -import { CHAT_MODELS } from './models/chat'; +import { CHAT_MODEL_ID } from './models/chat'; import { PromptEvaluationChainReply, QaChainReply } from './models/langchain'; import { LEVEL_NAMES } from './models/level'; import { getOpenAIKey, getValidOpenAIModels } from './openai'; @@ -30,10 +30,8 @@ function makePromptTemplate( return PromptTemplate.fromTemplate(fullPrompt); } -function getChatModel() { - return getValidOpenAIModels().includes(CHAT_MODELS.GPT_4) - ? CHAT_MODELS.GPT_4 - : CHAT_MODELS.GPT_3_5_TURBO; +function getChatModel(): CHAT_MODEL_ID { + return getValidOpenAIModels().includes('gpt-4') ? 'gpt-4' : 'gpt-3.5-turbo'; } function initQAModel(level: LEVEL_NAMES, Prompt: string) { diff --git a/backend/src/models/api/OpenAiSetModelRequest.ts b/backend/src/models/api/OpenAiSetModelRequest.ts index 53162a677..20143473f 100644 --- a/backend/src/models/api/OpenAiSetModelRequest.ts +++ b/backend/src/models/api/OpenAiSetModelRequest.ts @@ -1,12 +1,15 @@ import { Request } from 'express'; -import { CHAT_MODELS, ChatModelConfigurations } from '@src/models/chat'; +import { CHAT_MODEL_ID, ChatModelConfigurations } from '@src/models/chat'; +import { ChatMessage } from '@src/models/chatMessage'; export type OpenAiSetModelRequest = Request< never, - never, { - model?: CHAT_MODELS; + chatInfoMessage: ChatMessage; + }, + { + model?: CHAT_MODEL_ID; configuration?: ChatModelConfigurations; }, never diff --git a/backend/src/models/chat.ts b/backend/src/models/chat.ts index 6025249b3..4b0f79a0c 100644 --- a/backend/src/models/chat.ts +++ b/backend/src/models/chat.ts @@ -7,18 +7,17 @@ import { ChatInfoMessage, ChatMessage } from './chatMessage'; import { DEFENCE_ID } from './defence'; import { EmailInfo } from './email'; -enum CHAT_MODELS { - GPT_4_TURBO = 'gpt-4-1106-preview', - GPT_4 = 'gpt-4', - GPT_4_0613 = 'gpt-4-0613', - GPT_3_5_TURBO = 'gpt-3.5-turbo', - GPT_3_5_TURBO_0613 = 'gpt-3.5-turbo-0613', - GPT_3_5_TURBO_16K = 'gpt-3.5-turbo-16k', - GPT_3_5_TURBO_16K_0613 = 'gpt-3.5-turbo-16k-0613', -} +const chatModelIds = [ + 'gpt-4-1106-preview', + 'gpt-4', + 'gpt-4-0613', + 'gpt-3.5-turbo', +] as const; + +type CHAT_MODEL_ID = (typeof chatModelIds)[number]; type ChatModel = { - id: CHAT_MODELS; + id: CHAT_MODEL_ID; configuration: ChatModelConfigurations; }; @@ -106,7 +105,7 @@ interface LevelHandlerResponse { } const defaultChatModel: ChatModel = { - id: CHAT_MODELS.GPT_3_5_TURBO, + id: 'gpt-3.5-turbo', configuration: { temperature: 1, topP: 1, @@ -131,4 +130,5 @@ export type { SingleDefenceReport, MODEL_CONFIG_ID, }; -export { CHAT_MODELS, defaultChatModel, modelConfigIds }; +export { defaultChatModel, modelConfigIds, chatModelIds }; +export type { CHAT_MODEL_ID }; diff --git a/backend/src/openai.ts b/backend/src/openai.ts index 57dfa3594..ce122a508 100644 --- a/backend/src/openai.ts +++ b/backend/src/openai.ts @@ -9,12 +9,13 @@ import { getQAPromptFromConfig } from './defence'; import { sendEmail } from './email'; import { queryDocuments } from './langchain'; import { - CHAT_MODELS, + CHAT_MODEL_ID, ChatGptReply, ChatModel, ChatResponse, FunctionCallResponse, ToolCallResponse, + chatModelIds, } from './models/chat'; import { ChatMessage } from './models/chatMessage'; import { QaLlmDefence } from './models/defence'; @@ -84,10 +85,10 @@ const chatGptTools: ChatCompletionTool[] = [ // list of valid chat models for the api key const validOpenAiModels = (() => { - let validModels: string[] = []; + let validModels: CHAT_MODEL_ID[] = []; return { get: () => validModels, - set: (models: string[]) => { + set: (models: CHAT_MODEL_ID[]) => { validModels = models; }, }; @@ -117,8 +118,8 @@ async function getValidModelsFromOpenAI() { // get the model ids that are supported by our app. Non-chat models like Dall-e and whisper are not supported. const validModels = models.data - .map((model) => model.id) - .filter((id) => Object.values(CHAT_MODELS).includes(id as CHAT_MODELS)) + .map((model) => model.id as CHAT_MODEL_ID) + .filter((id) => chatModelIds.includes(id)) .sort(); validOpenAiModels.set(validModels); @@ -283,7 +284,7 @@ async function chatGptChatCompletion( function getChatCompletionsInContextWindow( chatHistory: ChatMessage[], - gptModel: CHAT_MODELS + gptModel: CHAT_MODEL_ID ): ChatCompletionMessageParam[] { const completions = chatHistory .map((chatMessage) => diff --git a/backend/src/utils/token.ts b/backend/src/utils/token.ts index d431bd6c8..598043dd2 100644 --- a/backend/src/utils/token.ts +++ b/backend/src/utils/token.ts @@ -4,18 +4,17 @@ import { } from 'openai/resources/chat/completions'; import { promptTokensEstimate, stringTokens } from 'openai-chat-tokens'; -import { CHAT_MODELS } from '@src/models/chat'; +import { CHAT_MODEL_ID } from '@src/models/chat'; import { chatGptTools } from '@src/openai'; // The size of each model's context window in number of tokens. https://platform.openai.com/docs/models -const chatModelMaxTokens = { - [CHAT_MODELS.GPT_4_TURBO]: 128000, - [CHAT_MODELS.GPT_4]: 8192, - [CHAT_MODELS.GPT_4_0613]: 8192, - [CHAT_MODELS.GPT_3_5_TURBO]: 4097, - [CHAT_MODELS.GPT_3_5_TURBO_0613]: 4097, - [CHAT_MODELS.GPT_3_5_TURBO_16K]: 16385, - [CHAT_MODELS.GPT_3_5_TURBO_16K_0613]: 16385, +const chatModelMaxTokens: { + [key in CHAT_MODEL_ID]: number; +} = { + 'gpt-4': 8192, + 'gpt-4-1106-preview': 128000, + 'gpt-4-0613': 8192, + 'gpt-3.5-turbo': 16385, }; const TOKENS_PER_TOOL_CALL = 4; diff --git a/backend/test/integration/openai.test.ts b/backend/test/integration/openai.test.ts index 840b0c724..3b9e91942 100644 --- a/backend/test/integration/openai.test.ts +++ b/backend/test/integration/openai.test.ts @@ -1,6 +1,6 @@ import { expect, jest, test, describe } from '@jest/globals'; -import { CHAT_MODELS, ChatModel } from '@src/models/chat'; +import { ChatModel } from '@src/models/chat'; import { ChatMessage } from '@src/models/chatMessage'; import { chatGptSendMessage } from '@src/openai'; @@ -54,7 +54,7 @@ describe('OpenAI Integration Tests', () => { }, ]; const chatModel: ChatModel = { - id: CHAT_MODELS.GPT_4, + id: 'gpt-4', configuration: { temperature: 1, topP: 1, diff --git a/backend/test/unit/controller/modelController.test.ts b/backend/test/unit/controller/modelController.test.ts index b7cb32239..141e42e16 100644 --- a/backend/test/unit/controller/modelController.test.ts +++ b/backend/test/unit/controller/modelController.test.ts @@ -1,8 +1,12 @@ import { expect, jest, test, describe } from '@jest/globals'; import { Response } from 'express'; -import { handleConfigureModel } from '@src/controller/modelController'; +import { + handleConfigureModel, + handleSetModel, +} from '@src/controller/modelController'; import { OpenAiConfigureModelRequest } from '@src/models/api/OpenAiConfigureModelRequest'; +import { OpenAiSetModelRequest } from '@src/models/api/OpenAiSetModelRequest'; import { modelConfigIds } from '@src/models/chat'; import { ChatMessage } from '@src/models/chatMessage'; import { LEVEL_NAMES, LevelState } from '@src/models/level'; @@ -14,6 +18,64 @@ function responseMock() { } as unknown as Response; } +describe('handleSetModel', () => { + test('WHEN passed sensible parameters THEN sets model AND adds info message to chat history AND responds with info message', () => { + const req = { + body: { + model: 'gpt-4', + }, + session: { + chatModel: { + id: 'gpt-3.5-turbo', + }, + levelState: [{}, {}, {}, { chatHistory: [] } as unknown as LevelState], + }, + } as OpenAiSetModelRequest; + const res = responseMock(); + + handleSetModel(req, res); + + expect(req.session.chatModel.id).toBe('gpt-4'); + + const expectedInfoMessage = { + infoMessage: 'changed model to gpt-4', + chatMessageType: 'GENERIC_INFO', + } as ChatMessage; + expect( + req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory.at(-1) + ).toEqual(expectedInfoMessage); + expect(res.send).toHaveBeenCalledWith({ + chatInfoMessage: expectedInfoMessage, + }); + }); + + test('WHEN missing model THEN does not set model', () => { + const req = { + body: {}, + } as OpenAiSetModelRequest; + const res = responseMock(); + + handleSetModel(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Missing model'); + }); + + test('WHEN model is invalid THEN does not set model', () => { + const req = { + body: { + model: 'invalid model', + }, + } as unknown as OpenAiSetModelRequest; + const res = responseMock(); + + handleSetModel(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Invalid model'); + }); +}); + describe('handleConfigureModel', () => { test('WHEN passed sensible parameters THEN configures model AND adds info message to chat history AND responds with info message', () => { const req = { @@ -37,7 +99,6 @@ describe('handleConfigureModel', () => { handleConfigureModel(req, res); - expect(res.status).toHaveBeenCalledWith(200); expect(req.session.chatModel.configuration.topP).toBe(0.5); const expectedInfoMessage = { diff --git a/backend/test/unit/openai.test.ts b/backend/test/unit/openai.test.ts index ca0ab8161..b46ff38e1 100644 --- a/backend/test/unit/openai.test.ts +++ b/backend/test/unit/openai.test.ts @@ -42,12 +42,7 @@ describe('getValidModelsFromOpenAI', () => { { id: 'da-vinci-1' }, { id: 'da-vinci-2' }, ]; - const expectedValidModels = [ - 'gpt-3.5-turbo', - 'gpt-3.5-turbo-0613', - 'gpt-4', - 'gpt-4-0613', - ]; + const expectedValidModels = ['gpt-3.5-turbo', 'gpt-4', 'gpt-4-0613']; mockListFn.mockResolvedValueOnce({ data: mockModelList, diff --git a/frontend/src/components/ControlPanel/ControlPanel.tsx b/frontend/src/components/ControlPanel/ControlPanel.tsx index f3f038863..c8633ce44 100644 --- a/frontend/src/components/ControlPanel/ControlPanel.tsx +++ b/frontend/src/components/ControlPanel/ControlPanel.tsx @@ -3,7 +3,7 @@ import DefenceBox from '@src/components/DefenceBox/DefenceBox'; import ModelBox from '@src/components/ModelBox/ModelBox'; import DetailElement from '@src/components/ThemedButtons/DetailElement'; import ThemedButton from '@src/components/ThemedButtons/ThemedButton'; -import { ChatMessage, ChatModel } from '@src/models/chat'; +import { CHAT_MODEL_ID, ChatMessage, ChatModel } from '@src/models/chat'; import { DEFENCE_ID, DefenceConfigItem, @@ -24,13 +24,12 @@ function ControlPanel({ resetDefenceConfiguration, setDefenceConfiguration, openDocumentViewer, - addInfoMessage, addChatMessage, }: { currentLevel: LEVEL_NAMES; defences: Defence[]; chatModel?: ChatModel; - setChatModelId: (modelId: string) => void; + setChatModelId: (modelId: CHAT_MODEL_ID) => void; chatModelOptions: string[]; toggleDefence: (defence: Defence) => void; resetDefenceConfiguration: ( @@ -42,7 +41,6 @@ function ControlPanel({ config: DefenceConfigItem[] ) => Promise; openDocumentViewer: () => void; - addInfoMessage: (message: string) => void; addChatMessage: (chatMessage: ChatMessage) => void; }) { const configurableDefences = @@ -101,7 +99,6 @@ function ControlPanel({ chatModel={chatModel} setChatModelId={setChatModelId} chatModelOptions={chatModelOptions} - addInfoMessage={addInfoMessage} addChatMessage={addChatMessage} /> )} diff --git a/frontend/src/components/MainComponent/MainBody.tsx b/frontend/src/components/MainComponent/MainBody.tsx index ae6066231..620288196 100644 --- a/frontend/src/components/MainComponent/MainBody.tsx +++ b/frontend/src/components/MainComponent/MainBody.tsx @@ -1,7 +1,7 @@ import ChatBox from '@src/components/ChatBox/ChatBox'; import ControlPanel from '@src/components/ControlPanel/ControlPanel'; import EmailBox from '@src/components/EmailBox/EmailBox'; -import { ChatMessage, ChatModel } from '@src/models/chat'; +import { CHAT_MODEL_ID, ChatMessage, ChatModel } from '@src/models/chat'; import { DEFENCE_ID, DefenceConfigItem, @@ -22,7 +22,6 @@ function MainBody({ chatModels, setChatModelId, addChatMessage, - addInfoMessage, addSentEmails, resetDefenceConfiguration, toggleDefence, @@ -37,10 +36,9 @@ function MainBody({ emails: EmailInfo[]; messages: ChatMessage[]; chatModel?: ChatModel; - setChatModelId: (modelId: string) => void; + setChatModelId: (modelId: CHAT_MODEL_ID) => void; chatModels: string[]; addChatMessage: (message: ChatMessage) => void; - addInfoMessage: (message: string) => void; addSentEmails: (emails: EmailInfo[]) => void; resetDefenceConfiguration: ( defenceId: DEFENCE_ID, @@ -70,7 +68,6 @@ function MainBody({ resetDefenceConfiguration={resetDefenceConfiguration} setDefenceConfiguration={setDefenceConfiguration} openDocumentViewer={openDocumentViewer} - addInfoMessage={addInfoMessage} addChatMessage={addChatMessage} /> diff --git a/frontend/src/components/MainComponent/MainComponent.tsx b/frontend/src/components/MainComponent/MainComponent.tsx index 1c096cd11..4912a1cfb 100644 --- a/frontend/src/components/MainComponent/MainComponent.tsx +++ b/frontend/src/components/MainComponent/MainComponent.tsx @@ -5,7 +5,7 @@ import HandbookOverlay from '@src/components/HandbookOverlay/HandbookOverlay'; import LevelMissionInfoBanner from '@src/components/LevelMissionInfoBanner/LevelMissionInfoBanner'; import ResetLevelOverlay from '@src/components/Overlay/ResetLevel'; import ResetProgressOverlay from '@src/components/Overlay/ResetProgress'; -import { ChatMessage, ChatModel } from '@src/models/chat'; +import { CHAT_MODEL_ID, ChatMessage, ChatModel } from '@src/models/chat'; import { DEFENCE_ID, DefenceConfigItem, @@ -290,7 +290,7 @@ function MainComponent({ } } - function setChatModelId(modelId: string) { + function setChatModelId(modelId: CHAT_MODEL_ID) { if (!chatModel) { console.error( 'You are trying to change the id of the chatModel but it has not been loaded yet' @@ -364,7 +364,6 @@ function MainComponent({ emails={emails} messages={messages} addChatMessage={addChatMessage} - addInfoMessage={addInfoMessage} addSentEmails={addSentEmails} resetDefenceConfiguration={( defenceId: DEFENCE_ID, diff --git a/frontend/src/components/ModelBox/ModelBox.tsx b/frontend/src/components/ModelBox/ModelBox.tsx index 97ffba269..f3e2743a8 100644 --- a/frontend/src/components/ModelBox/ModelBox.tsx +++ b/frontend/src/components/ModelBox/ModelBox.tsx @@ -1,4 +1,4 @@ -import { ChatMessage, ChatModel } from '@src/models/chat'; +import { CHAT_MODEL_ID, ChatMessage, ChatModel } from '@src/models/chat'; import ModelConfiguration from './ModelConfiguration'; import ModelSelection from './ModelSelection'; @@ -9,13 +9,11 @@ function ModelBox({ chatModel, setChatModelId, chatModelOptions, - addInfoMessage, addChatMessage, }: { chatModel?: ChatModel; - setChatModelId: (modelId: string) => void; + setChatModelId: (modelId: CHAT_MODEL_ID) => void; chatModelOptions: string[]; - addInfoMessage: (message: string) => void; addChatMessage: (chatMessage: ChatMessage) => void; }) { return ( @@ -24,7 +22,7 @@ function ModelBox({ chatModel={chatModel} setChatModelId={setChatModelId} chatModelOptions={chatModelOptions} - addInfoMessage={addInfoMessage} + addChatMessage={addChatMessage} /> void; + setChatModelId: (modelId: CHAT_MODEL_ID) => void; chatModelOptions: string[]; - addInfoMessage: (message: string) => void; + addChatMessage: (message: ChatMessage) => void; }) { // model currently selected in the dropdown - const [selectedModel, setSelectedModel] = useState( + const [selectedModel, setSelectedModel] = useState( undefined ); @@ -34,11 +34,13 @@ function ModelSelection({ const currentSelectedModel = selectedModel; console.log(`selected model: ${currentSelectedModel}`); setIsSettingModel(true); - const modelUpdated = await chatService.setGptModel(currentSelectedModel); + const chatInfoMessage = await chatService.setGptModel( + currentSelectedModel + ); setIsSettingModel(false); - if (modelUpdated) { + if (chatInfoMessage) { setErrorChangingModel(false); - addInfoMessage(`changed model to ${currentSelectedModel}`); + addChatMessage(chatInfoMessage); setChatModelId(currentSelectedModel); } else { setErrorChangingModel(true); @@ -64,7 +66,7 @@ function ModelSelection({ aria-label="model-select" value={selectedModel ?? 0} // default to the first model onChange={(e) => { - setSelectedModel(e.target.value); + setSelectedModel(e.target.value as CHAT_MODEL_ID); }} > {chatModelOptions.map((model) => ( diff --git a/frontend/src/models/chat.ts b/frontend/src/models/chat.ts index ef418eea9..cb0b348f0 100644 --- a/frontend/src/models/chat.ts +++ b/frontend/src/models/chat.ts @@ -15,8 +15,20 @@ type CHAT_MESSAGE_TYPE = | 'ERROR_MSG' | 'RESET_LEVEL'; +const chatModelIds = [ + 'gpt-4-1106-preview', + 'gpt-4', + 'gpt-4-0613', + 'gpt-3.5-turbo', + 'gpt-3.5-turbo-0613', + 'gpt-3.5-turbo-16k', + 'gpt-3.5-turbo-16k-0613', +] as const; + +type CHAT_MODEL_ID = (typeof chatModelIds)[number]; + type ChatModel = { - id: string; + id: CHAT_MODEL_ID; configuration: ChatModelConfigurations; }; @@ -95,4 +107,5 @@ export type { CustomChatModelConfiguration, CHAT_MESSAGE_TYPE, MODEL_CONFIG_ID, + CHAT_MODEL_ID, }; diff --git a/frontend/src/service/chatService.ts b/frontend/src/service/chatService.ts index 2af47dd11..25e242fee 100644 --- a/frontend/src/service/chatService.ts +++ b/frontend/src/service/chatService.ts @@ -55,13 +55,18 @@ function getChatMessagesFromDTOResponse(chatMessageDTOs: ChatMessageDTO[]) { .map(makeChatMessageFromDTO); } -async function setGptModel(model: string): Promise { +async function setGptModel(model: string): Promise { const response = await sendRequest(`${PATH}model`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ model }), }); - return response.status === 200; + + if (response.status !== 200) return null; + + const { chatInfoMessage } = + (await response.json()) as ChatInfoMessageResponse; + return makeChatMessageFromDTO(chatInfoMessage); } async function configureGptModel(