diff --git a/backend/src/app.js b/backend/src/app.js index e7e4ccd39..fce21dee3 100644 --- a/backend/src/app.js +++ b/backend/src/app.js @@ -14,6 +14,9 @@ const port = process.env.PORT || 3001; const envOpenAiKey = process.env.OPENAI_API_KEY; +// use default model +const defaultModel = "gpt-4"; + // Creating express server const app = express(); // for parsing application/json @@ -55,6 +58,9 @@ app.use(function (req, res, next) { if (!req.session.apiKey) { req.session.apiKey = envOpenAiKey || ""; } + if (!req.session.gptModel) { + req.session.gptModel = defaultModel; + } next(); }); @@ -65,6 +71,7 @@ app.listen(port, () => { // for dev purposes only - set the API key from the environment variable if (envOpenAiKey) { console.debug("Initializing models with API key from environment variable"); - setOpenAiApiKey({ apiKey: "" }, process.env.OPENAI_API_KEY); + const defaultSession = { apiKey: "", gptModel: defaultModel }; + setOpenAiApiKey(defaultSession, process.env.OPENAI_API_KEY); } }); diff --git a/backend/src/openai.js b/backend/src/openai.js index 0f142218d..856ec9f53 100644 --- a/backend/src/openai.js +++ b/backend/src/openai.js @@ -11,6 +11,7 @@ const { // OpenAI config let config = null; let openai = null; + // functions available to ChatGPT const chatGptFunctions = [ { @@ -60,30 +61,24 @@ const chatGptFunctions = [ }, ]; -function initOpenAi() { - config = new Configuration({ - apiKey: process.env.OPENAI_API_KEY, - }); - openai = new OpenAIApi(config); -} - // test the api key works with the model -async function validateApiKey(apiKey) { +async function validateApiKey(apiKey, gptModel) { try { const testOpenAI = new OpenAIApi(new Configuration({ apiKey: apiKey })); const response = await testOpenAI.createChatCompletion({ - model: "gpt-4", + model: gptModel, messages: [{ role: "user", content: "this is a test prompt" }], }); return true; } catch (error) { + console.error("Error validating API key: " + error); return false; } } async function setOpenAiApiKey(session, apiKey) { // initialise all models with the new key - if (await validateApiKey(apiKey)) { + if (await validateApiKey(apiKey, session.gptModel)) { console.debug("Setting API key and initialising models"); session.apiKey = apiKey; initOpenAi(session); @@ -99,6 +94,29 @@ async function setOpenAiApiKey(session, apiKey) { } } +function initOpenAi(session) { + configuration = new Configuration({ + apiKey: session.apiKey, + }); + openai = new OpenAIApi(configuration); + console.debug("OpenAI initialised"); +} + +async function setGptModel(session, model) { + if (model !== session.gptModel) { + if (await validateApiKey(session.apiKey, model)) { + console.debug( + "Setting GPT model from: " + session.gptModel + " to: " + model + ); + session.gptModel = model; + return true; + } else { + console.debug("Could not validate apiKey with model=" + model); + return false; + } + } +} + // returns true if the function is in the list of functions available to ChatGPT function isChatGptFunction(functionName) { return chatGptFunctions.find((func) => func.name === functionName); @@ -142,10 +160,11 @@ async function chatGptCallFunction(functionCall, defenceInfo, session) { session ); } - } else if (functionName == "askEmailWhitelist") { - response = askEmailWhitelist(session.defences); + } else if (functionName == "getEmailWhitelist") { + response = getEmailWhitelist( + isDefenceActive("EMAIL_WHITELIST", session.activeDefences) + ); } - if (functionName === "askQuestion") { console.debug("Asking question: " + params.question); // if asking a question, call the queryDocuments @@ -192,7 +211,7 @@ async function chatGptChatCompletion(session) { } chat_completion = await openai.createChatCompletion({ - model: "gpt-4", + model: session.gptModel, messages: session.chatHistory, functions: chatGptFunctions, }); @@ -254,4 +273,5 @@ module.exports = { chatGptSendMessage, setOpenAiApiKey, validateApiKey, + setGptModel, }; diff --git a/backend/src/router.js b/backend/src/router.js index 248bf0aa0..c8457a921 100644 --- a/backend/src/router.js +++ b/backend/src/router.js @@ -1,6 +1,4 @@ const express = require("express"); -const router = express.Router(); - const { activateDefence, deactivateDefence, @@ -8,7 +6,12 @@ const { transformMessage, detectTriggeredDefences, } = require("./defence"); -const { chatGptSendMessage, setOpenAiApiKey } = require("./openai"); +const { + chatGptSendMessage, + setOpenAiApiKey, + setGptModel, +} = require("./openai"); +const router = express.Router(); // Activate a defence router.post("/defence/activate", (req, res, next) => { @@ -150,5 +153,21 @@ router.get("/openai/apiKey", (req, res, next) => { res.send(req.session.apiKey); }); +// Set the ChatGPT model +router.post("/openai/model", async (req, res, next) => { + const model = req.body?.model; + if (model) { + if (await setGptModel(req.session, model)) { + res.status(200).send("ChatGPT model set. "); + } else { + res.status(401).send("Could not set model"); + } + } +}); + +router.get("/openai/model", (req, res, next) => { + res.send(req.session.gptModel); +}); + // Importing the router module.exports = router; diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 210696976..5abed3fd2 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -6,6 +6,7 @@ import DefenceBox from "./components/DefenceBox/DefenceBox"; import EmailBox from "./components/EmailBox/EmailBox"; import ApiKeyBox from "./components/ApiKeyBox/ApiKeyBox"; import Header from "./components/Header"; +import ModelSelectionBox from "./components/ModelSelectionBox/ModelSelectionBox"; import { useState } from "react"; import { EmailInfo } from "./models/email"; @@ -26,10 +27,13 @@ function App() {
defence mechanisms
+
attack mechanisms
openai api key
+
model selection
+
diff --git a/frontend/src/components/ModelSelectionBox/ModelSelectionBox.css b/frontend/src/components/ModelSelectionBox/ModelSelectionBox.css new file mode 100644 index 000000000..e08734413 --- /dev/null +++ b/frontend/src/components/ModelSelectionBox/ModelSelectionBox.css @@ -0,0 +1,39 @@ +#model-selection-box { + padding: 1px 32px; + position: relative; +} + +#model-selection-row { + display: flex; + align-items: center; +} + +#model-selection-info { + align-items: center; + text-align: center; +} + +#model-selection-info p { + margin: 0; + padding: 0; + color:cadetblue +} + +#p error { + color: red; +} + +#model-selection-info { + +} + + +#model-selection-menu { + width: 60%; + padding: 5px; + margin-right: 10px; +} + +#model-selection-button { + padding: 5px; +} \ No newline at end of file diff --git a/frontend/src/components/ModelSelectionBox/ModelSelectionBox.tsx b/frontend/src/components/ModelSelectionBox/ModelSelectionBox.tsx new file mode 100644 index 000000000..ca1b57f5e --- /dev/null +++ b/frontend/src/components/ModelSelectionBox/ModelSelectionBox.tsx @@ -0,0 +1,80 @@ +// Path: frontend\src\components\ModelSelectionBox\ModelSelectionBox.tsx +import React, { useEffect, useState } from "react"; +import "./ModelSelectionBox.css"; +import { setGptModel, getGptModel } from "../../service/chatService"; +import { CHAT_MODELS } from "../../models/chat"; + +// return a drop down menu with the models +function ModelSelectionBox() { + // model currently selected in the dropdown + const [selectedModel, setSelectedModel] = useState("gpt-4"); + // model in use by the app + const [modelInUse, setModelInUse] = useState("gpt-4"); + + const [errorChangingModel, setErrorChangingModel] = useState(false); + + // handle button click to log the selected model + const submitSelectedModel = async () => { + console.log("selected model: " + selectedModel); + if (await setGptModel(selectedModel)) { + setModelInUse(selectedModel); + setErrorChangingModel(false); + } else { + setErrorChangingModel(true); + } + }; + + // get the model + useEffect(() => { + const getModel = async () => { + const model = await getGptModel(); + if (model) { + setModelInUse(model); + } + }; + getModel(); + }, []); + + const chatModelOptions = Object.values(CHAT_MODELS); + + // return a drop down menu with the models + return ( +
+
+

