diff --git a/.env.example b/.env.example index d3d368d0..020bbbc1 100644 --- a/.env.example +++ b/.env.example @@ -54,4 +54,5 @@ ALLOWED_TELEGRAM_USER_IDS=USER_ID_1,USER_ID_2 # TTS_PRICES=0.015,0.030 # BOT_LANGUAGE=en # ENABLE_VISION_FOLLOW_UP_QUESTIONS="true" -# VISION_MODEL="gpt-4-vision-preview" \ No newline at end of file +# VISION_MODEL="gpt-4-vision-preview" +# SHOW_COMMANDS_IN_HELP=true diff --git a/.gitignore b/.gitignore index a156be6a..11368887 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ __pycache__ /usage_logs venv /.cache + +storage/conversations.json +storage/last_updated.json diff --git a/bot/main.py b/bot/main.py index 41604c84..b52a52e3 100644 --- a/bot/main.py +++ b/bot/main.py @@ -100,6 +100,7 @@ def main(): 'tts_prices': [float(i) for i in os.environ.get('TTS_PRICES', "0.015,0.030").split(",")], 'transcription_price': float(os.environ.get('TRANSCRIPTION_PRICE', 0.006)), 'bot_language': os.environ.get('BOT_LANGUAGE', 'en'), + 'show_commands_in_help': os.environ.get('SHOW_COMMANDS_IN_HELP', 'true').lower() == 'true', } plugin_config = { diff --git a/bot/openai_helper.py b/bot/openai_helper.py index 24eaa49f..b021dfb1 100644 --- a/bot/openai_helper.py +++ b/bot/openai_helper.py @@ -112,9 +112,39 @@ def __init__(self, config: dict, plugin_manager: PluginManager): self.client = openai.AsyncOpenAI(api_key=config['api_key'], http_client=http_client) self.config = config self.plugin_manager = plugin_manager - self.conversations: dict[int: list] = {} # {chat_id: history} + self.__load_conversations_and_last_updated() self.conversations_vision: dict[int: bool] = {} # {chat_id: is_vision} - self.last_updated: dict[int: datetime] = {} # {chat_id: last_update_timestamp} + + def __load_file(self, file_path: str): + try: + with open(file_path, 'r', encoding='utf-8') as f: + logging.info(f"Loading from {file_path}") + return json.load(f) + except Exception as e: + logging.error(f"An error occurred while loading conversations from {file_path}: {e}") + return {} + + def __save_file(self, file_path: str, data: dict): + try: + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=4) + except Exception as e: + logging.error(f"An error occurred while saving conversations to {file_path}: {e}") + + def __save_conversations(self): + self.__save_file('storage/conversations.json', self.conversations) + + def __save_last_updated(self): + last_updated_copy = {chat_id: last_update.isoformat() for chat_id, last_update in self.last_updated.items()} + self.__save_file('storage/last_updated.json', last_updated_copy) + + def __load_conversations_and_last_updated(self): + self.conversations = {int(k): v for k, v in self.__load_file('storage/conversations.json').items()} + last_updated_copy = {int(k): v for k, v in self.__load_file('storage/last_updated.json').items()} + if last_updated_copy: + self.last_updated = {chat_id: datetime.datetime.fromisoformat(last_update) for chat_id, last_update in last_updated_copy.items()} + else: + self.last_updated = {} def get_conversation_stats(self, chat_id: int) -> tuple[int, int]: """ @@ -135,7 +165,7 @@ async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]: """ plugins_used = () response = await self.__common_get_chat_response(chat_id, query) - if self.config['enable_functions'] and not self.conversations_vision[chat_id]: + if self.config['enable_functions'] and chat_id not in self.conversations_vision: response, plugins_used = await self.__handle_function_call(chat_id, response) if is_direct_result(response): return response, '0' @@ -226,6 +256,7 @@ async def __common_get_chat_response(self, chat_id: int, query: str, stream=Fals self.reset_chat_history(chat_id) self.last_updated[chat_id] = datetime.datetime.now() + self.__save_last_updated() self.__add_to_history(chat_id, role="user", content=query) @@ -245,9 +276,14 @@ async def __common_get_chat_response(self, chat_id: int, query: str, stream=Fals except Exception as e: logging.warning(f'Error while summarising chat history: {str(e)}. Popping elements instead...') self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:] + self.__save_conversations() + + model = self.config['model'] + if chat_id in self.conversations_vision: + model = self.config['vision_model'] common_args = { - 'model': self.config['model'] if not self.conversations_vision[chat_id] else self.config['vision_model'], + 'model': model, 'messages': self.conversations[chat_id], 'temperature': self.config['temperature'], 'n': self.config['n_choices'], @@ -257,7 +293,7 @@ async def __common_get_chat_response(self, chat_id: int, query: str, stream=Fals 'stream': stream } - if self.config['enable_functions'] and not self.conversations_vision[chat_id]: + if self.config['enable_functions'] and chat_id not in self.conversations_vision: functions = self.plugin_manager.get_functions_specs() if len(functions) > 0: common_args['functions'] = self.plugin_manager.get_functions_specs() @@ -408,6 +444,7 @@ async def __common_get_chat_response_vision(self, chat_id: int, content: list, s self.reset_chat_history(chat_id) self.last_updated[chat_id] = datetime.datetime.now() + self.__save_last_updated() if self.config['enable_vision_follow_up_questions']: self.conversations_vision[chat_id] = True @@ -427,16 +464,17 @@ async def __common_get_chat_response_vision(self, chat_id: int, content: list, s if exceeded_max_tokens or exceeded_max_history_size: logging.info(f'Chat history for chat ID {chat_id} is too long. Summarising...') try: - last = self.conversations[chat_id][-1] summary = await self.__summarise(self.conversations[chat_id][:-1]) logging.debug(f'Summary: {summary}') self.reset_chat_history(chat_id, self.conversations[chat_id][0]['content']) self.__add_to_history(chat_id, role="assistant", content=summary) self.conversations[chat_id] += [last] + self.__save_conversations() except Exception as e: logging.warning(f'Error while summarising chat history: {str(e)}. Popping elements instead...') self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:] + self.__save_conversations() message = {'role':'user', 'content':content} @@ -573,6 +611,7 @@ def reset_chat_history(self, chat_id, content=''): if content == '': content = self.config['assistant_prompt'] self.conversations[chat_id] = [{"role": "system", "content": content}] + self.__save_conversations() self.conversations_vision[chat_id] = False def __max_age_reached(self, chat_id) -> bool: @@ -602,6 +641,7 @@ def __add_to_history(self, chat_id, role, content): :param content: The message content """ self.conversations[chat_id].append({"role": role, "content": content}) + self.__save_conversations() async def __summarise(self, conversation) -> str: """ diff --git a/bot/telegram_bot.py b/bot/telegram_bot.py index 7a536b1f..9f9d602a 100644 --- a/bot/telegram_bot.py +++ b/bot/telegram_bot.py @@ -66,17 +66,17 @@ async def help(self, update: Update, _: ContextTypes.DEFAULT_TYPE) -> None: """ commands = self.group_commands if is_group_chat(update) else self.commands commands_description = [f'/{command.command} - {command.description}' for command in commands] + commands_description_text = '\n'.join(commands_description) + '\n\n' if self.config.get('show_commands_in_help', True) else '' bot_language = self.config['bot_language'] help_text = ( localized_text('help_text', bot_language)[0] + '\n\n' + - '\n'.join(commands_description) + - '\n\n' + + commands_description_text + localized_text('help_text', bot_language)[1] + '\n\n' + localized_text('help_text', bot_language)[2] ) - await update.message.reply_text(help_text, disable_web_page_preview=True) + await update.message.reply_text(help_text, parse_mode=constants.ParseMode.MARKDOWN, disable_web_page_preview=True) async def stats(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ diff --git a/storage/.gitignore b/storage/.gitignore new file mode 100644 index 00000000..e69de29b