Skip to content

Commit

Permalink
52 different llms (#96)
Browse files Browse the repository at this point in the history
* Dropdown menu for changing gpt model
  • Loading branch information
heatherlogan-scottlogic committed Aug 9, 2023
1 parent 6a02a78 commit db1182b
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 19 deletions.
9 changes: 8 additions & 1 deletion backend/src/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ const port = process.env.PORT || 3001;

const envOpenAiKey = process.env.OPENAI_API_KEY;

// use default model
const defaultModel = "gpt-4";

// Creating express server
const app = express();
// for parsing application/json
Expand Down Expand Up @@ -55,6 +58,9 @@ app.use(function (req, res, next) {
if (!req.session.apiKey) {
req.session.apiKey = envOpenAiKey || "";
}
if (!req.session.gptModel) {
req.session.gptModel = defaultModel;
}
next();
});

Expand All @@ -65,6 +71,7 @@ app.listen(port, () => {
// for dev purposes only - set the API key from the environment variable
if (envOpenAiKey) {
console.debug("Initializing models with API key from environment variable");
setOpenAiApiKey({ apiKey: "" }, process.env.OPENAI_API_KEY);
const defaultSession = { apiKey: "", gptModel: defaultModel };
setOpenAiApiKey(defaultSession, process.env.OPENAI_API_KEY);
}
});
48 changes: 34 additions & 14 deletions backend/src/openai.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const {
// OpenAI config
let config = null;
let openai = null;

// functions available to ChatGPT
const chatGptFunctions = [
{
Expand Down Expand Up @@ -60,30 +61,24 @@ const chatGptFunctions = [
},
];

function initOpenAi() {
config = new Configuration({
apiKey: process.env.OPENAI_API_KEY,
});
openai = new OpenAIApi(config);
}

// test the api key works with the model
async function validateApiKey(apiKey) {
async function validateApiKey(apiKey, gptModel) {
try {
const testOpenAI = new OpenAIApi(new Configuration({ apiKey: apiKey }));
const response = await testOpenAI.createChatCompletion({
model: "gpt-4",
model: gptModel,
messages: [{ role: "user", content: "this is a test prompt" }],
});
return true;
} catch (error) {
console.error("Error validating API key: " + error);
return false;
}
}

async function setOpenAiApiKey(session, apiKey) {
// initialise all models with the new key
if (await validateApiKey(apiKey)) {
if (await validateApiKey(apiKey, session.gptModel)) {
console.debug("Setting API key and initialising models");
session.apiKey = apiKey;
initOpenAi(session);
Expand All @@ -99,6 +94,29 @@ async function setOpenAiApiKey(session, apiKey) {
}
}

function initOpenAi(session) {
configuration = new Configuration({
apiKey: session.apiKey,
});
openai = new OpenAIApi(configuration);
console.debug("OpenAI initialised");
}

async function setGptModel(session, model) {
if (model !== session.gptModel) {
if (await validateApiKey(session.apiKey, model)) {
console.debug(
"Setting GPT model from: " + session.gptModel + " to: " + model
);
session.gptModel = model;
return true;
} else {
console.debug("Could not validate apiKey with model=" + model);
return false;
}
}
}

// returns true if the function is in the list of functions available to ChatGPT
function isChatGptFunction(functionName) {
return chatGptFunctions.find((func) => func.name === functionName);
Expand Down Expand Up @@ -142,10 +160,11 @@ async function chatGptCallFunction(functionCall, defenceInfo, session) {
session
);
}
} else if (functionName == "askEmailWhitelist") {
response = askEmailWhitelist(session.defences);
} else if (functionName == "getEmailWhitelist") {
response = getEmailWhitelist(
isDefenceActive("EMAIL_WHITELIST", session.activeDefences)
);
}

if (functionName === "askQuestion") {
console.debug("Asking question: " + params.question);
// if asking a question, call the queryDocuments
Expand Down Expand Up @@ -192,7 +211,7 @@ async function chatGptChatCompletion(session) {
}

chat_completion = await openai.createChatCompletion({
model: "gpt-4",
model: session.gptModel,
messages: session.chatHistory,
functions: chatGptFunctions,
});
Expand Down Expand Up @@ -254,4 +273,5 @@ module.exports = {
chatGptSendMessage,
setOpenAiApiKey,
validateApiKey,
setGptModel,
};
25 changes: 22 additions & 3 deletions backend/src/router.js
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
const express = require("express");
const router = express.Router();

const {
activateDefence,
deactivateDefence,
configureDefence,
transformMessage,
detectTriggeredDefences,
} = require("./defence");
const { chatGptSendMessage, setOpenAiApiKey } = require("./openai");
const {
chatGptSendMessage,
setOpenAiApiKey,
setGptModel,
} = require("./openai");
const router = express.Router();

// Activate a defence
router.post("/defence/activate", (req, res, next) => {
Expand Down Expand Up @@ -150,5 +153,21 @@ router.get("/openai/apiKey", (req, res, next) => {
res.send(req.session.apiKey);
});

// Set the ChatGPT model
router.post("/openai/model", async (req, res, next) => {
const model = req.body?.model;
if (model) {
if (await setGptModel(req.session, model)) {
res.status(200).send("ChatGPT model set. ");
} else {
res.status(401).send("Could not set model");
}
}
});

router.get("/openai/model", (req, res, next) => {
res.send(req.session.gptModel);
});

// Importing the router
module.exports = router;
4 changes: 4 additions & 0 deletions frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import DefenceBox from "./components/DefenceBox/DefenceBox";
import EmailBox from "./components/EmailBox/EmailBox";
import ApiKeyBox from "./components/ApiKeyBox/ApiKeyBox";
import Header from "./components/Header";
import ModelSelectionBox from "./components/ModelSelectionBox/ModelSelectionBox";
import { useState } from "react";
import { EmailInfo } from "./models/email";

Expand All @@ -26,10 +27,13 @@ function App() {
<div className="side-bar">
<div className="side-bar-header">defence mechanisms</div>
<DefenceBox key={defenceBoxKey} triggeredDefences={triggeredDefences} />

<div className="side-bar-header">attack mechanisms</div>
<AttackBox />
<div className="side-bar-header">openai api key</div>
<ApiKeyBox />
<div className="side-bar-header">model selection</div>
<ModelSelectionBox />
</div>
<div id="centre-area">
<Header />
Expand Down
39 changes: 39 additions & 0 deletions frontend/src/components/ModelSelectionBox/ModelSelectionBox.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#model-selection-box {
padding: 1px 32px;
position: relative;
}

#model-selection-row {
display: flex;
align-items: center;
}

#model-selection-info {
align-items: center;
text-align: center;
}

#model-selection-info p {
margin: 0;
padding: 0;
color:cadetblue
}

#p error {
color: red;
}

#model-selection-info {

}


#model-selection-menu {
width: 60%;
padding: 5px;
margin-right: 10px;
}

#model-selection-button {
padding: 5px;
}
80 changes: 80 additions & 0 deletions frontend/src/components/ModelSelectionBox/ModelSelectionBox.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Path: frontend\src\components\ModelSelectionBox\ModelSelectionBox.tsx
import React, { useEffect, useState } from "react";
import "./ModelSelectionBox.css";
import { setGptModel, getGptModel } from "../../service/chatService";
import { CHAT_MODELS } from "../../models/chat";

// return a drop down menu with the models
function ModelSelectionBox() {
// model currently selected in the dropdown
const [selectedModel, setSelectedModel] = useState("gpt-4");
// model in use by the app
const [modelInUse, setModelInUse] = useState("gpt-4");

const [errorChangingModel, setErrorChangingModel] = useState(false);

// handle button click to log the selected model
const submitSelectedModel = async () => {
console.log("selected model: " + selectedModel);
if (await setGptModel(selectedModel)) {
setModelInUse(selectedModel);
setErrorChangingModel(false);
} else {
setErrorChangingModel(true);
}
};

// get the model
useEffect(() => {
const getModel = async () => {
const model = await getGptModel();
if (model) {
setModelInUse(model);
}
};
getModel();
}, []);

const chatModelOptions = Object.values(CHAT_MODELS);

// return a drop down menu with the models
return (
<div id="model-selection-box">
<div id="model-selection-row">
<p>Select a model: </p>
<select
id="model-selection-menu"
aria-label="model-select"
onChange={(e) => setSelectedModel(e.target.value)}
placeholder={modelInUse}
>
{chatModelOptions.map((model) => (
<option value={model}>{model}</option>
))}
;
</select>
<button
id="model-selection-button"
onClick={async () => submitSelectedModel()}
>
Choose
</button>
</div>

<div id="model-selection-info">
{errorChangingModel ? (
<p id="error">
Could not change model. You are still chatting to:
<b> {modelInUse} </b>
</p>
) : (
<p>
You are chatting to model: <b> {modelInUse}</b>
</p>
)}
</div>
</div>
);
}

export default ModelSelectionBox;
12 changes: 12 additions & 0 deletions frontend/src/models/chat.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
enum CHAT_MODELS {
GPT_4 = "gpt-4",
GPT_4_0613 = "gpt-4-0613",
GPT_4_32K = "gpt-4-32k",
GPT_4_32K_0613 = "gpt-4-32k-0613",
GPT_3_5_TURBO = "gpt-3.5-turbo",
GPT_3_5_TURBO_0613 = "gpt-3.5-turbo-0613",
GPT_3_5_TURBO_16K = "gpt-3.5-turbo-16k",
GPT_3_5_TURBO_16K_0613 = "gpt-3.5-turbo-16k-0613",
}

interface ChatDefenceReport {
blocked: boolean;
triggeredDefences: string[];
Expand All @@ -17,3 +28,4 @@ interface ChatResponse {
}

export type { ChatMessage, ChatResponse };
export { CHAT_MODELS };
25 changes: 24 additions & 1 deletion frontend/src/service/chatService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,27 @@ const getOpenAIApiKey = async (): Promise<string> => {
return data;
};

export { clearChat, sendMessage, setOpenAIApiKey, getOpenAIApiKey };
const setGptModel = async (model: string): Promise<boolean> => {
const response = await sendRequest(
PATH + "model",
"POST",
{ "Content-Type": "application/json" },
JSON.stringify({ model })
);
return response.status === 200;
};

const getGptModel = async (): Promise<string> => {
const response = await sendRequest(PATH + "model", "GET");
const data = await response.text();
return data;
};

export {
clearChat,
sendMessage,
setOpenAIApiKey,
getOpenAIApiKey,
getGptModel,
setGptModel,
};
6 changes: 6 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit db1182b

Please sign in to comment.