diff --git a/backend/src/app.ts b/backend/src/app.ts index 477db2b54..053560320 100644 --- a/backend/src/app.ts +++ b/backend/src/app.ts @@ -3,7 +3,7 @@ import dotenv from "dotenv"; import express from "express"; import session from "express-session"; -import { getInitialDefences, getQALLMprePrompt } from "./defence"; +import { getInitialDefences } from "./defence"; import { setOpenAiApiKey } from "./openai"; import { router } from "./router"; import { ChatCompletionRequestMessage } from "openai"; diff --git a/backend/src/defence.ts b/backend/src/defence.ts index 2a6a99cfe..79ca0815f 100644 --- a/backend/src/defence.ts +++ b/backend/src/defence.ts @@ -230,8 +230,7 @@ function transformMessage(message: string, defences: DefenceInfo[]) { // detects triggered defences in original message and blocks the message if necessary async function detectTriggeredDefences( message: string, - defences: DefenceInfo[], - useLlmEvaluation: boolean + defences: DefenceInfo[] ) { // keep track of any triggered defences const defenceReport: ChatDefenceReport = { @@ -263,18 +262,15 @@ async function detectTriggeredDefences( } // evaluate the message for prompt injection - // to speed up replies, only do this on phases where defences are active - if (useLlmEvaluation) { - const evalPrompt = await queryPromptEvaluationModel(message); - if (evalPrompt.isMalicious) { - defenceReport.triggeredDefences.push("LLM_EVALUATION"); - if (isDefenceActive("LLM_EVALUATION", defences)) { - console.debug("LLM evalutation defence active."); - defenceReport.isBlocked = true; - defenceReport.blockedReason = - "Message blocked by the malicious prompt evaluator." + - evalPrompt.reason; - } + const evalPrompt = await queryPromptEvaluationModel(message); + if (evalPrompt.isMalicious) { + defenceReport.triggeredDefences.push("LLM_EVALUATION"); + if (isDefenceActive("LLM_EVALUATION", defences)) { + console.debug("LLM evalutation defence active."); + defenceReport.isBlocked = true; + defenceReport.blockedReason = + "Message blocked by the malicious prompt evaluator." + + evalPrompt.reason; } } diff --git a/backend/src/router.ts b/backend/src/router.ts index 5d2700371..232a27105 100644 --- a/backend/src/router.ts +++ b/backend/src/router.ts @@ -130,18 +130,18 @@ router.post("/openai/chat", async (req, res) => { prevPhase = currentPhase; initQAModel(req.session.apiKey, retrievalQAPrePrompt, currentPhase); } - if (message) { chatResponse.transformedMessage = message; - // see if this message triggers any defences - const useLlmEvaluation = + // see if this message triggers any defences (only for phase 2 and sandbox) + if ( currentPhase === PHASE_NAMES.PHASE_2 || - currentPhase === PHASE_NAMES.SANDBOX; - chatResponse.defenceInfo = await detectTriggeredDefences( - message, - req.session.defences, - useLlmEvaluation - ); + currentPhase === PHASE_NAMES.SANDBOX + ) { + chatResponse.defenceInfo = await detectTriggeredDefences( + message, + req.session.defences + ); + } // if blocked, send the response if (!chatResponse.defenceInfo.isBlocked) { // transform the message according to active defences diff --git a/backend/test/unit/defence.test.ts b/backend/test/unit/defence.test.ts index 3129d1b69..9bcdcc993 100644 --- a/backend/test/unit/defence.test.ts +++ b/backend/test/unit/defence.test.ts @@ -111,7 +111,7 @@ test("GIVEN XML_TAGGING defence is active WHEN transforming message THEN message test("GIVEN no defences are active WHEN detecting triggered defences THEN no defences are triggered", async () => { const message = "Hello"; const defences = getInitialDefences(); - const defenceReport = await detectTriggeredDefences(message, defences, false); + const defenceReport = await detectTriggeredDefences(message, defences); expect(defenceReport.blockedReason).toBe(null); expect(defenceReport.isBlocked).toBe(false); expect(defenceReport.triggeredDefences.length).toBe(0); @@ -133,11 +133,7 @@ test( value: String(3), }, ]); - const defenceReport = await detectTriggeredDefences( - message, - defences, - false - ); + const defenceReport = await detectTriggeredDefences(message, defences); expect(defenceReport.blockedReason).toBe("Message is too long"); expect(defenceReport.isBlocked).toBe(true); expect(defenceReport.triggeredDefences).toContain("CHARACTER_LIMIT"); @@ -160,11 +156,7 @@ test( value: String(280), }, ]); - const defenceReport = await detectTriggeredDefences( - message, - defences, - false - ); + const defenceReport = await detectTriggeredDefences(message, defences); expect(defenceReport.blockedReason).toBe(null); expect(defenceReport.isBlocked).toBe(false); expect(defenceReport.triggeredDefences.length).toBe(0); @@ -185,11 +177,7 @@ test( value: String(3), }, ]); - const defenceReport = await detectTriggeredDefences( - message, - defences, - false - ); + const defenceReport = await detectTriggeredDefences(message, defences); expect(defenceReport.blockedReason).toBe(null); expect(defenceReport.isBlocked).toBe(false); expect(defenceReport.triggeredDefences).toContain("CHARACTER_LIMIT"); @@ -199,7 +187,7 @@ test( test("GIVEN message contains XML tags WHEN detecting triggered defences THEN XML_TAGGING defence is triggered", async () => { const message = ""; const defences = getInitialDefences(); - const defenceReport = await detectTriggeredDefences(message, defences, false); + const defenceReport = await detectTriggeredDefences(message, defences); expect(defenceReport.blockedReason).toBe(null); expect(defenceReport.isBlocked).toBe(false); expect(defenceReport.triggeredDefences).toContain("XML_TAGGING");