From 6a1b9c4cecd5ca7530a7cc241d31bbfc34b26730 Mon Sep 17 00:00:00 2001 From: ned Date: Thu, 2 Mar 2023 21:58:41 +0100 Subject: [PATCH] multi-chat support --- README.md | 2 +- gpt_helper.py | 35 +++++++++++++++++++++++++---------- telegram_bot.py | 6 +++--- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 5fe90b29..5d853ebc 100644 --- a/README.md +++ b/README.md @@ -18,9 +18,9 @@ A [Telegram bot](https://core.telegram.org/bots/api) that integrates with OpenAI - [x] (NEW!) Support multiple answers! - [x] (NEW!) Customizable model parameters (see [configuration](#configuration) section) - [x] (NEW!) See token usage after each answer +- [x] (NEW!) Multi-chat support ## Coming soon -- [ ] Multi-chat support - [ ] Image generation using DALL·E APIs ## Additional Features - help needed! diff --git a/gpt_helper.py b/gpt_helper.py index b05f6e40..30f36249 100644 --- a/gpt_helper.py +++ b/gpt_helper.py @@ -14,21 +14,25 @@ def __init__(self, config: dict): """ openai.api_key = config['api_key'] self.config = config - self.initial_history = [{"role": "system", "content": config['assistant_prompt']}] - self.history = self.initial_history + self.sessions: dict[int: list] = dict() # {chat_id: history} - def get_response(self, query) -> str: + + def get_response(self, chat_id: int, query: str) -> str: """ Gets a response from the GPT-3 model. + :param chat_id: The chat ID :param query: The query to send to the model :return: The answer from the model """ try: - self.history.append({"role": "user", "content": query}) + if chat_id not in self.sessions: + self.reset_history(chat_id) + + self.__add_to_history(chat_id, role="user", content=query) response = openai.ChatCompletion.create( model=self.config['model'], - messages=self.history, + messages=self.sessions[chat_id], temperature=self.config['temperature'], n=self.config['n_choices'], max_tokens=self.config['max_tokens'], @@ -42,13 +46,13 @@ def get_response(self, query) -> str: if len(response.choices) > 1 and self.config['n_choices'] > 1: for index, choice in enumerate(response.choices): if index == 0: - self.history.append({"role": "assistant", "content": choice['message']['content']}) + self.__add_to_history(chat_id, role="assistant", content=choice['message']['content']) answer += f'{index+1}\u20e3\n' answer += choice['message']['content'] answer += '\n\n' else: answer = response.choices[0]['message']['content'] - self.history.append({"role": "assistant", "content": answer}) + self.__add_to_history(chat_id, role="assistant", content=answer) if self.config['show_usage']: answer += "\n\n---\n" \ @@ -63,7 +67,7 @@ def get_response(self, query) -> str: except openai.error.RateLimitError as e: logging.exception(e) - return "⚠️ _OpenAI RateLimit exceeded_ ⚠️\nPlease try again in a while." + return f"⚠️ _OpenAI Rate Limit exceeded_ ⚠️\n{str(e)}" except openai.error.InvalidRequestError as e: logging.exception(e) @@ -73,8 +77,19 @@ def get_response(self, query) -> str: logging.exception(e) return f"⚠️ _An error has occurred_ ⚠️\n{str(e)}" - def reset_history(self): + + def reset_history(self, chat_id): """ Resets the conversation history. """ - self.history = self.initial_history + self.sessions[chat_id] = [{"role": "system", "content": self.config['assistant_prompt']}] + + + def __add_to_history(self, chat_id, role, content): + """ + Adds a message to the conversation history. + :param chat_id: The chat ID + :param role: The role of the message sender + :param content: The message content + """ + self.sessions[chat_id].append({"role": role, "content": content}) \ No newline at end of file diff --git a/telegram_bot.py b/telegram_bot.py index ab4eebfe..ceacc247 100644 --- a/telegram_bot.py +++ b/telegram_bot.py @@ -37,12 +37,12 @@ async def reset(self, update: Update, context: ContextTypes.DEFAULT_TYPE): Resets the conversation. """ if not self.is_allowed(update): - logging.warning(f'User {update.message.from_user.name} is not allowed to reset the bot') + logging.warning(f'User {update.message.from_user.name} is not allowed to reset the conversation') await self.send_disallowed_message(update, context) return logging.info(f'Resetting the conversation for user {update.message.from_user.name}...') - self.gpt.reset_history() + self.gpt.reset_history(chat_id=update.effective_chat.id) await context.bot.send_message(chat_id=update.effective_chat.id, text='Done!') async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): @@ -57,7 +57,7 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): logging.info(f'New message received from user {update.message.from_user.name}') await context.bot.send_chat_action(chat_id=update.effective_chat.id, action=constants.ChatAction.TYPING) - response = self.gpt.get_response(update.message.text) + response = self.gpt.get_response(chat_id=update.effective_chat.id, query=update.message.text) await context.bot.send_message( chat_id=update.effective_chat.id, reply_to_message_id=update.message.message_id,