Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store LLM conversation history in a file #2

Merged
merged 2 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ ALLOWED_TELEGRAM_USER_IDS=USER_ID_1,USER_ID_2
# TTS_PRICES=0.015,0.030
# BOT_LANGUAGE=en
# ENABLE_VISION_FOLLOW_UP_QUESTIONS="true"
# VISION_MODEL="gpt-4-vision-preview"
# VISION_MODEL="gpt-4-vision-preview"
# SHOW_COMMANDS_IN_HELP=true
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ __pycache__
/usage_logs
venv
/.cache

storage/conversations.json
storage/last_updated.json
1 change: 1 addition & 0 deletions bot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def main():
'tts_prices': [float(i) for i in os.environ.get('TTS_PRICES', "0.015,0.030").split(",")],
'transcription_price': float(os.environ.get('TRANSCRIPTION_PRICE', 0.006)),
'bot_language': os.environ.get('BOT_LANGUAGE', 'en'),
'show_commands_in_help': os.environ.get('SHOW_COMMANDS_IN_HELP', 'true').lower() == 'true',
}

plugin_config = {
Expand Down
52 changes: 46 additions & 6 deletions bot/openai_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,39 @@ def __init__(self, config: dict, plugin_manager: PluginManager):
self.client = openai.AsyncOpenAI(api_key=config['api_key'], http_client=http_client)
self.config = config
self.plugin_manager = plugin_manager
self.conversations: dict[int: list] = {} # {chat_id: history}
self.__load_conversations_and_last_updated()
self.conversations_vision: dict[int: bool] = {} # {chat_id: is_vision}
self.last_updated: dict[int: datetime] = {} # {chat_id: last_update_timestamp}

def __load_file(self, file_path: str):
try:
with open(file_path, 'r', encoding='utf-8') as f:
logging.info(f"Loading from {file_path}")
return json.load(f)
except Exception as e:
logging.error(f"An error occurred while loading conversations from {file_path}: {e}")
return {}

def __save_file(self, file_path: str, data: dict):
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
except Exception as e:
logging.error(f"An error occurred while saving conversations to {file_path}: {e}")

def __save_conversations(self):
self.__save_file('storage/conversations.json', self.conversations)

def __save_last_updated(self):
last_updated_copy = {chat_id: last_update.isoformat() for chat_id, last_update in self.last_updated.items()}
self.__save_file('storage/last_updated.json', last_updated_copy)

def __load_conversations_and_last_updated(self):
self.conversations = {int(k): v for k, v in self.__load_file('storage/conversations.json').items()}
last_updated_copy = {int(k): v for k, v in self.__load_file('storage/last_updated.json').items()}
if last_updated_copy:
self.last_updated = {chat_id: datetime.datetime.fromisoformat(last_update) for chat_id, last_update in last_updated_copy.items()}
else:
self.last_updated = {}

def get_conversation_stats(self, chat_id: int) -> tuple[int, int]:
"""
Expand All @@ -135,7 +165,7 @@ async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]:
"""
plugins_used = ()
response = await self.__common_get_chat_response(chat_id, query)
if self.config['enable_functions'] and not self.conversations_vision[chat_id]:
if self.config['enable_functions'] and chat_id not in self.conversations_vision:
response, plugins_used = await self.__handle_function_call(chat_id, response)
if is_direct_result(response):
return response, '0'
Expand Down Expand Up @@ -226,6 +256,7 @@ async def __common_get_chat_response(self, chat_id: int, query: str, stream=Fals
self.reset_chat_history(chat_id)

self.last_updated[chat_id] = datetime.datetime.now()
self.__save_last_updated()

self.__add_to_history(chat_id, role="user", content=query)

Expand All @@ -245,9 +276,14 @@ async def __common_get_chat_response(self, chat_id: int, query: str, stream=Fals
except Exception as e:
logging.warning(f'Error while summarising chat history: {str(e)}. Popping elements instead...')
self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:]
self.__save_conversations()

model = self.config['model']
if chat_id in self.conversations_vision:
model = self.config['vision_model']

common_args = {
'model': self.config['model'] if not self.conversations_vision[chat_id] else self.config['vision_model'],
'model': model,
'messages': self.conversations[chat_id],
'temperature': self.config['temperature'],
'n': self.config['n_choices'],
Expand All @@ -257,7 +293,7 @@ async def __common_get_chat_response(self, chat_id: int, query: str, stream=Fals
'stream': stream
}

if self.config['enable_functions'] and not self.conversations_vision[chat_id]:
if self.config['enable_functions'] and chat_id not in self.conversations_vision:
functions = self.plugin_manager.get_functions_specs()
if len(functions) > 0:
common_args['functions'] = self.plugin_manager.get_functions_specs()
Expand Down Expand Up @@ -408,6 +444,7 @@ async def __common_get_chat_response_vision(self, chat_id: int, content: list, s
self.reset_chat_history(chat_id)

self.last_updated[chat_id] = datetime.datetime.now()
self.__save_last_updated()

if self.config['enable_vision_follow_up_questions']:
self.conversations_vision[chat_id] = True
Expand All @@ -427,16 +464,17 @@ async def __common_get_chat_response_vision(self, chat_id: int, content: list, s
if exceeded_max_tokens or exceeded_max_history_size:
logging.info(f'Chat history for chat ID {chat_id} is too long. Summarising...')
try:

last = self.conversations[chat_id][-1]
summary = await self.__summarise(self.conversations[chat_id][:-1])
logging.debug(f'Summary: {summary}')
self.reset_chat_history(chat_id, self.conversations[chat_id][0]['content'])
self.__add_to_history(chat_id, role="assistant", content=summary)
self.conversations[chat_id] += [last]
self.__save_conversations()
except Exception as e:
logging.warning(f'Error while summarising chat history: {str(e)}. Popping elements instead...')
self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:]
self.__save_conversations()

message = {'role':'user', 'content':content}

Expand Down Expand Up @@ -573,6 +611,7 @@ def reset_chat_history(self, chat_id, content=''):
if content == '':
content = self.config['assistant_prompt']
self.conversations[chat_id] = [{"role": "system", "content": content}]
self.__save_conversations()
self.conversations_vision[chat_id] = False

def __max_age_reached(self, chat_id) -> bool:
Expand Down Expand Up @@ -602,6 +641,7 @@ def __add_to_history(self, chat_id, role, content):
:param content: The message content
"""
self.conversations[chat_id].append({"role": role, "content": content})
self.__save_conversations()

async def __summarise(self, conversation) -> str:
"""
Expand Down
6 changes: 3 additions & 3 deletions bot/telegram_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,17 @@ async def help(self, update: Update, _: ContextTypes.DEFAULT_TYPE) -> None:
"""
commands = self.group_commands if is_group_chat(update) else self.commands
commands_description = [f'/{command.command} - {command.description}' for command in commands]
commands_description_text = '\n'.join(commands_description) + '\n\n' if self.config.get('show_commands_in_help', True) else ''
bot_language = self.config['bot_language']
help_text = (
localized_text('help_text', bot_language)[0] +
'\n\n' +
'\n'.join(commands_description) +
'\n\n' +
commands_description_text +
localized_text('help_text', bot_language)[1] +
'\n\n' +
localized_text('help_text', bot_language)[2]
)
await update.message.reply_text(help_text, disable_web_page_preview=True)
await update.message.reply_text(help_text, parse_mode=constants.ParseMode.MARKDOWN, disable_web_page_preview=True)

async def stats(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
"""
Expand Down
Empty file added storage/.gitignore
Empty file.