Skip to content

Commit

Permalink
Merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
heatherlogan-scottlogic committed Aug 8, 2023
2 parents 35c8b38 + a09c06c commit 38c3279
Show file tree
Hide file tree
Showing 41 changed files with 684 additions and 415 deletions.
31 changes: 20 additions & 11 deletions backend/src/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const dotenv = require("dotenv");
const cors = require("cors");
const session = require("express-session");
const { initOpenAi } = require("./openai");
const { initQAModel } = require("./documents");
const { initQAModel, initPromptEvaluationModel } = require("./langchain");

dotenv.config();

Expand All @@ -15,21 +15,30 @@ const port = process.env.PORT || 3001;
const app = express();
// for parsing application/json
app.use(express.json());

// use session
app.use(
session({
secret: process.env.SESSION_SECRET,
resave: false,
saveUninitialized: true,
})
);
const sess = {
secret: process.env.SESSION_SECRET,
name: "prompt-injection.sid",
resave: false,
saveUninitialized: true,
cookie: {},
};
// serve secure cookies in production
if (app.get("env") === "production") {
sess.cookie.secure = true;
}
app.use(session(sess));

// initialise openai
// main model for chat
initOpenAi();

// initialise question answering llm
// model for question answering of documenents
initQAModel();

// model for LLM prompt evaluation
initPromptEvaluationModel();

app.use(
cors({
credentials: true,
Expand All @@ -55,4 +64,4 @@ app.use(function (req, res, next) {
app.use("/", router);
app.listen(port, () => {
console.log("Server is running on port: " + port);
});
});
1 change: 1 addition & 0 deletions backend/src/defence.js
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ function detectTriggeredDefences(message, activeDefences) {
return { reply: null, defenceInfo: defenceInfo };
}


module.exports = {
activateDefence,
deactivateDefence,
Expand Down
91 changes: 0 additions & 91 deletions backend/src/documents.js

This file was deleted.

161 changes: 161 additions & 0 deletions backend/src/langchain.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
const { TextLoader } = require("langchain/document_loaders/fs/text");
const { PDFLoader } = require("langchain/document_loaders/fs/pdf");
const { CSVLoader } = require("langchain/document_loaders/fs/csv");
const { DirectoryLoader } = require("langchain/document_loaders/fs/directory");
const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter");
const { OpenAIEmbeddings } = require("langchain/embeddings/openai");
const { MemoryVectorStore } = require("langchain/vectorstores/memory");
const { ChatOpenAI } = require("langchain/chat_models/openai");
const { RetrievalQAChain, LLMChain, SequentialChain } = require("langchain/chains");
const { PromptTemplate } = require("langchain/prompts");
const { OpenAI } = require("langchain/llms/openai");
const { retrievalQATemplate, promptInjectionEvalTemplate, maliciousPromptTemplate } = require("./promptTemplates");

// chain we use in question/answer request
let qaChain = null;

// chain we use in prompt evaluation request
let promptEvaluationChain = null;

// load the documents from filesystem
async function getDocuments() {
const filePath = "resources/documents/";
const loader = new DirectoryLoader(filePath, {
".pdf": (path) => new PDFLoader(path),
".txt": (path) => new TextLoader(path),
".csv": (path) => new CSVLoader(path),
});
const docs = await loader.load();

// split the documents into chunks
const textSplitter = new RecursiveCharacterTextSplitter({
chunkSize: 1000,
chunkOverlap: 0,
});
const splitDocs = await textSplitter.splitDocuments(docs);

console.debug(
"Documents loaded and split. Number of chunks: " + splitDocs.length
);

return splitDocs;
}

// QA Chain - ask the chat model a question about the documents
async function initQAModel() {
// get the documents
const docs = await getDocuments();

// embed and store the splits
const embeddings = new OpenAIEmbeddings();
const vectorStore = await MemoryVectorStore.fromDocuments(
docs,
embeddings,
);
// initialise model
const model = new ChatOpenAI({
modelName: "gpt-4",
stream: true,
});

// prompt template for question and answering
const qaPrompt = PromptTemplate.fromTemplate(retrievalQATemplate);

// set chain to retrieval QA chain
qaChain = RetrievalQAChain.fromLLM(model, vectorStore.asRetriever(), {
prompt: qaPrompt,
});
console.debug("QA Retrieval chain initialised.");
}


// initialise the prompt evaluation model
async function initPromptEvaluationModel() {

// create chain to detect prompt injection
const promptInjectionPrompt = PromptTemplate.fromTemplate(promptInjectionEvalTemplate);

const promptInjectionChain = new LLMChain({
llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0 }),
prompt: promptInjectionPrompt,
outputKey: "promptInjectionEval"
});

// create chain to detect malicious prompts
const maliciousInputPrompt = PromptTemplate.fromTemplate(maliciousPromptTemplate);
const maliciousInputChain = new LLMChain({
llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0 }),
prompt: maliciousInputPrompt,
outputKey: "maliciousInputEval"
});

promptEvaluationChain = new SequentialChain({
chains: [promptInjectionChain, maliciousInputChain],
inputVariables: ['prompt'],
outputVariables: ["promptInjectionEval", "maliciousInputEval"],
});
console.debug("Prompt evaluation chain initialised.");
}


// ask the question and return models answer
async function queryDocuments(question) {
const response = await qaChain.call({
query: question,
});
console.debug("QA model response: " + response.text);
const result = {
reply: response.text,
questionAnswered: true,
};
return result;
}

// ask LLM whether the prompt is malicious
async function queryPromptEvaluationModel(input) {
const response = await promptEvaluationChain.call({ prompt: input });

const promptInjectionEval = formatEvaluationOutput(response.promptInjectionEval);
const maliciousInputEval = formatEvaluationOutput(response.maliciousInputEval);

console.debug("Prompt injection eval: " + JSON.stringify(promptInjectionEval));
console.debug("Malicious input eval: " + JSON.stringify(maliciousInputEval));

// if both are malicious, combine reason
if (promptInjectionEval.isMalicious && maliciousInputEval.isMalicious) {
return { isMalicious: true, reason: `${promptInjectionEval.reason} & ${maliciousInputEval.reason}` };
} else if (promptInjectionEval.isMalicious) {
return { isMalicious: true, reason: promptInjectionEval.reason };
} else if (maliciousInputEval.isMalicious) {
return { isMalicious: true, reason: maliciousInputEval.reason };
}
return { isMalicious: false, reason: "" };
}

// format the evaluation model output. text should be a Yes or No answer followed by a reason
function formatEvaluationOutput(response){
try {
// split response on first full stop or comma
const splitResponse = response.split(/\.|,/);
const answer = splitResponse[0].replace(/\W/g, '').toLowerCase();
const reason = splitResponse[1];
return {
isMalicious: answer === "yes",
reason: reason,
}

} catch (error) {
// in case the model does not respond in the format we have asked
console.error(error);
console.debug("Did not get a valid response from the prompt evaluation model. Original response: " + response.text);
return { isMalicious: false, reason: "" };
}
}

module.exports = {
getDocuments,
initQAModel,
initPromptEvaluationModel,
queryDocuments,
queryPromptEvaluationModel
};
Loading

0 comments on commit 38c3279

Please sign in to comment.