Skip to content

Commit

Permalink
Merge with LLM Evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
heatherlogan-scottlogic committed Aug 7, 2023
2 parents 9cfec44 + 24c161b commit 26fa3b8
Show file tree
Hide file tree
Showing 10 changed files with 301 additions and 101 deletions.
9 changes: 6 additions & 3 deletions backend/src/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const dotenv = require("dotenv");
const cors = require("cors");
const session = require("express-session");
const { initOpenAi } = require("./openai");
const { initQAModel } = require("./documents");
const { initQAModel, initPromptEvaluationModel } = require("./langchain");

dotenv.config();

Expand All @@ -24,12 +24,15 @@ app.use(
})
);

// initialise openai
// main model for chat
initOpenAi();

// initialise question answering llm
// model for question answering of documenents
initQAModel();

// model for LLM prompt evaluation
initPromptEvaluationModel();

app.use(
cors({
credentials: true,
Expand Down
1 change: 1 addition & 0 deletions backend/src/defence.js
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ function detectTriggeredDefences(message, activeDefences) {
return { reply: null, defenceInfo: defenceInfo };
}


module.exports = {
activateDefence,
deactivateDefence,
Expand Down
92 changes: 0 additions & 92 deletions backend/src/documents.js

This file was deleted.

161 changes: 161 additions & 0 deletions backend/src/langchain.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
const { TextLoader } = require("langchain/document_loaders/fs/text");
const { PDFLoader } = require("langchain/document_loaders/fs/pdf");
const { CSVLoader } = require("langchain/document_loaders/fs/csv");
const { DirectoryLoader } = require("langchain/document_loaders/fs/directory");
const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter");
const { OpenAIEmbeddings } = require("langchain/embeddings/openai");
const { MemoryVectorStore } = require("langchain/vectorstores/memory");
const { ChatOpenAI } = require("langchain/chat_models/openai");
const { RetrievalQAChain, LLMChain, SequentialChain } = require("langchain/chains");
const { PromptTemplate } = require("langchain/prompts");
const { OpenAI } = require("langchain/llms/openai");
const { retrievalQATemplate, promptInjectionEvalTemplate, maliciousPromptTemplate } = require("./promptTemplates");

// chain we use in question/answer request
let qaChain = null;

// chain we use in prompt evaluation request
let promptEvaluationChain = null;

// load the documents from filesystem
async function getDocuments() {
const filePath = "resources/documents/";
const loader = new DirectoryLoader(filePath, {
".pdf": (path) => new PDFLoader(path),
".txt": (path) => new TextLoader(path),
".csv": (path) => new CSVLoader(path),
});
const docs = await loader.load();

// split the documents into chunks
const textSplitter = new RecursiveCharacterTextSplitter({
chunkSize: 1000,
chunkOverlap: 0,
});
const splitDocs = await textSplitter.splitDocuments(docs);

console.debug(
"Documents loaded and split. Number of chunks: " + splitDocs.length
);

return splitDocs;
}

// QA Chain - ask the chat model a question about the documents
async function initQAModel() {
// get the documents
const docs = await getDocuments();

// embed and store the splits
const embeddings = new OpenAIEmbeddings();
const vectorStore = await MemoryVectorStore.fromDocuments(
docs,
embeddings,
);
// initialise model
const model = new ChatOpenAI({
modelName: "gpt-4",
stream: true,
});

// prompt template for question and answering
const qaPrompt = PromptTemplate.fromTemplate(retrievalQATemplate);

// set chain to retrieval QA chain
qaChain = RetrievalQAChain.fromLLM(model, vectorStore.asRetriever(), {
prompt: qaPrompt,
});
console.debug("QA Retrieval chain initialised.");
}


// initialise the prompt evaluation model
async function initPromptEvaluationModel() {

// create chain to detect prompt injection
const promptInjectionPrompt = PromptTemplate.fromTemplate(promptInjectionEvalTemplate);

const promptInjectionChain = new LLMChain({
llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0 }),
prompt: promptInjectionPrompt,
outputKey: "promptInjectionEval"
});

// create chain to detect malicious prompts
const maliciousInputPrompt = PromptTemplate.fromTemplate(maliciousPromptTemplate);
const maliciousInputChain = new LLMChain({
llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0 }),
prompt: maliciousInputPrompt,
outputKey: "maliciousInputEval"
});

promptEvaluationChain = new SequentialChain({
chains: [promptInjectionChain, maliciousInputChain],
inputVariables: ['prompt'],
outputVariables: ["promptInjectionEval", "maliciousInputEval"],
});
console.debug("Prompt evaluation chain initialised.");
}


// ask the question and return models answer
async function queryDocuments(question) {
const response = await qaChain.call({
query: question,
});
console.debug("QA model response: " + response.text);
const result = {
reply: response.text,
questionAnswered: true,
};
return result;
}

