Skip to content

Commit

Permalink
🚀 feat: Assistants Streaming (#2159)
Browse files Browse the repository at this point in the history
* chore: bump openai to 4.29.0 and npm audit fix

* chore: remove unnecessary stream field from ContentData

* feat: new enum and types for AssistantStreamEvent

* refactor(AssistantService): remove stream field and add conversationId to text ContentData
> - return `finalMessage` and `text` on run completion
> - move `processMessages` to services/Threads to avoid circular dependencies with new stream handling
> - refactor(processMessages/retrieveAndProcessFile): add new `client` field to differentiate new RunClient type

* WIP: new assistants stream handling

* chore: stores messages to StreamRunManager

* chore: add additional typedefs

* fix: pass req and openai to StreamRunManager

* fix(AssistantService): pass openai as client to `retrieveAndProcessFile`

* WIP: streaming tool i/o, handle in_progress and completed run steps

* feat(assistants): process required actions with streaming enabled

* chore: condense early return check for useSSE useEffect

* chore: remove unnecessary comments and only handle completed tool calls when not function

* feat: add TTL for assistants run abort cacheKey

* feat: abort stream runs

* fix(assistants): render streaming cursor

* fix(assistants): hide edit icon as functionality is not supported

* fix(textArea): handle pasting edge cases; first, when onChange events wouldn't fire; second, when textarea wouldn't resize

* chore: memoize Conversations

* chore(useTextarea): reverse args order

* fix: load default capabilities when an azure is configured to support assistants, but `assistants` endpoint is not configured

* fix(AssistantSelect): update form assistant model on assistant form select

* fix(actions): handle azure strict validation for function names to fix crud for actions

* chore: remove content data debug log as it fires in rapid succession

* feat: improve UX for assistant errors mid-request

* feat: add tool call localizations and replace any domain separators from azure action names

* refactor(chat): error out tool calls without outputs during handleError

* fix(ToolService): handle domain separators allowing Azure use of actions

* refactor(StreamRunManager): types and throw Error if tool submission fails
  • Loading branch information
danny-avila committed Mar 22, 2024
1 parent ed64c76 commit f427ad7
Show file tree
Hide file tree
Showing 39 changed files with 1,502 additions and 329 deletions.
2 changes: 1 addition & 1 deletion api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"multer": "^1.4.5-lts.1",
"nodejs-gpt": "^1.37.4",
"nodemailer": "^6.9.4",
"openai": "^4.28.4",
"openai": "^4.29.0",
"openai-chat-tokens": "^0.2.8",
"openid-client": "^5.4.2",
"passport": "^0.6.0",
Expand Down
5 changes: 3 additions & 2 deletions api/server/middleware/abortRun.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { sendMessage } = require('~/server/utils');
// const spendTokens = require('~/models/spendTokens');
const { logger } = require('~/config');

const three_minutes = 1000 * 60 * 3;

async function abortRun(req, res) {
res.setHeader('Content-Type', 'application/json');
const { abortKey } = req.body;
Expand Down Expand Up @@ -40,7 +41,7 @@ async function abortRun(req, res) {
const { openai } = await initializeClient({ req, res });

try {
await cache.set(cacheKey, 'cancelled');
await cache.set(cacheKey, 'cancelled', three_minutes);
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
logger.debug('[abortRun] Cancelled run:', cancelledRun);
} catch (error) {
Expand Down
14 changes: 10 additions & 4 deletions api/server/routes/assistants/actions.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ const { v4 } = require('uuid');
const express = require('express');
const { actionDelimiter } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/assistants');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { updateAssistant, getAssistant } = require('~/models/Assistant');
const { encryptMetadata } = require('~/server/services/ActionService');
const { logger } = require('~/config');

const router = express.Router();
Expand Down Expand Up @@ -44,7 +44,10 @@ router.post('/:assistant_id', async (req, res) => {

let metadata = encryptMetadata(_metadata);

const { domain } = metadata;
let { domain } = metadata;
/* Azure doesn't support periods in function names */
domain = domainParser(req, domain, true);

if (!domain) {
return res.status(400).json({ message: 'No domain provided' });
}
Expand Down Expand Up @@ -141,9 +144,10 @@ router.post('/:assistant_id', async (req, res) => {
* @param {string} req.params.action_id - The ID of the action to delete.
* @returns {Object} 200 - success response - application/json
*/
router.delete('/:assistant_id/:action_id', async (req, res) => {
router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
try {
const { assistant_id, action_id } = req.params;
const { assistant_id, action_id, model } = req.params;
req.body.model = model;

/** @type {{ openai: OpenAI }} */
const { openai } = await initializeClient({ req, res });
Expand All @@ -167,6 +171,8 @@ router.delete('/:assistant_id/:action_id', async (req, res) => {
return true;
});

domain = domainParser(req, domain, true);

const updatedTools = tools.filter(
(tool) => !(tool.function && tool.function.name.includes(domain)),
);
Expand Down
173 changes: 132 additions & 41 deletions api/server/routes/assistants/chat.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ const {
Constants,
RunStatus,
CacheKeys,
ContentTypes,
EModelEndpoint,
ViolationTypes,
AssistantStreamEvents,
} = require('librechat-data-provider');
const {
initThread,
Expand All @@ -18,8 +20,8 @@ const {
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { getTransactions } = require('~/models/Transaction');
const { createRun } = require('~/server/services/Runs');
const checkBalance = require('~/models/checkBalance');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
Expand All @@ -38,6 +40,8 @@ const {

router.post('/abort', handleAbort());

const ten_minutes = 1000 * 60 * 10;

/**
* @route POST /
* @desc Chat with an assistant
Expand Down Expand Up @@ -147,7 +151,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
return sendResponse(res, messageData, defaultErrorMessage);
}

await sleep(3000);
await sleep(2000);

try {
const status = await cache.get(cacheKey);
Expand Down Expand Up @@ -187,6 +191,42 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
latestMessageId: responseMessageId,
});

const errorContentPart = {
text: {
value:
error?.message ?? 'There was an error processing your request. Please try again later.',
},
type: ContentTypes.ERROR,
};

if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
runMessages[runMessages.length - 1].content = [errorContentPart];
} else {
const contentParts = runMessages[runMessages.length - 1].content;
for (let i = 0; i < contentParts.length; i++) {
const currentPart = contentParts[i];
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
if (
toolCall &&
toolCall?.function &&
!(toolCall?.function?.output || toolCall?.function?.output?.length)
) {
contentParts[i] = {
...currentPart,
[ContentTypes.TOOL_CALL]: {
...toolCall,
function: {
...toolCall.function,
output: 'error processing tool',
},
},
};
}
}
runMessages[runMessages.length - 1].content.push(errorContentPart);
}

finalEvent = {
title: 'New Chat',
final: true,
Expand Down Expand Up @@ -358,53 +398,107 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
body.instructions = instructions;
}

/* NOTE:
* By default, a Run will use the model and tools configuration specified in Assistant object,
* but you can override most of these when creating the Run for added flexibility:
*/
const run = await createRun({
openai,
thread_id,
body,
});
const sendInitialResponse = () => {
sendMessage(res, {
sync: true,
conversationId,
// messages: previousMessages,
requestMessage,
responseMessage: {
user: req.user.id,
messageId: openai.responseMessage.messageId,
parentMessageId: userMessageId,
conversationId,
assistant_id,
thread_id,
model: assistant_id,
},
});
};

run_id = run.id;
await cache.set(cacheKey, `${thread_id}:${run_id}`);
/** @type {RunResponse | typeof StreamRunManager | undefined} */
let response;

const processRun = async (retry = false) => {
if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
if (retry) {
response = await runAssistant({
openai,
thread_id,
run_id,
in_progress: openai.in_progress,
});
return;
}

/* NOTE:
* By default, a Run will use the model and tools configuration specified in Assistant object,
* but you can override most of these when creating the Run for added flexibility:
*/
const run = await createRun({
openai,
thread_id,
body,
});

sendMessage(res, {
sync: true,
conversationId,
// messages: previousMessages,
requestMessage,
responseMessage: {
user: req.user.id,
messageId: openai.responseMessage.messageId,
parentMessageId: userMessageId,
conversationId,
assistant_id,
thread_id,
model: assistant_id,
},
});
run_id = run.id;
await cache.set(cacheKey, `${thread_id}:${run_id}`, ten_minutes);
sendInitialResponse();

// todo: retry logic
let response = await runAssistant({ openai, thread_id, run_id });
logger.debug('[/assistants/chat/] response', response);
// todo: retry logic
response = await runAssistant({ openai, thread_id, run_id });
return;
}

if (response.run.status === RunStatus.IN_PROGRESS) {
response = await runAssistant({
/** @type {{[AssistantStreamEvents.ThreadRunCreated]: (event: ThreadRunCreated) => Promise<void>}} */
const handlers = {
[AssistantStreamEvents.ThreadRunCreated]: async (event) => {
await cache.set(cacheKey, `${thread_id}:${event.data.id}`, ten_minutes);
run_id = event.data.id;
sendInitialResponse();
},
};

const streamRunManager = new StreamRunManager({
req,
res,
openai,
thread_id,
run_id,
in_progress: openai.in_progress,
responseMessage: openai.responseMessage,
handlers,
// streamOptions: {

// },
});

await streamRunManager.runAssistant({
thread_id,
body,
});

response = streamRunManager;
};

await processRun();
logger.debug('[/assistants/chat/] response', {
run: response.run,
steps: response.steps,
});

if (response.run.status === RunStatus.CANCELLED) {
logger.debug('[/assistants/chat/] Run cancelled, handled by `abortRun`');
return res.end();
}

if (response.run.status === RunStatus.IN_PROGRESS) {
processRun(true);
}

completedRun = response.run;

/** @type {ResponseMessage} */
const responseMessage = {
...openai.responseMessage,
...response.finalMessage,
parentMessageId: userMessageId,
conversationId,
user: req.user.id,
Expand All @@ -413,9 +507,6 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
model: assistant_id,
};

// TODO: token count from usage returned in run
// TODO: parse responses, save to db, send to user

sendMessage(res, {
title: 'New Chat',
final: true,
Expand All @@ -432,7 +523,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
if (parentMessageId === Constants.NO_PARENT && !_thread_id) {
addTitle(req, {
text,
responseText: openai.responseText,
responseText: response.text,
conversationId,
client,
});
Expand All @@ -447,7 +538,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res

if (!response.run.usage) {
await sleep(3000);
completedRun = await openai.beta.threads.runs.retrieve(thread_id, run.id);
completedRun = await openai.beta.threads.runs.retrieve(thread_id, response.run.id);
if (completedRun.usage) {
await recordUsage({
...completedRun.usage,
Expand Down
30 changes: 29 additions & 1 deletion api/server/services/ActionService.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,35 @@
const { AuthTypeEnum } = require('librechat-data-provider');
const { AuthTypeEnum, EModelEndpoint, actionDomainSeparator } = require('librechat-data-provider');
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
const { getActions } = require('~/models/Action');
const { logger } = require('~/config');

/**
* Parses the domain for an action.
*
* Azure OpenAI Assistants API doesn't support periods in function
* names due to `[a-zA-Z0-9_-]*` Regex Validation.
*
* @param {Express.Request} req - Express Request object
* @param {string} domain - The domain for the actoin
* @param {boolean} inverse - If true, replaces periods with `actionDomainSeparator`
* @returns {string} The parsed domain
*/
function domainParser(req, domain, inverse = false) {
if (!domain) {
return;
}

if (!req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
return domain;
}

if (inverse) {
return domain.replace(/\./g, actionDomainSeparator);
}

return domain.replace(actionDomainSeparator, '.');
}

/**
* Loads action sets based on the user and assistant ID.
*
Expand Down Expand Up @@ -117,4 +144,5 @@ module.exports = {
createActionTool,
encryptMetadata,
decryptMetadata,
domainParser,
};
Loading

0 comments on commit f427ad7

Please sign in to comment.