Select a model:

+ + +
+ +
+ {errorChangingModel ? ( +

+ Could not change model. You are still chatting to: + {modelInUse} +

+ ) : ( +

+ You are chatting to model: {modelInUse} +

+ )} +
+
+ ); +} + +export default ModelSelectionBox; diff --git a/frontend/src/models/chat.ts b/frontend/src/models/chat.ts index 57ae60cb8..6762c6091 100644 --- a/frontend/src/models/chat.ts +++ b/frontend/src/models/chat.ts @@ -1,3 +1,14 @@ +enum CHAT_MODELS { + GPT_4 = "gpt-4", + GPT_4_0613 = "gpt-4-0613", + GPT_4_32K = "gpt-4-32k", + GPT_4_32K_0613 = "gpt-4-32k-0613", + GPT_3_5_TURBO = "gpt-3.5-turbo", + GPT_3_5_TURBO_0613 = "gpt-3.5-turbo-0613", + GPT_3_5_TURBO_16K = "gpt-3.5-turbo-16k", + GPT_3_5_TURBO_16K_0613 = "gpt-3.5-turbo-16k-0613", +} + interface ChatDefenceReport { blocked: boolean; triggeredDefences: string[]; @@ -17,3 +28,4 @@ interface ChatResponse { } export type { ChatMessage, ChatResponse }; +export { CHAT_MODELS }; diff --git a/frontend/src/service/chatService.ts b/frontend/src/service/chatService.ts index 1e96af8ba..031c616c8 100644 --- a/frontend/src/service/chatService.ts +++ b/frontend/src/service/chatService.ts @@ -36,4 +36,27 @@ const getOpenAIApiKey = async (): Promise => { return data; }; -export { clearChat, sendMessage, setOpenAIApiKey, getOpenAIApiKey }; +const setGptModel = async (model: string): Promise => { + const response = await sendRequest( + PATH + "model", + "POST", + { "Content-Type": "application/json" }, + JSON.stringify({ model }) + ); + return response.status === 200; +}; + +const getGptModel = async (): Promise => { + const response = await sendRequest(PATH + "model", "GET"); + const data = await response.text(); + return data; +}; + +export { + clearChat, + sendMessage, + setOpenAIApiKey, + getOpenAIApiKey, + getGptModel, + setGptModel, +}; diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 000000000..e6536e0c7 --- /dev/null +++ b/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "prompt-injection", + "lockfileVersion": 3, + "requires": true, + "packages": {} +}