// ask LLM whether the prompt is malicious
async function queryPromptEvaluationModel(input) {
const response = await promptEvaluationChain.call({ prompt: input });

const promptInjectionEval = formatEvaluationOutput(response.promptInjectionEval);
const maliciousInputEval = formatEvaluationOutput(response.maliciousInputEval);

console.debug("Prompt injection eval: " + JSON.stringify(promptInjectionEval));
console.debug("Malicious input eval: " + JSON.stringify(maliciousInputEval));

// if both are malicious, combine reason
if (promptInjectionEval.isMalicious && maliciousInputEval.isMalicious) {
return { isMalicious: true, reason: `${promptInjectionEval.reason} & ${maliciousInputEval.reason}` };
} else if (promptInjectionEval.isMalicious) {
return { isMalicious: true, reason: promptInjectionEval.reason };
} else if (maliciousInputEval.isMalicious) {
return { isMalicious: true, reason: maliciousInputEval.reason };
}
return { isMalicious: false, reason: "" };
}

// format the evaluation model output. text should be a Yes or No answer followed by a reason
function formatEvaluationOutput(response){
try {
// split response on first full stop or comma
const splitResponse = response.split(/\.|,/);
const answer = splitResponse[0].replace(/\W/g, '').toLowerCase();
const reason = splitResponse[1];
return {
isMalicious: answer === "yes",
reason: reason,
}

} catch (error) {
// in case the model does not respond in the format we have asked
console.error(error);
console.debug("Did not get a valid response from the prompt evaluation model. Original response: " + response.text);
return { isMalicious: false, reason: "" };
}
}

module.exports = {
getDocuments,
initQAModel,
initPromptEvaluationModel,
queryDocuments,
queryPromptEvaluationModel
};
26 changes: 21 additions & 5 deletions backend/src/openai.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
const { Configuration, OpenAIApi } = require("openai");
const { isDefenceActive } = require("./defence");
const { sendEmail, getEmailWhitelist, isEmailInWhitelist } = require("./email");
const { queryDocuments } = require("./documents");
const { queryDocuments, queryPromptEvaluationModel } = require("./langchain");

// OpenAI configuration
let configuration = null;
Expand Down Expand Up @@ -33,15 +33,16 @@ const chatGptFunctions = [
{
name: "getEmailWhitelist",
description:
"user asks who is on the email whitelist and the system replies with the list of emails. if the email whitelist defence is not active then user should be able to email anyone. ",
"user asks who is on the email whitelist and the system replies with the list of emails.",
parameters: {
type: "object",
properties: {},
},
},
{
name: "askQuestion",
description: "Ask a question about the documents with company infomration",
description:
"Ask a question about the documents containing company information ",
parameters: {
type: "object",
properties: {
Expand Down Expand Up @@ -104,7 +105,7 @@ async function chatGptCallFunction(functionCall, defenceInfo, session) {
session
);
}
} else if (functionName == "getEmailWhitelist") {
} else if (functionName === "getEmailWhitelist") {
response = getEmailWhitelist(
isDefenceActive("EMAIL_WHITELIST", session.activeDefences)
);
Expand All @@ -113,7 +114,7 @@ async function chatGptCallFunction(functionCall, defenceInfo, session) {
if (functionName === "askQuestion") {
console.debug("Asking question: " + params.question);
// if asking a question, call the queryDocuments
response = await queryDocuments(params.question);
response = (await queryDocuments(params.question)).reply;
}

// add function call to chat
Expand Down Expand Up @@ -163,6 +164,20 @@ async function chatGptChatCompletion(session) {
async function chatGptSendMessage(message, session) {
// init defence info
let defenceInfo = { triggeredDefences: [], blocked: false };

// evaluate the message for prompt injection
const evalPrompt = await queryPromptEvaluationModel(message);
if (evalPrompt.isMalicious) {
defenceInfo.triggeredDefences.push("LLM_EVALUATION");
if (isDefenceActive("LLM_EVALUATION", session.activeDefences)) {
console.debug("LLM evalutation defence active.");
defenceInfo.blocked = true;
const evalResponse =
"Message blocked by the malicious prompt evaluator." +
evalPrompt.reason;
return { reply: evalResponse, defenceInfo: defenceInfo };
}
}
// add user message to chat
session.chatHistory.push({ role: "user", content: message });

Expand All @@ -181,6 +196,7 @@ async function chatGptSendMessage(message, session) {
session.chatHistory.push(functionCallReply.reply);
// update the defence info
defenceInfo = functionCallReply.defenceInfo;

// get a new reply from ChatGPT now that the function has been called
reply = await chatGptChatCompletion(session);
}
Expand Down
Loading

0 comments on commit 26fa3b8

Please sign in to comment.