From ae133b7f2f65a6d7c754ef304bb4838801f80238 Mon Sep 17 00:00:00 2001 From: Peter Marsh <118171430+pmarsh-scottlogic@users.noreply.github.com> Date: Thu, 28 Mar 2024 09:58:31 +0000 Subject: [PATCH] 828 streamline chat model configuration info message network call (#875) --- backend/src/controller/modelController.ts | 53 +++++- .../models/api/OpenAiConfigureModelRequest.ts | 4 +- .../src/models/api/OpenAiSetModelRequest.ts | 4 +- backend/src/models/chat.ts | 36 ++-- .../unit/controller/modelController.test.ts | 156 ++++++++++++++++++ .../components/ControlPanel/ControlPanel.tsx | 5 +- .../src/components/MainComponent/MainBody.tsx | 1 + frontend/src/components/ModelBox/ModelBox.tsx | 6 +- .../ModelBox/ModelConfiguration.tsx | 25 +-- .../ModelBox/ModelConfigurationSlider.tsx | 7 +- frontend/src/models/chat.ts | 33 ++-- frontend/src/service/chatService.ts | 14 +- 12 files changed, 277 insertions(+), 67 deletions(-) create mode 100644 backend/test/unit/controller/modelController.test.ts diff --git a/backend/src/controller/modelController.ts b/backend/src/controller/modelController.ts index fd6bf2b7c..36934e5d6 100644 --- a/backend/src/controller/modelController.ts +++ b/backend/src/controller/modelController.ts @@ -2,7 +2,12 @@ import { Response } from 'express'; import { OpenAiConfigureModelRequest } from '@src/models/api/OpenAiConfigureModelRequest'; import { OpenAiSetModelRequest } from '@src/models/api/OpenAiSetModelRequest'; -import { MODEL_CONFIG } from '@src/models/chat'; +import { MODEL_CONFIG_ID, modelConfigIds } from '@src/models/chat'; +import { ChatInfoMessage } from '@src/models/chatMessage'; +import { LEVEL_NAMES } from '@src/models/level'; +import { pushMessageToHistory } from '@src/utils/chat'; + +import { sendErrorResponse } from './handleError'; function handleSetModel(req: OpenAiSetModelRequest, res: Response) { const { model } = req.body; @@ -19,16 +24,48 @@ function handleSetModel(req: OpenAiSetModelRequest, res: Response) { } function handleConfigureModel(req: OpenAiConfigureModelRequest, res: Response) { - const configId = req.body.configId as MODEL_CONFIG | undefined; + const configId = req.body.configId as MODEL_CONFIG_ID | undefined; const value = req.body.value; - const maxValue = configId === MODEL_CONFIG.TOP_P ? 1 : 2; - if (configId && value !== undefined && value >= 0 && value <= maxValue) { - req.session.chatModel.configuration[configId] = value; - res.status(200).send(); - } else { - res.status(400).send(); + if (!configId) { + sendErrorResponse(res, 400, 'Missing configId'); + return; + } + + if (!modelConfigIds.includes(configId)) { + sendErrorResponse(res, 400, 'Invalid configId'); + return; + } + + if (!Number.isFinite(value) || value === undefined) { + sendErrorResponse(res, 400, 'Missing or invalid value'); + return; + } + + const maxValue = configId === 'topP' ? 1 : 2; + + if (value < 0 || value > maxValue) { + sendErrorResponse( + res, + 400, + `Value should be between 0 and ${maxValue} for ${configId}` + ); + return; } + + req.session.chatModel.configuration[configId] = value; + + const chatInfoMessage = { + infoMessage: `changed ${configId} to ${value}`, + chatMessageType: 'GENERIC_INFO', + } as ChatInfoMessage; + req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory = + pushMessageToHistory( + req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory, + chatInfoMessage + ); + + res.status(200).send({ chatInfoMessage }); } export { handleSetModel, handleConfigureModel }; diff --git a/backend/src/models/api/OpenAiConfigureModelRequest.ts b/backend/src/models/api/OpenAiConfigureModelRequest.ts index 83d9fdf94..aa5c49fc0 100644 --- a/backend/src/models/api/OpenAiConfigureModelRequest.ts +++ b/backend/src/models/api/OpenAiConfigureModelRequest.ts @@ -1,8 +1,10 @@ import { Request } from 'express'; +import { ChatMessage } from '@src/models/chatMessage'; + export type OpenAiConfigureModelRequest = Request< never, - never, + null | { chatInfoMessage: ChatMessage }, { configId?: string; value?: number; diff --git a/backend/src/models/api/OpenAiSetModelRequest.ts b/backend/src/models/api/OpenAiSetModelRequest.ts index 5ca03f7fb..53162a677 100644 --- a/backend/src/models/api/OpenAiSetModelRequest.ts +++ b/backend/src/models/api/OpenAiSetModelRequest.ts @@ -1,13 +1,13 @@ import { Request } from 'express'; -import { CHAT_MODELS, ChatModelConfiguration } from '@src/models/chat'; +import { CHAT_MODELS, ChatModelConfigurations } from '@src/models/chat'; export type OpenAiSetModelRequest = Request< never, never, { model?: CHAT_MODELS; - configuration?: ChatModelConfiguration; + configuration?: ChatModelConfigurations; }, never >; diff --git a/backend/src/models/chat.ts b/backend/src/models/chat.ts index fd4a6a72e..6025249b3 100644 --- a/backend/src/models/chat.ts +++ b/backend/src/models/chat.ts @@ -17,24 +17,23 @@ enum CHAT_MODELS { GPT_3_5_TURBO_16K_0613 = 'gpt-3.5-turbo-16k-0613', } -enum MODEL_CONFIG { - TEMPERATURE = 'temperature', - TOP_P = 'topP', - FREQUENCY_PENALTY = 'frequencyPenalty', - PRESENCE_PENALTY = 'presencePenalty', -} - -interface ChatModel { +type ChatModel = { id: CHAT_MODELS; - configuration: ChatModelConfiguration; -} + configuration: ChatModelConfigurations; +}; -interface ChatModelConfiguration { - temperature: number; - topP: number; - frequencyPenalty: number; - presencePenalty: number; -} +const modelConfigIds = [ + 'temperature', + 'topP', + 'frequencyPenalty', + 'presencePenalty', +] as const; + +type MODEL_CONFIG_ID = (typeof modelConfigIds)[number]; + +type ChatModelConfigurations = { + [key in MODEL_CONFIG_ID]: number; +}; interface DefenceReport { blockedReason: string | null; @@ -121,7 +120,7 @@ export type { ChatGptReply, ChatMalicious, ChatModel, - ChatModelConfiguration, + ChatModelConfigurations, ChatResponse, LevelHandlerResponse, ChatHttpResponse, @@ -130,5 +129,6 @@ export type { ToolCallResponse, MessageTransformation, SingleDefenceReport, + MODEL_CONFIG_ID, }; -export { CHAT_MODELS, MODEL_CONFIG, defaultChatModel }; +export { CHAT_MODELS, defaultChatModel, modelConfigIds }; diff --git a/backend/test/unit/controller/modelController.test.ts b/backend/test/unit/controller/modelController.test.ts new file mode 100644 index 000000000..b7cb32239 --- /dev/null +++ b/backend/test/unit/controller/modelController.test.ts @@ -0,0 +1,156 @@ +import { expect, jest, test, describe } from '@jest/globals'; +import { Response } from 'express'; + +import { handleConfigureModel } from '@src/controller/modelController'; +import { OpenAiConfigureModelRequest } from '@src/models/api/OpenAiConfigureModelRequest'; +import { modelConfigIds } from '@src/models/chat'; +import { ChatMessage } from '@src/models/chatMessage'; +import { LEVEL_NAMES, LevelState } from '@src/models/level'; + +function responseMock() { + return { + send: jest.fn(), + status: jest.fn().mockReturnThis(), + } as unknown as Response; +} + +describe('handleConfigureModel', () => { + test('WHEN passed sensible parameters THEN configures model AND adds info message to chat history AND responds with info message', () => { + const req = { + body: { + configId: 'topP', + value: 0.5, + }, + session: { + chatModel: { + configuration: { + temperature: 0.0, + topP: 0.0, + presencePenalty: 0.0, + frequencyPenalty: 0.0, + }, + }, + levelState: [{}, {}, {}, { chatHistory: [] } as unknown as LevelState], + }, + } as OpenAiConfigureModelRequest; + const res = responseMock(); + + handleConfigureModel(req, res); + + expect(res.status).toHaveBeenCalledWith(200); + expect(req.session.chatModel.configuration.topP).toBe(0.5); + + const expectedInfoMessage = { + infoMessage: 'changed topP to 0.5', + chatMessageType: 'GENERIC_INFO', + } as ChatMessage; + expect( + req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory.at(-1) + ).toEqual(expectedInfoMessage); + expect(res.send).toHaveBeenCalledWith({ + chatInfoMessage: expectedInfoMessage, + }); + }); + + test('WHEN missing configId THEN does not configure model', () => { + const req = { + body: { + value: 0.5, + }, + } as OpenAiConfigureModelRequest; + const res = responseMock(); + + handleConfigureModel(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Missing configId'); + }); + + test('WHEN configId is invalid THEN does not configure model', () => { + const req = { + body: { + configId: 'invalid config id', + value: 0.5, + }, + } as OpenAiConfigureModelRequest; + const res = responseMock(); + + handleConfigureModel(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Invalid configId'); + }); + + test('WHEN value is missing THEN does not configure model', () => { + const req = { + body: { + configId: 'topP', + }, + } as unknown as OpenAiConfigureModelRequest; + const res = responseMock(); + + handleConfigureModel(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Missing or invalid value'); + }); + + test('WHEN value is invalid THEN does not configure model', () => { + const req = { + body: { + configId: 'topP', + value: 'invalid value', + }, + } as unknown as OpenAiConfigureModelRequest; + const res = responseMock(); + + handleConfigureModel(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Missing or invalid value'); + }); + + test.each(modelConfigIds)( + 'GIVEN configId is %s WHEN value is below range THEN does not configure model', + (configId) => { + const req = { + body: { + configId, + value: -1, + }, + } as unknown as OpenAiConfigureModelRequest; + const res = responseMock(); + + handleConfigureModel(req, res); + + const expectedMaxValue = configId === 'topP' ? 1 : 2; + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + `Value should be between 0 and ${expectedMaxValue} for ${configId}` + ); + } + ); + + test.each(modelConfigIds)( + 'GIVEN configId is %s WHEN value is above range THEN does not configure model', + (configId) => { + const expectedMaxValue = configId === 'topP' ? 1 : 2; + + const req = { + body: { + configId, + value: expectedMaxValue + 1, + }, + } as unknown as OpenAiConfigureModelRequest; + const res = responseMock(); + + handleConfigureModel(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + `Value should be between 0 and ${expectedMaxValue} for ${configId}` + ); + } + ); +}); diff --git a/frontend/src/components/ControlPanel/ControlPanel.tsx b/frontend/src/components/ControlPanel/ControlPanel.tsx index e1849a29a..db096b034 100644 --- a/frontend/src/components/ControlPanel/ControlPanel.tsx +++ b/frontend/src/components/ControlPanel/ControlPanel.tsx @@ -3,7 +3,7 @@ import DefenceBox from '@src/components/DefenceBox/DefenceBox'; import DocumentViewButton from '@src/components/DocumentViewer/DocumentViewButton'; import ModelBox from '@src/components/ModelBox/ModelBox'; import DetailElement from '@src/components/ThemedButtons/DetailElement'; -import { ChatModel } from '@src/models/chat'; +import { ChatMessage, ChatModel } from '@src/models/chat'; import { DEFENCE_ID, DefenceConfigItem, @@ -25,6 +25,7 @@ function ControlPanel({ setDefenceConfiguration, openDocumentViewer, addInfoMessage, + addChatMessage, }: { currentLevel: LEVEL_NAMES; defences: Defence[]; @@ -42,6 +43,7 @@ function ControlPanel({ ) => Promise; openDocumentViewer: () => void; addInfoMessage: (message: string) => void; + addChatMessage: (chatMessage: ChatMessage) => void; }) { const configurableDefences = currentLevel === LEVEL_NAMES.SANDBOX @@ -100,6 +102,7 @@ function ControlPanel({ setChatModelId={setChatModelId} chatModelOptions={chatModelOptions} addInfoMessage={addInfoMessage} + addChatMessage={addChatMessage} /> )} diff --git a/frontend/src/components/MainComponent/MainBody.tsx b/frontend/src/components/MainComponent/MainBody.tsx index 70e937bc0..ae6066231 100644 --- a/frontend/src/components/MainComponent/MainBody.tsx +++ b/frontend/src/components/MainComponent/MainBody.tsx @@ -71,6 +71,7 @@ function MainBody({ setDefenceConfiguration={setDefenceConfiguration} openDocumentViewer={openDocumentViewer} addInfoMessage={addInfoMessage} + addChatMessage={addChatMessage} />
diff --git a/frontend/src/components/ModelBox/ModelBox.tsx b/frontend/src/components/ModelBox/ModelBox.tsx index 09f66ba43..97ffba269 100644 --- a/frontend/src/components/ModelBox/ModelBox.tsx +++ b/frontend/src/components/ModelBox/ModelBox.tsx @@ -1,4 +1,4 @@ -import { ChatModel } from '@src/models/chat'; +import { ChatMessage, ChatModel } from '@src/models/chat'; import ModelConfiguration from './ModelConfiguration'; import ModelSelection from './ModelSelection'; @@ -10,11 +10,13 @@ function ModelBox({ setChatModelId, chatModelOptions, addInfoMessage, + addChatMessage, }: { chatModel?: ChatModel; setChatModelId: (modelId: string) => void; chatModelOptions: string[]; addInfoMessage: (message: string) => void; + addChatMessage: (chatMessage: ChatMessage) => void; }) { return (
@@ -26,7 +28,7 @@ function ModelBox({ />
); diff --git a/frontend/src/components/ModelBox/ModelConfiguration.tsx b/frontend/src/components/ModelBox/ModelConfiguration.tsx index 149e29335..630b35629 100644 --- a/frontend/src/components/ModelBox/ModelConfiguration.tsx +++ b/frontend/src/components/ModelBox/ModelConfiguration.tsx @@ -1,9 +1,10 @@ import { useEffect, useState } from 'react'; import { + ChatMessage, ChatModel, CustomChatModelConfiguration, - MODEL_CONFIG, + MODEL_CONFIG_ID, } from '@src/models/chat'; import { chatService } from '@src/service'; @@ -13,7 +14,7 @@ import './ModelConfiguration.css'; const DEFAULT_CONFIGS: CustomChatModelConfiguration[] = [ { - id: MODEL_CONFIG.TEMPERATURE, + id: 'temperature', name: 'Model Temperature', info: 'Controls the randomness of the model. Lower means more deterministic, higher means more surprising. Default is 1.', value: 1, @@ -21,7 +22,7 @@ const DEFAULT_CONFIGS: CustomChatModelConfiguration[] = [ max: 2, }, { - id: MODEL_CONFIG.TOP_P, + id: 'topP', name: 'Top P', info: 'Controls how many different words or phrases the language model considers when it’s trying to predict the next word. Default is 1. ', value: 1, @@ -29,7 +30,7 @@ const DEFAULT_CONFIGS: CustomChatModelConfiguration[] = [ max: 1, }, { - id: MODEL_CONFIG.PRESENCE_PENALTY, + id: 'presencePenalty', name: 'Presence Penalty', info: 'Controls diversity of text generation. Higher presence penalty increases likelihood of using new words. Default is 0.', value: 0, @@ -37,7 +38,7 @@ const DEFAULT_CONFIGS: CustomChatModelConfiguration[] = [ max: 2, }, { - id: MODEL_CONFIG.FREQUENCY_PENALTY, + id: 'frequencyPenalty', name: 'Frequency Penalty', info: 'Controls diversity of text generation. Higher frequency penalty decreases likelihood of using the same words. Default is 0.', value: 0, @@ -48,15 +49,15 @@ const DEFAULT_CONFIGS: CustomChatModelConfiguration[] = [ function ModelConfiguration({ chatModel, - addInfoMessage, + addChatMessage, }: { chatModel?: ChatModel; - addInfoMessage: (message: string) => void; + addChatMessage: (chatMessage: ChatMessage) => void; }) { const [customChatModelConfigs, setCustomChatModel] = useState(DEFAULT_CONFIGS); - function setCustomChatModelByID(id: MODEL_CONFIG, value: number) { + function setCustomChatModelByID(id: MODEL_CONFIG_ID, value: number) { setCustomChatModel((prev) => { return prev.map((config) => { if (config.id === id) { @@ -67,7 +68,7 @@ function ModelConfiguration({ }); } - function updateConfigValue(id: MODEL_CONFIG, newValue: number) { + function updateConfigValue(id: MODEL_CONFIG_ID, newValue: number) { const prevValue = customChatModelConfigs.find( (config) => config.id === id )?.value; @@ -77,9 +78,9 @@ function ModelConfiguration({ setCustomChatModelByID(id, newValue); - void chatService.configureGptModel(id, newValue).then((success) => { - if (success) { - addInfoMessage(`changed ${id} to ${newValue}`); + void chatService.configureGptModel(id, newValue).then((chatInfoMessage) => { + if (chatInfoMessage) { + addChatMessage(chatInfoMessage); } else { setCustomChatModelByID(id, prevValue); } diff --git a/frontend/src/components/ModelBox/ModelConfigurationSlider.tsx b/frontend/src/components/ModelBox/ModelConfigurationSlider.tsx index 30826ef93..6d1835e4b 100644 --- a/frontend/src/components/ModelBox/ModelConfigurationSlider.tsx +++ b/frontend/src/components/ModelBox/ModelConfigurationSlider.tsx @@ -2,7 +2,10 @@ import { Slider } from '@mui/material'; import { useEffect, useState } from 'react'; import DetailElement from '@src/components/ThemedButtons/DetailElement'; -import { CustomChatModelConfiguration, MODEL_CONFIG } from '@src/models/chat'; +import { + CustomChatModelConfiguration, + MODEL_CONFIG_ID, +} from '@src/models/chat'; import './ModelConfigurationSlider.css'; @@ -11,7 +14,7 @@ function ModelConfigurationSlider({ onConfigChanged, }: { config: CustomChatModelConfiguration; - onConfigChanged: (id: MODEL_CONFIG, newValue: number) => void; + onConfigChanged: (id: MODEL_CONFIG_ID, newValue: number) => void; }) { const [value, setValue] = useState(config.value); diff --git a/frontend/src/models/chat.ts b/frontend/src/models/chat.ts index 0cabfe9fb..ef418eea9 100644 --- a/frontend/src/models/chat.ts +++ b/frontend/src/models/chat.ts @@ -15,27 +15,26 @@ type CHAT_MESSAGE_TYPE = | 'ERROR_MSG' | 'RESET_LEVEL'; -enum MODEL_CONFIG { - TEMPERATURE = 'temperature', - TOP_P = 'topP', - FREQUENCY_PENALTY = 'frequencyPenalty', - PRESENCE_PENALTY = 'presencePenalty', -} - -interface ChatModel { +type ChatModel = { id: string; configuration: ChatModelConfigurations; -} +}; -interface ChatModelConfigurations { - temperature: number; - topP: number; - frequencyPenalty: number; - presencePenalty: number; -} +const modelConfigIds = [ + 'temperature', + 'topP', + 'frequencyPenalty', + 'presencePenalty', +] as const; + +type MODEL_CONFIG_ID = (typeof modelConfigIds)[number]; + +type ChatModelConfigurations = { + [key in MODEL_CONFIG_ID]: number; +}; interface CustomChatModelConfiguration { - id: MODEL_CONFIG; + id: MODEL_CONFIG_ID; name: string; info: string; value: number; @@ -95,5 +94,5 @@ export type { ChatModelConfigurations, CustomChatModelConfiguration, CHAT_MESSAGE_TYPE, + MODEL_CONFIG_ID, }; -export { MODEL_CONFIG }; diff --git a/frontend/src/service/chatService.ts b/frontend/src/service/chatService.ts index becbfe5c9..2af47dd11 100644 --- a/frontend/src/service/chatService.ts +++ b/frontend/src/service/chatService.ts @@ -1,9 +1,10 @@ +import { ChatInfoMessageResponse } from '@src/models/apiResponse'; import { CHAT_MESSAGE_TYPE, ChatMessageDTO, ChatMessage, ChatResponse, - MODEL_CONFIG, + MODEL_CONFIG_ID, } from '@src/models/chat'; import { LEVEL_NAMES } from '@src/models/level'; @@ -64,15 +65,20 @@ async function setGptModel(model: string): Promise { } async function configureGptModel( - configId: MODEL_CONFIG, + configId: MODEL_CONFIG_ID, value: number -): Promise { +): Promise { const response = await sendRequest(`${PATH}model/configure`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ configId, value }), }); - return response.status === 200; + + if (response.status !== 200) return null; + + const { chatInfoMessage } = + (await response.json()) as ChatInfoMessageResponse; + return makeChatMessageFromDTO(chatInfoMessage); } async function addInfoMessageToChatHistory(