diff --git a/backend/src/controller/chatController.ts b/backend/src/controller/chatController.ts index a9388c6e5..537d34c5f 100644 --- a/backend/src/controller/chatController.ts +++ b/backend/src/controller/chatController.ts @@ -1,5 +1,6 @@ import { Response } from 'express'; +import { defaultDefences } from '@src/defaultDefences'; import { transformMessage, detectTriggeredInputDefences, @@ -21,7 +22,7 @@ import { ChatInfoMessage, chatInfoMessageTypes, } from '@src/models/chatMessage'; -import { Defence } from '@src/models/defence'; +import { DEFENCE_ID, Defence, QaLlmDefence } from '@src/models/defence'; import { EmailInfo } from '@src/models/email'; import { LEVEL_NAMES } from '@src/models/level'; import { chatGptSendMessage } from '@src/openai'; @@ -94,8 +95,7 @@ async function handleChatWithoutDefenceDetection( chatResponse: ChatHttpResponse, currentLevel: LEVEL_NAMES, chatModel: ChatModel, - chatHistory: ChatMessage[], - defences: Defence[] + chatHistory: ChatMessage[] ): Promise { console.log(`User message: '${message}'`); @@ -106,7 +106,6 @@ async function handleChatWithoutDefenceDetection( const openAiReply = await chatGptSendMessage( updatedChatHistory, - defences, chatModel, currentLevel ); @@ -149,11 +148,15 @@ async function handleChatWithDefenceDetection( }'` ); + const qaLlmDefence = defences.find( + (defence) => defence.id === DEFENCE_ID.QA_LLM + ) as QaLlmDefence | undefined; + const openAiReplyPromise = chatGptSendMessage( chatHistoryWithNewUserMessages, - defences, chatModel, - currentLevel + currentLevel, + qaLlmDefence ); // run input defence detection and chatGPT concurrently @@ -240,14 +243,15 @@ async function handleChatToGPT(req: OpenAiChatRequest, res: Response) { ? req.session.chatModel : defaultChatModel; + const defences = + req.session.levelState[currentLevel].defences ?? defaultDefences; + const currentChatHistory = setSystemRoleInChatHistory( currentLevel, - req.session.levelState[currentLevel].defences, + defences, req.session.levelState[currentLevel].chatHistory ); - const defences = [...req.session.levelState[currentLevel].defences]; - let levelResult: LevelHandlerResponse; try { if (currentLevel < LEVEL_NAMES.LEVEL_3) { @@ -256,8 +260,7 @@ async function handleChatToGPT(req: OpenAiChatRequest, res: Response) { initChatResponse, currentLevel, chatModel, - currentChatHistory, - defences + currentChatHistory ); } else { levelResult = await handleChatWithDefenceDetection( diff --git a/backend/src/controller/defenceController.ts b/backend/src/controller/defenceController.ts index 5389045c4..92dab94e1 100644 --- a/backend/src/controller/defenceController.ts +++ b/backend/src/controller/defenceController.ts @@ -7,10 +7,10 @@ import { resetDefenceConfig, } from '@src/defence'; import { DefenceActivateRequest } from '@src/models/api/DefenceActivateRequest'; -import { DefenceConfigResetRequest } from '@src/models/api/DefenceConfigResetRequest'; +import { DefenceConfigItemResetRequest } from '@src/models/api/DefenceConfigResetRequest'; import { DefenceConfigureRequest } from '@src/models/api/DefenceConfigureRequest'; import { DefenceConfigItem } from '@src/models/defence'; -import { LEVEL_NAMES } from '@src/models/level'; +import { LEVEL_NAMES, isValidLevel } from '@src/models/level'; import { sendErrorResponse } from './handleError'; @@ -24,46 +24,99 @@ function configValueExceedsCharacterLimit(config: DefenceConfigItem[]) { } function handleDefenceActivation(req: DefenceActivateRequest, res: Response) { - const defenceId = req.body.defenceId; - const level = req.body.level; - - if (defenceId && level !== undefined) { - // activate the defence - req.session.levelState[level].defences = activateDefence( - defenceId, - req.session.levelState[level].defences - ); - res.status(200).send(); - } else { - res.status(400).send(); + const { defenceId, level } = req.body; + + if (!defenceId) { + sendErrorResponse(res, 400, 'Missing defenceId'); + return; + } + + if (!Number.isFinite(level) || level === undefined) { + sendErrorResponse(res, 400, 'Missing level'); + return; + } + + if (!isValidLevel(level)) { + sendErrorResponse(res, 400, 'Invalid level'); + return; + } + + const currentDefences = req.session.levelState[level].defences; + + if (!currentDefences) { + sendErrorResponse(res, 400, 'You cannot activate defences on this level'); + return; + } + + if (!currentDefences.some((defence) => defence.id === defenceId)) { + sendErrorResponse(res, 400, `Defence with id ${defenceId} not found`); + return; } + + req.session.levelState[level].defences = activateDefence( + defenceId, + currentDefences + ); + res.status(200).send(); } function handleDefenceDeactivation(req: DefenceActivateRequest, res: Response) { - // id of the defence - const defenceId = req.body.defenceId; - const level = req.body.level; - if (defenceId && level !== undefined) { - // deactivate the defence - req.session.levelState[level].defences = deactivateDefence( - defenceId, - req.session.levelState[level].defences - ); - res.send(); - } else { - res.status(400); - res.send(); + const { defenceId, level } = req.body; + + if (!defenceId) { + sendErrorResponse(res, 400, 'Missing defenceId'); + return; } + + if (!Number.isFinite(level) || level === undefined) { + sendErrorResponse(res, 400, 'Missing level'); + return; + } + + if (!isValidLevel(level)) { + sendErrorResponse(res, 400, 'Invalid level'); + return; + } + + const currentDefences = req.session.levelState[level].defences; + + if (!currentDefences) { + sendErrorResponse(res, 400, 'You cannot deactivate defences on this level'); + return; + } + + if (!currentDefences.some((defence) => defence.id === defenceId)) { + sendErrorResponse(res, 400, `Defence with id ${defenceId} not found`); + return; + } + + req.session.levelState[level].defences = deactivateDefence( + defenceId, + currentDefences + ); + res.status(200).send(); } function handleConfigureDefence(req: DefenceConfigureRequest, res: Response) { - // id of the defence - const defenceId = req.body.defenceId; - const config = req.body.config; - const level = req.body.level; + const { defenceId, level, config } = req.body; + + if (!defenceId) { + sendErrorResponse(res, 400, 'Missing defenceId'); + return; + } + + if (!Number.isFinite(level) || level === undefined) { + sendErrorResponse(res, 400, 'Missing level'); + return; + } - if (!defenceId || !config || level === undefined) { - sendErrorResponse(res, 400, 'Missing defenceId, config or level'); + if (!isValidLevel(level)) { + sendErrorResponse(res, 400, 'Invalid level'); + return; + } + + if (!config) { + sendErrorResponse(res, 400, 'Missing config'); return; } @@ -72,48 +125,106 @@ function handleConfigureDefence(req: DefenceConfigureRequest, res: Response) { return; } - // configure the defence + const currentDefences = req.session.levelState[level].defences; + + if (!currentDefences || level !== LEVEL_NAMES.SANDBOX) { + sendErrorResponse(res, 400, 'You cannot configure defences on this level'); + return; + } + + const defence = currentDefences.find((defence) => defence.id === defenceId); + + if (defence === undefined) { + sendErrorResponse(res, 400, `Defence with id ${defenceId} not found`); + return; + } + req.session.levelState[level].defences = configureDefence( defenceId, - req.session.levelState[level].defences, + currentDefences, config ); res.send(); } -function handleResetSingleDefence( - req: DefenceConfigResetRequest, +function handleResetDefenceConfigItem( + req: DefenceConfigItemResetRequest, res: Response ) { - const defenceId = req.body.defenceId; - const configId = req.body.configId; - const level = LEVEL_NAMES.SANDBOX; //configuration only available in sandbox - if (defenceId && configId) { - req.session.levelState[level].defences = resetDefenceConfig( - defenceId, - configId, - req.session.levelState[level].defences + const { defenceId, configItemId, level } = req.body; + + if (!defenceId) { + sendErrorResponse(res, 400, 'Missing defenceId'); + return; + } + + if (!configItemId) { + sendErrorResponse(res, 400, 'Missing configId'); + return; + } + + if (!Number.isFinite(level) || level === undefined) { + sendErrorResponse(res, 400, 'Missing level'); + return; + } + + if (!isValidLevel(level)) { + sendErrorResponse(res, 400, 'Invalid level'); + return; + } + + const currentDefences = req.session.levelState[level].defences; + + if (!currentDefences || level !== LEVEL_NAMES.SANDBOX) { + sendErrorResponse( + res, + 400, + 'You cannot reset defence config items on this level' ); - const updatedDefenceConfig: DefenceConfigItem | undefined = - req.session.levelState[level].defences - .find((defence) => defence.id === defenceId) - ?.config.find((config) => config.id === configId); - - if (updatedDefenceConfig) { - res.send(updatedDefenceConfig); - } else { - res.status(400); - res.send(); - } - } else { - res.status(400); - res.send(); + return; } + + const defence = currentDefences.find((defence) => defence.id === defenceId); + + if (defence === undefined) { + sendErrorResponse(res, 400, `Defence with id ${defenceId} not found`); + return; + } + + const configItem = defence.config.find( + (configItem) => configItem.id === configItemId + ); + + if (configItem === undefined) { + sendErrorResponse( + res, + 400, + `Config item with id ${configItemId} not found for defence with id ${defenceId}` + ); + return; + } + + req.session.levelState[level].defences = resetDefenceConfig( + defenceId, + configItemId, + currentDefences + ); + + const updatedDefences = req.session.levelState[level].defences; + if (updatedDefences === undefined) { + sendErrorResponse(res, 500, 'Something went whacky'); + return; + } + const updatedDefenceConfig = updatedDefences + .find((defence) => defence.id === defenceId) + ?.config.find((config) => config.id === configItemId); + + res.send(updatedDefenceConfig); } export { handleDefenceActivation, handleDefenceDeactivation, handleConfigureDefence, - handleResetSingleDefence, + handleResetDefenceConfigItem, }; diff --git a/backend/src/models/api/DefenceConfigResetRequest.ts b/backend/src/models/api/DefenceConfigResetRequest.ts index 8da066f7b..241f48bb2 100644 --- a/backend/src/models/api/DefenceConfigResetRequest.ts +++ b/backend/src/models/api/DefenceConfigResetRequest.ts @@ -5,13 +5,15 @@ import { DEFENCE_ID, DefenceConfigItem, } from '@src/models/defence'; +import { LEVEL_NAMES } from '@src/models/level'; -export type DefenceConfigResetRequest = Request< +export type DefenceConfigItemResetRequest = Request< never, DefenceConfigItem, { defenceId?: DEFENCE_ID; - configId?: DEFENCE_CONFIG_ITEM_ID; + configItemId?: DEFENCE_CONFIG_ITEM_ID; + level?: LEVEL_NAMES; }, never >; diff --git a/backend/src/models/defence.ts b/backend/src/models/defence.ts index bae137c86..3875a4db8 100644 --- a/backend/src/models/defence.ts +++ b/backend/src/models/defence.ts @@ -23,6 +23,10 @@ type DefenceConfigItem = { value: string; }; +type QaLlmDefence = Defence & { + id: DEFENCE_ID.QA_LLM; +}; + type Defence = { id: DEFENCE_ID; config: DefenceConfigItem[]; @@ -30,4 +34,9 @@ type Defence = { }; export { DEFENCE_ID }; -export type { Defence, DefenceConfigItem, DEFENCE_CONFIG_ITEM_ID }; +export type { + Defence, + DefenceConfigItem, + DEFENCE_CONFIG_ITEM_ID, + QaLlmDefence, +}; diff --git a/backend/src/models/level.ts b/backend/src/models/level.ts index 59f818e19..636f6b90c 100644 --- a/backend/src/models/level.ts +++ b/backend/src/models/level.ts @@ -20,17 +20,24 @@ function isValidLevel(levelValue: unknown) { interface LevelState { level: LEVEL_NAMES; chatHistory: ChatMessage[]; - defences: Defence[]; + defences?: Defence[]; sentEmails: EmailInfo[]; } function getInitialLevelStates() { + const levelsWithDefences: number[] = [ + LEVEL_NAMES.LEVEL_3, + LEVEL_NAMES.SANDBOX, + ]; + return Object.values(LEVEL_NAMES).map( (level) => ({ level, chatHistory: [], - defences: defaultDefences, + defences: levelsWithDefences.includes(level) + ? defaultDefences + : undefined, sentEmails: [], } as LevelState) ); diff --git a/backend/src/openai.ts b/backend/src/openai.ts index a4a5433c2..663953e1c 100644 --- a/backend/src/openai.ts +++ b/backend/src/openai.ts @@ -5,7 +5,7 @@ import { ChatCompletionMessageToolCall, } from 'openai/resources/chat/completions'; -import { isDefenceActive, getQAPromptFromConfig } from './defence'; +import { getQAPromptFromConfig } from './defence'; import { sendEmail } from './email'; import { queryDocuments } from './langchain'; import { @@ -17,7 +17,7 @@ import { ToolCallResponse, } from './models/chat'; import { ChatMessage } from './models/chatMessage'; -import { DEFENCE_ID, Defence } from './models/defence'; +import { QaLlmDefence } from './models/defence'; import { EmailResponse } from './models/email'; import { LEVEL_NAMES } from './models/level'; import { @@ -141,14 +141,14 @@ function isChatGptFunction(functionName: string) { async function handleAskQuestionFunction( functionCallArgs: string | undefined, currentLevel: LEVEL_NAMES, - defences: Defence[] + qaLlmDefence?: QaLlmDefence ) { if (functionCallArgs) { const params = JSON.parse(functionCallArgs) as FunctionAskQuestionParams; console.debug(`Asking question: ${params.question}`); // if asking a question, call the queryDocuments - const configQAPrompt = isDefenceActive(DEFENCE_ID.QA_LLM, defences) - ? getQAPromptFromConfig(defences) + const configQAPrompt = qaLlmDefence?.isActive + ? getQAPromptFromConfig([qaLlmDefence]) : ''; return await queryDocuments(params.question, configQAPrompt, currentLevel); } else { @@ -191,11 +191,11 @@ function handleSendEmailFunction( } async function chatGptCallFunction( - defences: Defence[], toolCallId: string, functionCall: ChatCompletionMessageToolCall.Function, // default to sandbox - currentLevel: LEVEL_NAMES = LEVEL_NAMES.SANDBOX + currentLevel: LEVEL_NAMES = LEVEL_NAMES.SANDBOX, + qaLlmDefence?: QaLlmDefence ): Promise { const functionName = functionCall.name; let functionReply = ''; @@ -220,7 +220,7 @@ async function chatGptCallFunction( functionReply = await handleAskQuestionFunction( functionCall.arguments, currentLevel, - defences + qaLlmDefence ); } } else { @@ -326,18 +326,18 @@ function getChatCompletionsInContextWindow( async function performToolCalls( toolCalls: ChatCompletionMessageToolCall[], chatHistory: ChatMessage[], - defences: Defence[], - currentLevel: LEVEL_NAMES + currentLevel: LEVEL_NAMES, + qaLlmDefence?: QaLlmDefence ): Promise { for (const toolCall of toolCalls) { // only tool type supported by openai is function // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition if (toolCall.type === 'function') { const functionCallReply = await chatGptCallFunction( - defences, toolCall.id, toolCall.function, - currentLevel + currentLevel, + qaLlmDefence ); // We assume only one function call in toolCalls, and so we return after getting function reply. @@ -358,9 +358,9 @@ async function performToolCalls( async function getFinalReplyAfterAllToolCalls( chatHistory: ChatMessage[], - defences: Defence[], chatModel: ChatModel, - currentLevel: LEVEL_NAMES + currentLevel: LEVEL_NAMES, + qaLlmDefence?: QaLlmDefence ) { let updatedChatHistory = [...chatHistory]; const sentEmails = []; @@ -385,8 +385,8 @@ async function getFinalReplyAfterAllToolCalls( const toolCallReply = await performToolCalls( gptReply.completion.tool_calls, updatedChatHistory, - defences, - currentLevel + currentLevel, + qaLlmDefence ); updatedChatHistory = toolCallReply.chatHistory; @@ -408,17 +408,17 @@ async function getFinalReplyAfterAllToolCalls( async function chatGptSendMessage( chatHistory: ChatMessage[], - defences: Defence[], chatModel: ChatModel, - currentLevel: LEVEL_NAMES = LEVEL_NAMES.SANDBOX + currentLevel: LEVEL_NAMES = LEVEL_NAMES.SANDBOX, + qaLlmDefence?: QaLlmDefence ) { // this method just calls getFinalReplyAfterAllToolCalls then reformats the output. Does it need to exist? const finalToolCallResponse = await getFinalReplyAfterAllToolCalls( chatHistory, - defences, chatModel, - currentLevel + currentLevel, + qaLlmDefence ); const chatResponse: ChatResponse = { diff --git a/backend/src/sessionRoutes.ts b/backend/src/sessionRoutes.ts index 1949bd1c9..ef01632dc 100644 --- a/backend/src/sessionRoutes.ts +++ b/backend/src/sessionRoutes.ts @@ -12,7 +12,7 @@ import { handleConfigureDefence, handleDefenceActivation, handleDefenceDeactivation, - handleResetSingleDefence, + handleResetDefenceConfigItem, } from './controller/defenceController'; import { handleClearEmails } from './controller/emailController'; import { handleLoadLevel } from './controller/levelController'; @@ -91,7 +91,7 @@ router.get('/level', handleLoadLevel); router.post('/defence/activate', handleDefenceActivation); router.post('/defence/deactivate', handleDefenceDeactivation); router.post('/defence/configure', handleConfigureDefence); -router.post('/defence/resetConfig', handleResetSingleDefence); +router.post('/defence/resetConfig', handleResetDefenceConfigItem); // emails router.post('/email/clear', handleClearEmails); diff --git a/backend/test/integration/defenceController.test.ts b/backend/test/integration/defenceController.test.ts new file mode 100644 index 000000000..fedba817a --- /dev/null +++ b/backend/test/integration/defenceController.test.ts @@ -0,0 +1,270 @@ +import { expect, test, jest, describe } from '@jest/globals'; +import { Response } from 'express'; + +import { + handleConfigureDefence, + handleDefenceActivation, + handleDefenceDeactivation, + handleResetDefenceConfigItem, +} from '@src/controller/defenceController'; +import { DefenceActivateRequest } from '@src/models/api/DefenceActivateRequest'; +import { DefenceConfigItemResetRequest } from '@src/models/api/DefenceConfigResetRequest'; +import { DefenceConfigureRequest } from '@src/models/api/DefenceConfigureRequest'; +import { DEFENCE_ID } from '@src/models/defence'; +import { LEVEL_NAMES, getInitialLevelStates } from '@src/models/level'; + +function responseMock() { + return { + send: jest.fn(), + status: jest.fn().mockReturnThis(), + } as unknown as Response; +} + +describe('The correct levels can have their defences changed', () => { + describe('activate defence', () => { + test.each([LEVEL_NAMES.LEVEL_1, LEVEL_NAMES.LEVEL_2])( + `GIVEN level [%s] WHEN attempt to activate a defence THEN defence is not activated`, + (level) => { + const req = { + body: { + defenceId: DEFENCE_ID.CHARACTER_LIMIT, + level, + }, + session: { + levelState: getInitialLevelStates(), + }, + } as unknown as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceActivation(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + 'You cannot activate defences on this level' + ); + } + ); + + test.each([LEVEL_NAMES.LEVEL_3, LEVEL_NAMES.SANDBOX])( + `GIVEN level [%s] WHEN attempt to activate a defence THEN defence is activated`, + (level) => { + const req = { + body: { + defenceId: DEFENCE_ID.CHARACTER_LIMIT, + level, + }, + session: { + levelState: getInitialLevelStates(), + }, + } as unknown as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceActivation(req, res); + + const newLevelState = getInitialLevelStates().map((levelState) => + levelState.level === level + ? { + ...levelState, + defences: levelState.defences?.map((defence) => + defence.id === DEFENCE_ID.CHARACTER_LIMIT + ? { ...defence, isActive: true } + : defence + ), + } + : levelState + ); + + expect(res.status).toHaveBeenCalledWith(200); + expect(res.send).toHaveBeenCalled(); + expect(req.session.levelState).toEqual(newLevelState); + } + ); + }); + + describe('deactivate defence', () => { + test.each([LEVEL_NAMES.LEVEL_1, LEVEL_NAMES.LEVEL_2])( + `GIVEN level [%s] WHEN attempt to deactivate a defence THEN defence is not activated`, + (level) => { + const req = { + body: { + defenceId: DEFENCE_ID.CHARACTER_LIMIT, + level, + }, + session: { + levelState: getInitialLevelStates(), + }, + } as unknown as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceDeactivation(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + 'You cannot deactivate defences on this level' + ); + } + ); + + test.each([LEVEL_NAMES.LEVEL_3, LEVEL_NAMES.SANDBOX])( + `GIVEN level [%s] WHEN attempt to deactivate a defence THEN defence is deactivated`, + (level) => { + const initialLevelStatesButWithCharacterLimitActive = + getInitialLevelStates().map((levelState) => + levelState.level === level + ? { + ...levelState, + defences: levelState.defences?.map((defence) => + defence.id === DEFENCE_ID.CHARACTER_LIMIT + ? { ...defence, isActive: true } + : defence + ), + } + : levelState + ); + + const req = { + body: { + defenceId: DEFENCE_ID.CHARACTER_LIMIT, + level, + }, + session: { + levelState: initialLevelStatesButWithCharacterLimitActive, + }, + } as unknown as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceDeactivation(req, res); + + expect(res.status).toHaveBeenCalledWith(200); + expect(res.send).toHaveBeenCalled(); + expect(req.session.levelState).toEqual(getInitialLevelStates()); + } + ); + }); + + describe('configure a defence', () => { + test.each([LEVEL_NAMES.LEVEL_1, LEVEL_NAMES.LEVEL_2, LEVEL_NAMES.LEVEL_3])( + `GIVEN level [%s] WHEN attempt to configure a defence THEN defence is not configured`, + (level) => { + const req = { + body: { + defenceId: DEFENCE_ID.CHARACTER_LIMIT, + level, + config: [], + }, + session: { + levelState: getInitialLevelStates(), + }, + } as unknown as DefenceConfigureRequest; + + const res = responseMock(); + + handleConfigureDefence(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + 'You cannot configure defences on this level' + ); + } + ); + + test(`GIVEN level sandbox WHEN attempt to configure a defence THEN defence is configured`, () => { + const req = { + body: { + defenceId: DEFENCE_ID.CHARACTER_LIMIT, + level: LEVEL_NAMES.SANDBOX, + config: [{ id: 'MAX_MESSAGE_LENGTH', value: '1' }], + }, + session: { + levelState: getInitialLevelStates(), + }, + } as DefenceConfigureRequest; + + const res = responseMock(); + + handleConfigureDefence(req, res); + + const updatedDefenceConfig = req.session.levelState + .find((levelState) => levelState.level === LEVEL_NAMES.SANDBOX) + ?.defences?.find( + (defence) => defence.id === DEFENCE_ID.CHARACTER_LIMIT + )?.config; + + const expectedDefenceConfig = [{ id: 'MAX_MESSAGE_LENGTH', value: '1' }]; + + expect(res.send).toHaveBeenCalled(); + expect(updatedDefenceConfig).toEqual(expectedDefenceConfig); + }); + }); + + describe("reset a defence's config item", () => { + test.each([LEVEL_NAMES.LEVEL_1, LEVEL_NAMES.LEVEL_2, LEVEL_NAMES.LEVEL_3])( + `GIVEN level [%s] WHEN attempt to reset a defence config item THEN defence config item is not reset`, + (level) => { + const req = { + body: { + defenceId: DEFENCE_ID.CHARACTER_LIMIT, + level, + configItemId: 'MAX_MESSAGE_LENGTH', + }, + session: { + levelState: getInitialLevelStates(), + }, + } as DefenceConfigItemResetRequest; + + const res = responseMock(); + + handleResetDefenceConfigItem(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + 'You cannot reset defence config items on this level' + ); + } + ); + + test(`GIVEN level Sandbox WHEN attempt to reset a defence config item THEN defence config item is reset`, () => { + const initialLevelStatesButWithCharacterLimitConfigured = + getInitialLevelStates().map((levelState) => + levelState.level === LEVEL_NAMES.SANDBOX + ? { + ...levelState, + defences: levelState.defences?.map((defence) => + defence.id === DEFENCE_ID.CHARACTER_LIMIT + ? { + ...defence, + config: [{ id: 'MAX_MESSAGE_LENGTH', value: '1' }], + } + : defence + ), + } + : levelState + ); + + const req = { + body: { + defenceId: DEFENCE_ID.CHARACTER_LIMIT, + level: LEVEL_NAMES.SANDBOX, + configItemId: 'MAX_MESSAGE_LENGTH', + }, + session: { + levelState: initialLevelStatesButWithCharacterLimitConfigured, + }, + } as DefenceConfigItemResetRequest; + + const res = responseMock(); + + handleResetDefenceConfigItem(req, res); + + expect(req.session.levelState).toEqual(getInitialLevelStates()); + expect(res.send).toHaveBeenCalledWith({ + id: 'MAX_MESSAGE_LENGTH', + value: '280', + }); + }); + }); +}); diff --git a/backend/test/integration/openai.test.ts b/backend/test/integration/openai.test.ts index be1d5e5cf..840b0c724 100644 --- a/backend/test/integration/openai.test.ts +++ b/backend/test/integration/openai.test.ts @@ -1,9 +1,7 @@ import { expect, jest, test, describe } from '@jest/globals'; -import { defaultDefences } from '@src/defaultDefences'; import { CHAT_MODELS, ChatModel } from '@src/models/chat'; import { ChatMessage } from '@src/models/chatMessage'; -import { Defence } from '@src/models/defence'; import { chatGptSendMessage } from '@src/openai'; const mockCreateChatCompletion = @@ -55,7 +53,6 @@ describe('OpenAI Integration Tests', () => { }, }, ]; - const defences: Defence[] = defaultDefences; const chatModel: ChatModel = { id: CHAT_MODELS.GPT_4, configuration: { @@ -68,11 +65,7 @@ describe('OpenAI Integration Tests', () => { mockCreateChatCompletion.mockResolvedValueOnce(chatResponseAssistant('Hi')); - const reply = await chatGptSendMessage( - chatHistoryWithMessage, - defences, - chatModel - ); + const reply = await chatGptSendMessage(chatHistoryWithMessage, chatModel); expect(reply).toBeDefined(); expect(reply.chatResponse.completion).toBeDefined(); diff --git a/backend/test/unit/controller/chatController.test.ts b/backend/test/unit/controller/chatController.test.ts index 308556406..f040ca76a 100644 --- a/backend/test/unit/controller/chatController.test.ts +++ b/backend/test/unit/controller/chatController.test.ts @@ -16,7 +16,7 @@ import { MessageTransformation, } from '@src/models/chat'; import { ChatMessage } from '@src/models/chatMessage'; -import { DEFENCE_ID, Defence } from '@src/models/defence'; +import { DEFENCE_ID, Defence, QaLlmDefence } from '@src/models/defence'; import { EmailInfo } from '@src/models/email'; import { LEVEL_NAMES, LevelState } from '@src/models/level'; import { chatGptSendMessage } from '@src/openai'; @@ -497,7 +497,6 @@ describe('handleChatToGPT unit tests', () => { expect(mockChatGptSendMessage).toHaveBeenCalledWith( [...existingHistory, newUserChatMessage], - [], mockChatModel, LEVEL_NAMES.LEVEL_1 ); @@ -558,10 +557,18 @@ describe('handleChatToGPT unit tests', () => { }, } as ChatMessage; + const qaLllmDefence = { + id: DEFENCE_ID.QA_LLM, + isActive: true, + config: [{ id: 'PROMPT', value: 'query them documents!' }], + } as QaLlmDefence; + const req = openAiChatRequestMock( 'send an email to bob@example.com saying hi', LEVEL_NAMES.SANDBOX, - existingHistory + existingHistory, + [], + [qaLllmDefence] ); const res = responseMock(); @@ -590,9 +597,9 @@ describe('handleChatToGPT unit tests', () => { expect(mockChatGptSendMessage).toHaveBeenCalledWith( [...existingHistory, newUserChatMessage], - [], mockChatModel, - LEVEL_NAMES.SANDBOX + LEVEL_NAMES.SANDBOX, + qaLllmDefence ); expect(res.send).toHaveBeenCalledWith({ @@ -655,10 +662,18 @@ describe('handleChatToGPT unit tests', () => { }, } as ChatMessage; + const qaLllmDefence = { + id: DEFENCE_ID.QA_LLM, + isActive: true, + config: [{ id: 'PROMPT', value: 'query them documents!' }], + } as QaLlmDefence; + const req = openAiChatRequestMock( 'hello bot', LEVEL_NAMES.SANDBOX, - existingHistory + existingHistory, + [], + [qaLllmDefence] ); const res = responseMock(); @@ -690,9 +705,9 @@ describe('handleChatToGPT unit tests', () => { expect(mockChatGptSendMessage).toHaveBeenCalledWith( [...existingHistory, ...newTransformationChatMessages], - [], mockChatModel, - LEVEL_NAMES.SANDBOX + LEVEL_NAMES.SANDBOX, + qaLllmDefence ); expect(res.send).toHaveBeenCalledWith({ diff --git a/backend/test/unit/controller/defenceController.test.ts b/backend/test/unit/controller/defenceController.test.ts index 753bb0907..d4e406c0d 100644 --- a/backend/test/unit/controller/defenceController.test.ts +++ b/backend/test/unit/controller/defenceController.test.ts @@ -1,18 +1,31 @@ import { expect, test, jest, describe } from '@jest/globals'; import { Response } from 'express'; -import { handleConfigureDefence } from '@src/controller/defenceController'; -import { configureDefence } from '@src/defence'; +import { + handleConfigureDefence, + handleDefenceActivation, + handleDefenceDeactivation, + handleResetDefenceConfigItem, +} from '@src/controller/defenceController'; +import { + activateDefence, + configureDefence, + deactivateDefence, + resetDefenceConfig, +} from '@src/defence'; +import { DefenceActivateRequest } from '@src/models/api/DefenceActivateRequest'; +import { DefenceConfigItemResetRequest } from '@src/models/api/DefenceConfigResetRequest'; import { DefenceConfigureRequest } from '@src/models/api/DefenceConfigureRequest'; import { ChatMessage } from '@src/models/chatMessage'; -import { DEFENCE_ID, Defence } from '@src/models/defence'; +import { + DEFENCE_CONFIG_ITEM_ID, + DEFENCE_ID, + Defence, +} from '@src/models/defence'; import { EmailInfo } from '@src/models/email'; import { LEVEL_NAMES } from '@src/models/level'; jest.mock('@src/defence'); -const mockConfigureDefence = configureDefence as jest.MockedFunction< - typeof configureDefence ->; function responseMock() { return { @@ -21,12 +34,346 @@ function responseMock() { } as unknown as Response; } -describe('handleConfigureDefence', () => { - test('WHEN passed a sensible config value THEN configures defences', () => { - const req: DefenceConfigureRequest = { +describe('handleDefenceActivation', () => { + test('WHEN passed sensible parameters THEN activates defence', () => { + const mockActivateDefence = activateDefence as jest.MockedFunction< + typeof activateDefence + >; + + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + level: LEVEL_NAMES.SANDBOX, + }, + session: { + levelState: [ + {}, + {}, + {}, + { + level: LEVEL_NAMES.SANDBOX, + defences: [ + { + id: DEFENCE_ID.PROMPT_EVALUATION_LLM, + isActive: false, + config: [], + }, + ] as Defence[], + }, + ], + }, + } as DefenceActivateRequest; + + const configuredDefences: Defence[] = [ + { + id: 'PROMPT_EVALUATION_LLM', + config: [], + isActive: true, + } as Defence, + ]; + mockActivateDefence.mockReturnValueOnce(configuredDefences); + + handleDefenceActivation(req, responseMock()); + + expect(mockActivateDefence).toHaveBeenCalledTimes(1); + expect(mockActivateDefence).toHaveBeenCalledWith( + DEFENCE_ID.PROMPT_EVALUATION_LLM, + [ + { + id: DEFENCE_ID.PROMPT_EVALUATION_LLM, + isActive: false, + config: [], + } as Defence, + ] + ); + expect(req.session.levelState[LEVEL_NAMES.SANDBOX].defences).toEqual( + configuredDefences + ); + }); + + test('WHEN missing defenceId THEN does not activate defence', () => { + const req = { + body: { + level: LEVEL_NAMES.SANDBOX, + }, + } as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceActivation(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Missing defenceId'); + }); + + test('WHEN missing level THEN does not activate defence', () => { + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + }, + } as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceActivation(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Missing level'); + }); + + test('WHEN level is invalid THEN does not activate defence', () => { + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + level: 5 as LEVEL_NAMES, + }, + } as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceActivation(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Invalid level'); + }); + + test('WHEN level does not have configurable defences THEN does not activate defence', () => { + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + level: LEVEL_NAMES.LEVEL_1, + }, + session: { + levelState: [ + { + level: LEVEL_NAMES.LEVEL_1, + defences: undefined, + }, + ], + }, + } as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceActivation(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + 'You cannot activate defences on this level' + ); + }); + + test('WHEN defence does not exist THEN does not activate defence', () => { + const req = { + body: { + defenceId: 'badDefenceID' as DEFENCE_ID, + level: LEVEL_NAMES.SANDBOX, + }, + session: { + levelState: [ + {}, + {}, + {}, + { + level: LEVEL_NAMES.SANDBOX, + defences: [ + { + id: DEFENCE_ID.PROMPT_EVALUATION_LLM, + isActive: true, + config: [], + }, + ] as Defence[], + }, + ], + }, + } as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceActivation(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + 'Defence with id badDefenceID not found' + ); + }); +}); + +describe('handleDefenceDeactivation', () => { + test('WHEN passed sensible parameters THEN deactivates defence', () => { + const mockDeactivateDefence = deactivateDefence as jest.MockedFunction< + typeof deactivateDefence + >; + + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + level: LEVEL_NAMES.SANDBOX, + }, + session: { + levelState: [ + {}, + {}, + {}, + { + level: LEVEL_NAMES.SANDBOX, + defences: [ + { + id: DEFENCE_ID.PROMPT_EVALUATION_LLM, + isActive: true, + config: [], + }, + ] as Defence[], + }, + ], + }, + } as DefenceActivateRequest; + + const configuredDefences: Defence[] = [ + { + id: 'PROMPT_EVALUATION_LLM', + config: [], + isActive: false, + } as Defence, + ]; + mockDeactivateDefence.mockReturnValueOnce(configuredDefences); + + handleDefenceDeactivation(req, responseMock()); + + expect(mockDeactivateDefence).toHaveBeenCalledTimes(1); + expect(mockDeactivateDefence).toHaveBeenCalledWith( + DEFENCE_ID.PROMPT_EVALUATION_LLM, + [ + { + id: DEFENCE_ID.PROMPT_EVALUATION_LLM, + isActive: true, + config: [], + } as Defence, + ] + ); + expect(req.session.levelState[LEVEL_NAMES.SANDBOX].defences).toEqual( + configuredDefences + ); + }); + + test('WHEN missing defenceId THEN does not deactivate defence', () => { + const req = { + body: { + level: LEVEL_NAMES.SANDBOX, + }, + } as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceDeactivation(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Missing defenceId'); + }); + + test('WHEN missing level THEN does not deactivate defence', () => { + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + }, + } as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceDeactivation(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Missing level'); + }); + + test('WHEN level is invalid THEN does not deactivate defence', () => { + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + level: 5 as LEVEL_NAMES, + }, + } as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceDeactivation(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Invalid level'); + }); + + test('WHEN level does not have configurable defences THEN does not deactivate defence', () => { + const req = { body: { defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, level: LEVEL_NAMES.LEVEL_1, + }, + session: { + levelState: [ + { + level: LEVEL_NAMES.LEVEL_1, + defences: undefined, + }, + ], + }, + } as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceDeactivation(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + 'You cannot deactivate defences on this level' + ); + }); + + test('WHEN defence does not exist THEN does not deactivate defence', () => { + const req = { + body: { + defenceId: 'badDefenceID' as DEFENCE_ID, + level: LEVEL_NAMES.SANDBOX, + }, + session: { + levelState: [ + {}, + {}, + {}, + { + level: LEVEL_NAMES.SANDBOX, + defences: [ + { + id: DEFENCE_ID.PROMPT_EVALUATION_LLM, + isActive: true, + config: [], + }, + ] as Defence[], + }, + ], + }, + } as DefenceActivateRequest; + + const res = responseMock(); + + handleDefenceDeactivation(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + 'Defence with id badDefenceID not found' + ); + }); +}); + +describe('handleConfigureDefence', () => { + test('WHEN passed a sensible config value THEN configures defence', () => { + const mockConfigureDefence = configureDefence as jest.MockedFunction< + typeof configureDefence + >; + + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + level: LEVEL_NAMES.SANDBOX, config: [ { id: 'PROMPT', @@ -36,11 +383,20 @@ describe('handleConfigureDefence', () => { }, session: { levelState: [ + {}, + {}, + {}, { - level: LEVEL_NAMES.LEVEL_1, + level: LEVEL_NAMES.SANDBOX, chatHistory: [] as ChatMessage[], sentEmails: [] as EmailInfo[], - defences: [] as Defence[], + defences: [ + { + id: DEFENCE_ID.PROMPT_EVALUATION_LLM, + isActive: true, + config: [], + }, + ] as Defence[], }, ], }, @@ -61,7 +417,13 @@ describe('handleConfigureDefence', () => { expect(mockConfigureDefence).toHaveBeenCalledTimes(1); expect(mockConfigureDefence).toHaveBeenCalledWith( DEFENCE_ID.PROMPT_EVALUATION_LLM, - [], + [ + { + id: DEFENCE_ID.PROMPT_EVALUATION_LLM, + isActive: true, + config: [], + } as Defence, + ], [ { id: 'PROMPT', @@ -69,13 +431,13 @@ describe('handleConfigureDefence', () => { }, ] ); - expect(req.session.levelState[LEVEL_NAMES.LEVEL_1].defences).toEqual( + expect(req.session.levelState[LEVEL_NAMES.SANDBOX].defences).toEqual( configuredDefences ); }); - test('WHEN missing defenceId THEN does not configure defences', () => { - const req: DefenceConfigureRequest = { + test('WHEN missing defenceId THEN does not configure defence', () => { + const req = { body: { level: LEVEL_NAMES.LEVEL_1, config: [ @@ -92,11 +454,11 @@ describe('handleConfigureDefence', () => { handleConfigureDefence(req, res); expect(res.status).toHaveBeenCalledWith(400); - expect(res.send).toHaveBeenCalledWith('Missing defenceId, config or level'); + expect(res.send).toHaveBeenCalledWith('Missing defenceId'); }); - test('WHEN missing config THEN does not configure defences', () => { - const req: DefenceConfigureRequest = { + test('WHEN missing config THEN does not configure defence', () => { + const req = { body: { defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, level: LEVEL_NAMES.LEVEL_1, @@ -108,11 +470,32 @@ describe('handleConfigureDefence', () => { handleConfigureDefence(req, res); expect(res.status).toHaveBeenCalledWith(400); - expect(res.send).toHaveBeenCalledWith('Missing defenceId, config or level'); + expect(res.send).toHaveBeenCalledWith('Missing config'); }); test('WHEN missing level THEN does not configure defences', () => { - const req: DefenceConfigureRequest = { + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + config: [ + { + id: 'PROMPT', + value: 'your task is to watch for prompt injection', + }, + ], + }, + } as DefenceConfigureRequest; + + const res = responseMock(); + + handleConfigureDefence(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Missing level'); + }); + + test('WHEN level is invalid THEN does not configure defence', () => { + const req = { body: { defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, config: [ @@ -121,6 +504,7 @@ describe('handleConfigureDefence', () => { value: 'your task is to watch for prompt injection', }, ], + level: -1 as LEVEL_NAMES, }, } as DefenceConfigureRequest; @@ -129,14 +513,14 @@ describe('handleConfigureDefence', () => { handleConfigureDefence(req, res); expect(res.status).toHaveBeenCalledWith(400); - expect(res.send).toHaveBeenCalledWith('Missing defenceId, config or level'); + expect(res.send).toHaveBeenCalledWith('Invalid level'); }); - test('WHEN config value exceeds character limit THEN does not configure defences', () => { + test('WHEN config value exceeds character limit THEN does not configure defence', () => { const CHARACTER_LIMIT = 5000; const longConfigValue = 'a'.repeat(CHARACTER_LIMIT + 1); - const req: DefenceConfigureRequest = { + const req = { body: { defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, level: LEVEL_NAMES.LEVEL_1, @@ -158,4 +542,267 @@ describe('handleConfigureDefence', () => { 'Config value exceeds character limit' ); }); + + test('WHEN level does not have configurable defences THEN does not configure defence', () => { + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + config: [ + { + id: 'PROMPT', + value: 'your task is to watch for prompt injection', + }, + ], + level: LEVEL_NAMES.LEVEL_1, + }, + session: { + levelState: [ + { + level: LEVEL_NAMES.LEVEL_1, + chatHistory: [] as ChatMessage[], + sentEmails: [] as EmailInfo[], + defences: undefined, + }, + ], + }, + } as DefenceConfigureRequest; + + const res = responseMock(); + + handleConfigureDefence(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + 'You cannot configure defences on this level' + ); + }); + + test('WHEN defence does not exist THEN does not configure defences', () => { + const req = { + body: { + defenceId: 'badDefenceID' as DEFENCE_ID, + config: [ + { + id: 'PROMPT', + value: 'your task is to watch for prompt injection', + }, + ], + level: LEVEL_NAMES.SANDBOX, + }, + session: { + levelState: [ + {}, + {}, + {}, + { + level: LEVEL_NAMES.SANDBOX, + chatHistory: [] as ChatMessage[], + sentEmails: [] as EmailInfo[], + defences: [] as Defence[], + }, + ], + }, + } as DefenceConfigureRequest; + + const res = responseMock(); + + handleConfigureDefence(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + `Defence with id badDefenceID not found` + ); + }); +}); + +describe('handleResetDefenceConfigItem', () => { + test('WHEN passed sensible parameters THEN resets config item', () => { + const mockResetDefenceConfig = resetDefenceConfig as jest.MockedFunction< + typeof resetDefenceConfig + >; + + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + configItemId: 'PROMPT', + level: LEVEL_NAMES.SANDBOX, + }, + session: { + levelState: [ + {}, + {}, + {}, + { + level: LEVEL_NAMES.SANDBOX, + chatHistory: [] as ChatMessage[], + sentEmails: [] as EmailInfo[], + defences: [ + { + id: DEFENCE_ID.PROMPT_EVALUATION_LLM, + isActive: true, + config: [{ id: 'PROMPT', value: 'old value' }], + }, + ] as Defence[], + }, + ], + }, + } as DefenceConfigItemResetRequest; + + const configuredDefences: Defence[] = [ + { + id: 'PROMPT_EVALUATION_LLM', + config: [{ id: 'PROMPT', value: 'reset value' }], + } as Defence, + ]; + mockResetDefenceConfig.mockReturnValueOnce(configuredDefences); + + handleResetDefenceConfigItem(req, responseMock()); + + expect(mockResetDefenceConfig).toHaveBeenCalledTimes(1); + expect(mockResetDefenceConfig).toHaveBeenCalledWith( + DEFENCE_ID.PROMPT_EVALUATION_LLM, + 'PROMPT', + [ + { + id: DEFENCE_ID.PROMPT_EVALUATION_LLM, + isActive: true, + config: [{ id: 'PROMPT', value: 'old value' }], + } as Defence, + ] + ); + expect(req.session.levelState[LEVEL_NAMES.SANDBOX].defences).toEqual( + configuredDefences + ); + }); + + test('WHEN missing defenceId THEN does not reset config item', () => { + const req = { + body: { + configItemId: 'PROMPT', + level: LEVEL_NAMES.SANDBOX, + }, + } as DefenceConfigItemResetRequest; + + const res = responseMock(); + + handleResetDefenceConfigItem(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Missing defenceId'); + }); + + test('WHEN missing config item id THEN does not reset config item', () => { + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + level: LEVEL_NAMES.SANDBOX, + }, + } as DefenceConfigItemResetRequest; + + const res = responseMock(); + + handleResetDefenceConfigItem(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Missing configId'); + }); + + test('WHEN missing level THEN does not reset config item', () => { + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + configItemId: 'PROMPT', + }, + } as DefenceConfigItemResetRequest; + + const res = responseMock(); + + handleResetDefenceConfigItem(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Missing level'); + }); + + test('WHEN invalid level THEN does not reset config item', () => { + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + configItemId: 'PROMPT', + level: 5 as LEVEL_NAMES, + }, + } as DefenceConfigItemResetRequest; + + const res = responseMock(); + + handleResetDefenceConfigItem(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith('Invalid level'); + }); + + test('WHEN defence does not exist THEN does not reset config item', () => { + const req = { + body: { + defenceId: 'badDefenceID' as DEFENCE_ID, + configItemId: 'PROMPT', + level: LEVEL_NAMES.SANDBOX, + }, + session: { + levelState: [ + {}, + {}, + {}, + { + level: LEVEL_NAMES.SANDBOX, + defences: [] as Defence[], + }, + ], + }, + } as DefenceConfigItemResetRequest; + + const res = responseMock(); + + handleResetDefenceConfigItem(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + `Defence with id badDefenceID not found` + ); + }); + + test('WHEN config item id does not exist for given defence THEN does not reset config item', () => { + const req = { + body: { + defenceId: DEFENCE_ID.PROMPT_EVALUATION_LLM, + configItemId: 'badConfigItemId' as DEFENCE_CONFIG_ITEM_ID, + level: LEVEL_NAMES.SANDBOX, + }, + session: { + levelState: [ + {}, + {}, + {}, + { + level: LEVEL_NAMES.SANDBOX, + defences: [ + { + id: DEFENCE_ID.PROMPT_EVALUATION_LLM, + isActive: true, + config: [{ id: 'PROMPT', value: 'old value' }], + }, + ] as Defence[], + }, + ], + }, + } as DefenceConfigItemResetRequest; + + const res = responseMock(); + + handleResetDefenceConfigItem(req, res); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.send).toHaveBeenCalledWith( + `Config item with id badConfigItemId not found for defence with id ${DEFENCE_ID.PROMPT_EVALUATION_LLM}` + ); + }); }); diff --git a/backend/test/unit/controller/levelController.test.ts b/backend/test/unit/controller/levelController.test.ts index 7e898610b..4652a0582 100644 --- a/backend/test/unit/controller/levelController.test.ts +++ b/backend/test/unit/controller/levelController.test.ts @@ -24,15 +24,9 @@ afterEach(() => { mockSend.mockClear(); }); -[ - LEVEL_NAMES.LEVEL_1, - LEVEL_NAMES.LEVEL_2, - LEVEL_NAMES.LEVEL_3, - LEVEL_NAMES.SANDBOX, -].forEach((level) => { - test(`GIVEN level ${ - level + 1 - } WHEN client asks to load the level THEN the backend sends the correct level information`, () => { +test.each(Object.values(LEVEL_NAMES))( + `GIVEN level [%s] WHEN client asks to load the level THEN the backend sends the correct level information`, + (level) => { const req = { query: { level, @@ -71,8 +65,8 @@ afterEach(() => { chatHistory: `level ${level + 1} chat history`, defences: `level ${level + 1} defences`, }); - }); -}); + } +); test('WHEN client does not provide a level THEN the backend responds with BadRequest', () => { const req = { diff --git a/backend/test/unit/controller/startController.test.ts b/backend/test/unit/controller/startController.test.ts index bf061e661..c9c731adc 100644 --- a/backend/test/unit/controller/startController.test.ts +++ b/backend/test/unit/controller/startController.test.ts @@ -24,15 +24,9 @@ afterEach(() => { mockSend.mockClear(); }); -[ - LEVEL_NAMES.LEVEL_1, - LEVEL_NAMES.LEVEL_2, - LEVEL_NAMES.LEVEL_3, - LEVEL_NAMES.SANDBOX, -].forEach((level) => { - test(`GIVEN level ${ - level + 1 - } provided WHEN user starts the frontend THEN the backend sends the correct initial information`, () => { +test.each(Object.values(LEVEL_NAMES))( + `GIVEN level [%s] provided WHEN user starts the frontend THEN the backend sends the correct initial information`, + (level) => { const req = { query: { level, @@ -78,8 +72,8 @@ afterEach(() => { { level: 2, systemRole: 'systemRoleLevel3' }, ], }); - }); -}); + } +); test('GIVEN no level provided WHEN user starts the frontend THEN the backend responds with error message', () => { const req = { diff --git a/frontend/src/components/ControlPanel/ControlPanel.tsx b/frontend/src/components/ControlPanel/ControlPanel.tsx index 3a165192d..3f4d8a493 100644 --- a/frontend/src/components/ControlPanel/ControlPanel.tsx +++ b/frontend/src/components/ControlPanel/ControlPanel.tsx @@ -2,7 +2,12 @@ import { DEFENCES_HIDDEN_LEVEL3_IDS, MODEL_DEFENCES } from '@src/Defences'; import DefenceBox from '@src/components/DefenceBox/DefenceBox'; import DocumentViewButton from '@src/components/DocumentViewer/DocumentViewButton'; import ModelBox from '@src/components/ModelBox/ModelBox'; -import { DEFENCE_ID, DefenceConfigItem, Defence } from '@src/models/defence'; +import { + DEFENCE_ID, + DefenceConfigItem, + Defence, + DEFENCE_CONFIG_ITEM_ID, +} from '@src/models/defence'; import { LEVEL_NAMES } from '@src/models/level'; import './ControlPanel.css'; @@ -21,7 +26,10 @@ function ControlPanel({ defences: Defence[]; chatModelOptions: string[]; toggleDefence: (defence: Defence) => void; - resetDefenceConfiguration: (defenceId: DEFENCE_ID, configId: string) => void; + resetDefenceConfiguration: ( + defenceId: DEFENCE_ID, + configItemId: DEFENCE_CONFIG_ITEM_ID + ) => void; setDefenceConfiguration: ( defenceId: DEFENCE_ID, config: DefenceConfigItem[] diff --git a/frontend/src/components/DefenceBox/DefenceBox.tsx b/frontend/src/components/DefenceBox/DefenceBox.tsx index b68fd3930..a4b74ef88 100644 --- a/frontend/src/components/DefenceBox/DefenceBox.tsx +++ b/frontend/src/components/DefenceBox/DefenceBox.tsx @@ -1,5 +1,10 @@ import { PROMPT_ENCLOSURE_DEFENCES } from '@src/Defences'; -import { DEFENCE_ID, DefenceConfigItem, Defence } from '@src/models/defence'; +import { + DEFENCE_ID, + DefenceConfigItem, + Defence, + DEFENCE_CONFIG_ITEM_ID, +} from '@src/models/defence'; import DefenceMechanism from './DefenceMechanism'; @@ -16,7 +21,10 @@ function DefenceBox({ defences: Defence[]; showConfigurations: boolean; toggleDefence: (defence: Defence) => void; - resetDefenceConfiguration: (defenceId: DEFENCE_ID, configId: string) => void; + resetDefenceConfiguration: ( + defenceId: DEFENCE_ID, + configItemId: DEFENCE_CONFIG_ITEM_ID + ) => void; setDefenceConfiguration: ( defenceId: DEFENCE_ID, config: DefenceConfigItem[] diff --git a/frontend/src/components/DefenceBox/DefenceConfiguration.tsx b/frontend/src/components/DefenceBox/DefenceConfiguration.tsx index 45fa3b3bb..c40ecee23 100644 --- a/frontend/src/components/DefenceBox/DefenceConfiguration.tsx +++ b/frontend/src/components/DefenceBox/DefenceConfiguration.tsx @@ -1,5 +1,9 @@ import ThemedButton from '@src/components/ThemedButtons/ThemedButton'; -import { Defence, DefenceConfigItem } from '@src/models/defence'; +import { + DEFENCE_CONFIG_ITEM_ID, + Defence, + DefenceConfigItem, +} from '@src/models/defence'; import { defenceService } from '@src/service'; import DefenceConfigurationInput from './DefenceConfigurationInput'; @@ -21,7 +25,10 @@ function DefenceConfiguration({ configId: string, value: string ) => Promise; - resetConfigurationValue: (defence: Defence, configId: string) => void; + resetConfigurationValue: ( + defence: Defence, + configItemId: DEFENCE_CONFIG_ITEM_ID + ) => void; }) { const uniqueInputId = `${defence.id}-${config.id}`; const supportText = `reset ${config.name} to default`; diff --git a/frontend/src/components/DefenceBox/DefenceMechanism.tsx b/frontend/src/components/DefenceBox/DefenceMechanism.tsx index 48e6848cc..b9e321827 100644 --- a/frontend/src/components/DefenceBox/DefenceMechanism.tsx +++ b/frontend/src/components/DefenceBox/DefenceMechanism.tsx @@ -1,6 +1,11 @@ import { useState } from 'react'; -import { DEFENCE_ID, DefenceConfigItem, Defence } from '@src/models/defence'; +import { + DEFENCE_ID, + DefenceConfigItem, + Defence, + DEFENCE_CONFIG_ITEM_ID, +} from '@src/models/defence'; import DefenceConfiguration from './DefenceConfiguration'; import PromptEnclosureDefenceMechanism from './PromptEnclosureDefenceMechanism'; @@ -19,7 +24,10 @@ function DefenceMechanism({ showConfigurations: boolean; promptEnclosureDefences: Defence[]; toggleDefence: (defence: Defence) => void; - resetDefenceConfiguration: (defenceId: DEFENCE_ID, configId: string) => void; + resetDefenceConfiguration: ( + defenceId: DEFENCE_ID, + configItemId: DEFENCE_CONFIG_ITEM_ID + ) => void; setDefenceConfiguration: ( defenceId: DEFENCE_ID, config: DefenceConfigItem[] @@ -27,8 +35,11 @@ function DefenceMechanism({ }) { const [configKey, setConfigKey] = useState(0); - function resetConfigurationValue(defence: Defence, configId: string) { - resetDefenceConfiguration(defence.id, configId); + function resetConfigurationValue( + defence: Defence, + configItemId: DEFENCE_CONFIG_ITEM_ID + ) { + resetDefenceConfiguration(defence.id, configItemId); } async function setConfigurationValue( diff --git a/frontend/src/components/DefenceBox/PromptEnclosureDefenceMechanism.tsx b/frontend/src/components/DefenceBox/PromptEnclosureDefenceMechanism.tsx index f93965a36..ce28201a6 100644 --- a/frontend/src/components/DefenceBox/PromptEnclosureDefenceMechanism.tsx +++ b/frontend/src/components/DefenceBox/PromptEnclosureDefenceMechanism.tsx @@ -1,7 +1,11 @@ import { useState } from 'react'; import DefenceConfigurationRadioButton from '@src/components/DefenceBox/DefenceConfigurationRadioButton'; -import { DEFENCE_ID, Defence } from '@src/models/defence'; +import { + DEFENCE_CONFIG_ITEM_ID, + DEFENCE_ID, + Defence, +} from '@src/models/defence'; import DefenceConfiguration from './DefenceConfiguration'; @@ -20,7 +24,10 @@ function PromptEnclosureDefenceMechanism({ configId: string, value: string ) => Promise; - resetConfigurationValue: (defence: Defence, configId: string) => void; + resetConfigurationValue: ( + defence: Defence, + configItemId: DEFENCE_CONFIG_ITEM_ID + ) => void; }) { // Using local state, else we'd need to wait for API response before updating // selected radio. This could land us in trouble if there's a server error... diff --git a/frontend/src/components/MainComponent/MainBody.tsx b/frontend/src/components/MainComponent/MainBody.tsx index 538240987..150c1f142 100644 --- a/frontend/src/components/MainComponent/MainBody.tsx +++ b/frontend/src/components/MainComponent/MainBody.tsx @@ -2,7 +2,12 @@ import ChatBox from '@src/components/ChatBox/ChatBox'; import ControlPanel from '@src/components/ControlPanel/ControlPanel'; import EmailBox from '@src/components/EmailBox/EmailBox'; import { ChatMessage } from '@src/models/chat'; -import { DEFENCE_ID, DefenceConfigItem, Defence } from '@src/models/defence'; +import { + DEFENCE_ID, + DefenceConfigItem, + Defence, + DEFENCE_CONFIG_ITEM_ID, +} from '@src/models/defence'; import { EmailInfo } from '@src/models/email'; import { LEVEL_NAMES } from '@src/models/level'; @@ -33,7 +38,10 @@ function MainBody({ addChatMessage: (message: ChatMessage) => void; addInfoMessage: (message: string) => void; addSentEmails: (emails: EmailInfo[]) => void; - resetDefenceConfiguration: (defenceId: DEFENCE_ID, configId: string) => void; + resetDefenceConfiguration: ( + defenceId: DEFENCE_ID, + configItemId: DEFENCE_CONFIG_ITEM_ID + ) => void; resetLevel: () => void; toggleDefence: (defence: Defence) => void; setDefenceConfiguration: ( diff --git a/frontend/src/components/MainComponent/MainComponent.tsx b/frontend/src/components/MainComponent/MainComponent.tsx index 266da665a..a6ed220bf 100644 --- a/frontend/src/components/MainComponent/MainComponent.tsx +++ b/frontend/src/components/MainComponent/MainComponent.tsx @@ -5,7 +5,12 @@ import HandbookOverlay from '@src/components/HandbookOverlay/HandbookOverlay'; import LevelMissionInfoBanner from '@src/components/LevelMissionInfoBanner/LevelMissionInfoBanner'; import ResetLevelOverlay from '@src/components/Overlay/ResetLevel'; import { ChatMessage } from '@src/models/chat'; -import { DEFENCE_ID, DefenceConfigItem, Defence } from '@src/models/defence'; +import { + DEFENCE_ID, + DefenceConfigItem, + Defence, + DEFENCE_CONFIG_ITEM_ID, +} from '@src/models/defence'; import { EmailInfo } from '@src/models/email'; import { LEVEL_NAMES, LevelSystemRole } from '@src/models/level'; import { @@ -190,18 +195,19 @@ function MainComponent({ async function resetDefenceConfiguration( defenceId: DEFENCE_ID, - configId: string + configItemId: DEFENCE_CONFIG_ITEM_ID ) { - const resetDefence = await defenceService.resetDefenceConfig( + const resetDefence = await defenceService.resetDefenceConfigItem( defenceId, - configId + configItemId, + currentLevel ); addConfigUpdateToChat(defenceId, 'reset'); // update state const newDefences = defences.map((defence) => { if (defence.id === defenceId) { defence.config.forEach((config) => { - if (config.id === configId) { + if (config.id === configItemId) { config.value = resetDefence.value.trim(); } }); @@ -311,9 +317,10 @@ function MainComponent({ addChatMessage={addChatMessage} addInfoMessage={addInfoMessage} addSentEmails={addSentEmails} - resetDefenceConfiguration={(defenceId: DEFENCE_ID, configId: string) => - void resetDefenceConfiguration(defenceId, configId) - } + resetDefenceConfiguration={( + defenceId: DEFENCE_ID, + configItemId: DEFENCE_CONFIG_ITEM_ID + ) => void resetDefenceConfiguration(defenceId, configItemId)} resetLevel={() => void resetLevel()} toggleDefence={(defence: Defence) => void setDefenceToggle(defence)} setDefenceConfiguration={setDefenceConfiguration} diff --git a/frontend/src/models/combined.ts b/frontend/src/models/combined.ts index 34b84903d..c77672b20 100644 --- a/frontend/src/models/combined.ts +++ b/frontend/src/models/combined.ts @@ -6,7 +6,7 @@ import { LevelSystemRole } from './level'; type StartReponse = { emails: EmailInfo[]; chatHistory: ChatMessageDTO[]; - defences: DefenceDTO[]; + defences?: DefenceDTO[]; availableModels: string[]; systemRoles: LevelSystemRole[]; }; @@ -14,7 +14,7 @@ type StartReponse = { type LoadLevelResponse = { emails: EmailInfo[]; chatHistory: ChatMessageDTO[]; - defences: DefenceDTO[]; + defences?: DefenceDTO[]; }; export type { StartReponse, LoadLevelResponse }; diff --git a/frontend/src/service/defenceService.ts b/frontend/src/service/defenceService.ts index 9ded6c7fb..66decf271 100644 --- a/frontend/src/service/defenceService.ts +++ b/frontend/src/service/defenceService.ts @@ -4,7 +4,9 @@ import { DefenceConfigItem, DefenceResetResponse, DefenceDTO, + DEFENCE_CONFIG_ITEM_ID, } from '@src/models/defence'; +import { LEVEL_NAMES } from '@src/models/level'; import { sendRequest } from './backendService'; @@ -35,7 +37,7 @@ function getDefencesFromDTOs(defenceDTOs: DefenceDTO[]) { async function toggleDefence( defenceId: DEFENCE_ID, isActive: boolean, - level: number + level: LEVEL_NAMES ): Promise { const requestPath = isActive ? 'deactivate' : 'activate'; const response = await sendRequest(`${PATH}${requestPath}`, { @@ -49,9 +51,9 @@ async function toggleDefence( } async function configureDefence( - defenceId: string, + defenceId: DEFENCE_ID, config: DefenceConfigItem[], - level: number + level: LEVEL_NAMES ): Promise { const response = await sendRequest(`${PATH}configure`, { method: 'POST', @@ -63,60 +65,65 @@ async function configureDefence( return response.status === 200; } -async function resetDefenceConfig( - defenceId: string, - configId: string +async function resetDefenceConfigItem( + defenceId: DEFENCE_ID, + configItemId: DEFENCE_CONFIG_ITEM_ID, + level: LEVEL_NAMES ): Promise { const response = await sendRequest(`${PATH}resetConfig`, { method: 'POST', headers: { 'Content-Type': 'application/json', }, - body: JSON.stringify({ defenceId, configId }), + body: JSON.stringify({ defenceId, configItemId, level }), }); return (await response.json()) as DefenceResetResponse; } -function validatePositiveNumberConfig(config: string) { +function validatePositiveNumberConfig(value: string) { // config is a number greater than zero - return !isNaN(Number(config)) && Number(config) > 0; + return !isNaN(Number(value)) && Number(value) > 0; } -function validateNonEmptyStringConfig(config: string) { +function validateNonEmptyStringConfig(value: string) { // config is non empty string and is not a number - return config !== '' && !Number(config); + return value !== '' && !Number(value); } -function validateFilterConfig(config: string) { +function validateFilterConfig(value: string) { // config is not a list of empty commas - const commaPattern = /^,*,*$/; - return config === '' || !commaPattern.test(config); + return ( + value + .split(',') + .map((item) => item.trim()) + .filter((item) => item).length > 0 + ); } function validateDefence( defenceId: DEFENCE_ID, - configId: string, - config: string + configItemId: DEFENCE_CONFIG_ITEM_ID, + configItemValue: string ) { switch (defenceId) { case DEFENCE_ID.CHARACTER_LIMIT: - return validatePositiveNumberConfig(config); + return validatePositiveNumberConfig(configItemValue); case DEFENCE_ID.INPUT_FILTERING: case DEFENCE_ID.OUTPUT_FILTERING: - return validateFilterConfig(config); + return validateFilterConfig(configItemValue); case DEFENCE_ID.RANDOM_SEQUENCE_ENCLOSURE: - return configId === 'SEQUENCE_LENGTH' - ? validatePositiveNumberConfig(config) - : validateNonEmptyStringConfig(config); + return configItemId === 'SEQUENCE_LENGTH' + ? validatePositiveNumberConfig(configItemValue) + : validateNonEmptyStringConfig(configItemValue); default: - return validateNonEmptyStringConfig(config); + return validateNonEmptyStringConfig(configItemValue); } } export { toggleDefence, configureDefence, - resetDefenceConfig, + resetDefenceConfigItem, validateDefence, getDefencesFromDTOs, }; diff --git a/frontend/src/service/levelService.ts b/frontend/src/service/levelService.ts index 643f3a9a1..d376d4d98 100644 --- a/frontend/src/service/levelService.ts +++ b/frontend/src/service/levelService.ts @@ -16,7 +16,7 @@ async function loadLevel(level: number) { return { emails, chatHistory: getChatMessagesFromDTOResponse(chatHistory), - defences: getDefencesFromDTOs(defences), + defences: defences ? getDefencesFromDTOs(defences) : [], }; } diff --git a/frontend/src/service/startService.ts b/frontend/src/service/startService.ts index ef61eb559..089764d15 100644 --- a/frontend/src/service/startService.ts +++ b/frontend/src/service/startService.ts @@ -16,7 +16,7 @@ async function start(level: number) { return { emails, chatHistory: getChatMessagesFromDTOResponse(chatHistory), - defences: getDefencesFromDTOs(defences), + defences: defences ? getDefencesFromDTOs(defences) : [], availableModels, systemRoles, };