Skip to content

Commit

Permalink
don't detect triggered defences on phase 0 and 1 (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
heatherlogan-scottlogic committed Aug 22, 2023
1 parent cf38f71 commit 44c6647
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 40 deletions.
24 changes: 10 additions & 14 deletions backend/src/defence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,7 @@ function transformMessage(message: string, defences: DefenceInfo[]) {
// detects triggered defences in original message and blocks the message if necessary
async function detectTriggeredDefences(
message: string,
defences: DefenceInfo[],
useLlmEvaluation: boolean
defences: DefenceInfo[]
) {
// keep track of any triggered defences
const defenceReport: ChatDefenceReport = {
Expand Down Expand Up @@ -263,18 +262,15 @@ async function detectTriggeredDefences(
}

// evaluate the message for prompt injection
// to speed up replies, only do this on phases where defences are active
if (useLlmEvaluation) {
const evalPrompt = await queryPromptEvaluationModel(message);
if (evalPrompt.isMalicious) {
defenceReport.triggeredDefences.push("LLM_EVALUATION");
if (isDefenceActive("LLM_EVALUATION", defences)) {
console.debug("LLM evalutation defence active.");
defenceReport.isBlocked = true;
defenceReport.blockedReason =
"Message blocked by the malicious prompt evaluator." +
evalPrompt.reason;
}
const evalPrompt = await queryPromptEvaluationModel(message);
if (evalPrompt.isMalicious) {
defenceReport.triggeredDefences.push("LLM_EVALUATION");
if (isDefenceActive("LLM_EVALUATION", defences)) {
console.debug("LLM evalutation defence active.");
defenceReport.isBlocked = true;
defenceReport.blockedReason =
"Message blocked by the malicious prompt evaluator." +
evalPrompt.reason;
}
}

Expand Down
18 changes: 9 additions & 9 deletions backend/src/router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,18 @@ router.post("/openai/chat", async (req, res) => {
prevPhase = currentPhase;
initQAModel(req.session.apiKey, retrievalQAPrePrompt, currentPhase);
}

if (message) {
chatResponse.transformedMessage = message;
// see if this message triggers any defences
const useLlmEvaluation =
// see if this message triggers any defences (only for phase 2 and sandbox)
if (
currentPhase === PHASE_NAMES.PHASE_2 ||
currentPhase === PHASE_NAMES.SANDBOX;
chatResponse.defenceInfo = await detectTriggeredDefences(
message,
req.session.defences,
useLlmEvaluation
);
currentPhase === PHASE_NAMES.SANDBOX
) {
chatResponse.defenceInfo = await detectTriggeredDefences(
message,
req.session.defences
);
}
// if blocked, send the response
if (!chatResponse.defenceInfo.isBlocked) {
// transform the message according to active defences
Expand Down
22 changes: 5 additions & 17 deletions backend/test/unit/defence.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ test("GIVEN XML_TAGGING defence is active WHEN transforming message THEN message
test("GIVEN no defences are active WHEN detecting triggered defences THEN no defences are triggered", async () => {
const message = "Hello";
const defences = getInitialDefences();
const defenceReport = await detectTriggeredDefences(message, defences, false);
const defenceReport = await detectTriggeredDefences(message, defences);
expect(defenceReport.blockedReason).toBe(null);
expect(defenceReport.isBlocked).toBe(false);
expect(defenceReport.triggeredDefences.length).toBe(0);
Expand All @@ -133,11 +133,7 @@ test(
value: String(3),
},
]);
const defenceReport = await detectTriggeredDefences(
message,
defences,
false
);
const defenceReport = await detectTriggeredDefences(message, defences);
expect(defenceReport.blockedReason).toBe("Message is too long");
expect(defenceReport.isBlocked).toBe(true);
expect(defenceReport.triggeredDefences).toContain("CHARACTER_LIMIT");
Expand All @@ -160,11 +156,7 @@ test(
value: String(280),
},
]);
const defenceReport = await detectTriggeredDefences(
message,
defences,
false
);
const defenceReport = await detectTriggeredDefences(message, defences);
expect(defenceReport.blockedReason).toBe(null);
expect(defenceReport.isBlocked).toBe(false);
expect(defenceReport.triggeredDefences.length).toBe(0);
Expand All @@ -185,11 +177,7 @@ test(
value: String(3),
},
]);
const defenceReport = await detectTriggeredDefences(
message,
defences,
false
);
const defenceReport = await detectTriggeredDefences(message, defences);
expect(defenceReport.blockedReason).toBe(null);
expect(defenceReport.isBlocked).toBe(false);
expect(defenceReport.triggeredDefences).toContain("CHARACTER_LIMIT");
Expand All @@ -199,7 +187,7 @@ test(
test("GIVEN message contains XML tags WHEN detecting triggered defences THEN XML_TAGGING defence is triggered", async () => {
const message = "<Hello>";
const defences = getInitialDefences();
const defenceReport = await detectTriggeredDefences(message, defences, false);
const defenceReport = await detectTriggeredDefences(message, defences);
expect(defenceReport.blockedReason).toBe(null);
expect(defenceReport.isBlocked).toBe(false);
expect(defenceReport.triggeredDefences).toContain("XML_TAGGING");
Expand Down

0 comments on commit 44c6647

Please sign in to comment.