From e7545aece79a5a25c51ff5d5a1f9d97b862ef081 Mon Sep 17 00:00:00 2001 From: tl_kid Date: Sun, 3 Mar 2024 22:20:46 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0MistralAI=E7=9A=84Adapter=20?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0config.example.cfg=E4=B8=AD=E5=85=B3=E4=BA=8E?= =?UTF-8?q?MistralAI=E7=9A=84=E7=A4=BA=E4=BE=8B=20=E6=9B=B4=E6=96=B0Quart?= =?UTF-8?q?=E5=8C=85=E5=88=B00.19.4=EF=BC=88=E4=BF=AE=E5=A4=8D0.17.0?= =?UTF-8?q?=E4=B8=AD=E8=87=AA=E5=B8=A6=E7=9A=84flask=E7=89=88=E6=9C=AC?= =?UTF-8?q?=E5=92=8Curl=5Fdecode=E5=9C=A8python3.11=E4=B8=AD=E4=B8=8D?= =?UTF-8?q?=E5=85=BC=E5=AE=B9=E7=9A=84=E9=94=99=E8=AF=AF=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- adapter/mistral/mistral.py | 326 +++++++++++++++++++++++++++++++++++++ config.example.cfg | 10 ++ config.py | 32 ++++ constants.py | 3 + conversation.py | 7 + manager/bot.py | 27 ++- requirements.txt | 2 +- 7 files changed, 404 insertions(+), 3 deletions(-) create mode 100644 adapter/mistral/mistral.py diff --git a/adapter/mistral/mistral.py b/adapter/mistral/mistral.py new file mode 100644 index 00000000..1979b1ad --- /dev/null +++ b/adapter/mistral/mistral.py @@ -0,0 +1,326 @@ +import json +import time +import aiohttp +import async_timeout +import tiktoken +from loguru import logger +from typing import AsyncGenerator + +from adapter.botservice import BotAdapter +from config import MistralAIAPIKey +from constants import botManager, config + +DEFAULT_ENGINE: str = "mistral-large-latest" + + +class MistralAIChatbot: + def __init__(self, api_info: MistralAIAPIKey): + self.api_key = api_info.api_key + self.proxy = api_info.proxy + self.top_p = config.mistral.mistral_params.top_p + self.temperature = config.mistral.mistral_params.temperature + self.max_tokens = config.mistral.mistral_params.max_tokens + self.engine = api_info.model or DEFAULT_ENGINE + self.timeout = config.response.max_timeout + self.conversation: dict[str, list[dict]] = { + "default": [ + { + "role": "system", + "content": "你是 MistralAI,现在需要用中文进行交流。", + }, + ], + } + + async def rollback(self, session_id: str = "default", n: int = 1) -> None: + try: + if session_id not in self.conversation: + raise ValueError(f"会话 ID {session_id} 不存在。") + + if n > len(self.conversation[session_id]): + raise ValueError(f"回滚次数 {n} 超过了会话 {session_id} 的消息数量。") + + for _ in range(n): + self.conversation[session_id].pop() + + except ValueError as ve: + logger.error(ve) + raise + except Exception as e: + logger.error(f"未知错误: {e}") + raise + + def add_to_conversation(self, message: str, role: str, session_id: str = "default") -> None: + if role and message is not None: + self.conversation[session_id].append({"role": role, "content": message}) + else: + logger.warning("出现错误!返回消息为空,不添加到会话。") + raise ValueError("出现错误!返回消息为空,不添加到会话。") + + # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + def count_tokens(self, session_id: str = "default", model: str = DEFAULT_ENGINE): + """Return the number of tokens used by a list of messages.""" + if model is None: + model = DEFAULT_ENGINE + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + + tokens_per_message = 4 + tokens_per_name = 1 + + num_tokens = 0 + for message in self.conversation[session_id]: + num_tokens += tokens_per_message + for key, value in message.items(): + if value is not None: + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 # every reply is primed with assistant + return num_tokens + + def get_max_tokens(self, session_id: str, model: str) -> int: + """Get max tokens""" + return self.max_tokens - self.count_tokens(session_id, model) + + +class MistralAIAPIAdapter(BotAdapter): + api_info: MistralAIAPIKey = None + """API Key""" + + def __init__(self, session_id: str = "unknown"): + self.latest_role = None + self.__conversation_keep_from = 0 + self.session_id = session_id + self.api_info = botManager.pick('mistral') + self.bot = MistralAIChatbot(self.api_info) + self.conversation_id = None + self.parent_id = None + super().__init__() + self.bot.conversation[self.session_id] = [] + self.current_model = self.bot.engine + self.supported_models = [ + "mistral-large-latest", + "mistral-medium-latest", + "mistral-small-latest", + "open-mixtral-8x7b", + "open-mistral-7b", + ] + + def manage_conversation(self, session_id: str, prompt: str): + if session_id not in self.bot.conversation: + self.bot.conversation[session_id] = [ + {"role": "system", "content": prompt} + ] + self.__conversation_keep_from = 1 + + while self.bot.max_tokens - self.bot.count_tokens(session_id) < config.mistral.mistral_params.min_tokens and \ + len(self.bot.conversation[session_id]) > self.__conversation_keep_from: + self.bot.conversation[session_id].pop(self.__conversation_keep_from) + logger.debug( + f"清理 token,历史记录遗忘后使用 token 数:{str(self.bot.count_tokens(session_id))}" + ) + + async def switch_model(self, model_name): + self.current_model = model_name + self.bot.engine = self.current_model + + async def rollback(self): + if len(self.bot.conversation[self.session_id]) <= 0: + return False + await self.bot.rollback(self.session_id, n=2) + return True + + async def on_reset(self): + self.api_info = botManager.pick('mistral') + self.bot.api_key = self.api_info.api_key + self.bot.proxy = self.api_info.proxy + self.bot.conversation[self.session_id] = [] + self.bot.engine = self.current_model + self.__conversation_keep_from = 0 + + def construct_data(self, messages: list = None, api_key: str = None, stream: bool = True): + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {api_key}' + } + data = { + 'model': self.bot.engine, + 'messages': messages, + 'stream': stream, + 'temperature': self.bot.temperature, + 'top_p': self.bot.top_p, + 'max_tokens': self.bot.get_max_tokens(self.session_id, self.bot.engine), + } + return headers, data + + def _prepare_request(self, session_id: str = None, messages: list = None, stream: bool = False): + self.api_info = botManager.pick('mistral') + api_key = self.api_info.api_key + proxy = self.api_info.proxy + api_endpoint = config.mistral.api_endpoint or "https://api.mistral.ai/v1" + + if not messages: + messages = self.bot.conversation[session_id] + + headers, data = self.construct_data(messages, api_key, stream) + + return proxy, api_endpoint, headers, data + + async def _process_response(self, resp, session_id: str = None): + + result = await resp.json() + + total_tokens = result.get('usage', {}).get('total_tokens', None) + logger.debug(f"[MistralAI-API:{self.bot.engine}] 使用 token 数:{total_tokens}") + if total_tokens is None: + raise Exception("Response does not contain 'total_tokens'") + + content = result.get('choices', [{}])[0].get('message', {}).get('content', None) + logger.debug(f"[MistralAI-API:{self.bot.engine}] 响应:{content}") + if content is None: + raise Exception("Response does not contain 'content'") + + response_role = result.get('choices', [{}])[0].get('message', {}).get('role', None) + if response_role is None: + raise Exception("Response does not contain 'role'") + + self.bot.add_to_conversation(content, response_role, session_id) + + return content + + async def request(self, session_id: str = None, messages: list = None) -> str: + proxy, api_endpoint, headers, data = self._prepare_request(session_id, messages, stream=False) + + async with aiohttp.ClientSession() as session: + with async_timeout.timeout(self.bot.timeout): + async with session.post(f'{api_endpoint}/chat/completions', headers=headers, + data=json.dumps(data), proxy=proxy) as resp: + if resp.status != 200: + response_text = await resp.text() + raise Exception( + f"{resp.status} {resp.reason} {response_text}", + ) + return await self._process_response(resp, session_id) + + async def request_with_stream(self, session_id: str = None, messages: list = None) -> AsyncGenerator[str, None]: + proxy, api_endpoint, headers, data = self._prepare_request(session_id, messages, stream=True) + + async with aiohttp.ClientSession() as session: + with async_timeout.timeout(self.bot.timeout): + async with session.post(f'{api_endpoint}/chat/completions', headers=headers, data=json.dumps(data), + proxy=proxy) as resp: + if resp.status != 200: + response_text = await resp.text() + raise Exception( + f"{resp.status} {resp.reason} {response_text}", + ) + + response_role: str = '' + completion_text: str = '' + + async for line in resp.content: + try: + line = line.decode('utf-8').strip() + if not line.startswith("data: "): + continue + line = line[len("data: "):] + if line == "[DONE]": + break + if not line: + continue + event = json.loads(line) + except json.JSONDecodeError: + raise Exception(f"JSON解码错误: {line}") from None + except Exception as e: + logger.error(f"未知错误: {e}\n响应内容: {resp.content}") + logger.error("请将该段日记提交到项目issue中,以便修复该问题。") + raise Exception(f"未知错误: {e}") from None + if 'error' in event: + raise Exception(f"响应错误: {event['error']}") + if 'choices' in event and len(event['choices']) > 0 and 'delta' in event['choices'][0]: + delta = event['choices'][0]['delta'] + if 'role' in delta: + if delta['role'] is not None: + response_role = delta['role'] + if 'content' in delta: + event_text = delta['content'] + if event_text is not None: + completion_text += event_text + self.latest_role = response_role + yield event_text + self.bot.add_to_conversation(completion_text, response_role, session_id) + + async def compressed_session(self, session_id: str): + if session_id not in self.bot.conversation or not self.bot.conversation[session_id]: + logger.debug(f"不存在该会话,不进行压缩: {session_id}") + return + + if self.bot.count_tokens(session_id) > config.mistral.mistral_params.compressed_tokens: + logger.debug('开始进行会话压缩') + + filtered_data = [entry for entry in self.bot.conversation[session_id] if entry['role'] != 'system'] + self.bot.conversation[session_id] = [entry for entry in self.bot.conversation[session_id] if + entry['role'] not in ['assistant', 'user']] + + filtered_data.append(({"role": "system", + "content": "Summarize the discussion briefly in 200 words or less to use as a prompt for future context."})) + + async for text in self.request_with_stream(session_id=session_id, messages=filtered_data): + pass + + token_count = self.bot.count_tokens(self.session_id, self.bot.engine) + logger.debug(f"压缩会话后使用 token 数:{token_count}") + + async def ask(self, prompt: str) -> AsyncGenerator[str, None]: + """Send a message to api and return the response with stream.""" + + self.manage_conversation(self.session_id, prompt) + + if config.mistral.mistral_params.compressed_session: + await self.compressed_session(self.session_id) + + event_time = None + + try: + if self.bot.engine not in self.supported_models: + logger.warning(f"当前模型非官方支持的模型,请注意控制台输出,当前使用的模型为 {self.bot.engine}") + logger.debug(f"[尝试使用MistralAI-API:{self.bot.engine}] 请求:{prompt}") + self.bot.add_to_conversation(prompt, "user", session_id=self.session_id) + start_time = time.time() + + full_response = '' + + if config.mistral.mistral_params.stream: + async for resp in self.request_with_stream(session_id=self.session_id): + full_response += resp + yield full_response + + token_count = self.bot.count_tokens(self.session_id, self.bot.engine) + logger.debug(f"[MistralAI-API:{self.bot.engine}] 响应:{full_response}") + logger.debug(f"[MistralAI-API:{self.bot.engine}] 使用 token 数:{token_count}") + else: + yield await self.request(session_id=self.session_id) + event_time = time.time() - start_time + if event_time is not None: + logger.debug(f"[MistralAI-API:{self.bot.engine}] 接收到全部消息花费了{event_time:.2f}秒") + + except Exception as e: + logger.error(f"[MistralAI-API:{self.bot.engine}] 请求失败:\n{e}") + yield f"发生错误: \n{e}" + raise + + async def preset_ask(self, role: str, text: str): + self.bot.engine = self.current_model + if role.endswith('bot') or role in {'assistant', 'mistral'}: + logger.debug(f"[预设] 响应:{text}") + yield text + role = 'assistant' + if role not in ['assistant', 'user', 'system']: + raise ValueError(f"预设文本有误!仅支持设定 assistant、user 或 system 的预设文本,但你写了{role}。") + if self.session_id not in self.bot.conversation: + self.bot.conversation[self.session_id] = [] + self.__conversation_keep_from = 0 + self.bot.conversation[self.session_id].append({"role": role, "content": text}) + self.__conversation_keep_from = len(self.bot.conversation[self.session_id]) diff --git a/config.example.cfg b/config.example.cfg index c8d40c2c..10bf66df 100644 --- a/config.example.cfg +++ b/config.example.cfg @@ -33,6 +33,16 @@ alias = 'g4f-chatgpt' # ping bot时针对此AI的描述 description = 'gpt4free的gpt-3.5-turbo' +[mistral] +api_endpoint = "https://api.mistral.ai/v1" +safe_prompt = true +temperature = 0.7 +top_p = 1.0 + +[[mistral.accounts]] +api_key = "" +# proxy="http://127.0.0.1:7890" + [presets] # 切换预设的命令: 加载预设 猫娘 command = "加载预设 (\\w+)" diff --git a/config.py b/config.py index 1886175c..a36e85a7 100644 --- a/config.py +++ b/config.py @@ -279,6 +279,37 @@ class G4fAuths(BaseModel): """支持的模型""" +class MistralAIParams(BaseModel): + temperature: float = 0.7 + max_tokens: int = 4000 + top_p: float = 1.0 + min_tokens: int = 1000 + compressed_session: bool = False + compressed_tokens: int = 1000 + stream: bool = True + + +class MistralAIAPIKey(BaseModel): + api_key: str + """自定义 Mistral API 的Key""" + model: Optional[str] = "mistral-large-latest" + """使用的默认模型,此选项优先级最高""" + proxy: Optional[str] = None + """可选的代理地址,留空则检测系统代理""" + + +class MistralAuths(BaseModel): + api_endpoint: Optional[str] = None + """自定义 Mistral API 的接入点""" + temperature: float = 0.7 + top_p: float = 1.0 + + mistral_params: MistralAIParams = MistralAIParams() + + accounts: List[MistralAIAPIKey] = [] + """MistralAI的账号列表""" + + class SlackAppAccessToken(BaseModel): channel_id: str """负责与机器人交互的 Channel ID""" @@ -563,6 +594,7 @@ class Config(BaseModel): slack: SlackAuths = SlackAuths() xinghuo: XinghuoAuths = XinghuoAuths() gpt4free: G4fAuths = G4fAuths() + mistral: MistralAuths = MistralAuths() # === Response Settings === text_to_image: TextToImage = TextToImage() diff --git a/constants.py b/constants.py index 9cbe58f3..3912f85e 100644 --- a/constants.py +++ b/constants.py @@ -31,6 +31,9 @@ class LlmName(Enum): YiYan = "yiyan" ChatGLM = "chatglm-api" XunfeiXinghuo = "xinghuo" + MistralSmall = "mistral-small-latest" + MistralMedium = "mistral-medium-latest" + MistralLarge = "mistral-large-latest" class BotPlatform(Enum): diff --git a/conversation.py b/conversation.py index 89f07282..162725b1 100644 --- a/conversation.py +++ b/conversation.py @@ -21,6 +21,7 @@ from adapter.quora.poe import PoeBot, PoeAdapter from adapter.thudm.chatglm_6b import ChatGLM6BAdapter from adapter.xunfei.xinghuo import XinghuoAdapter +from adapter.mistral.mistral import MistralAIAPIAdapter from constants import LlmName from constants import config from drawing import DrawingAPI, SDWebUI as SDDrawing, OpenAI as OpenAIDrawing @@ -112,6 +113,12 @@ def __init__(self, _type: str, session_id: str): self.adapter = XinghuoAdapter(self.session_id) elif g4f_parse(_type): self.adapter = Gpt4FreeAdapter(self.session_id, g4f_parse(_type)) + elif _type == LlmName.MistralLarge.value: + self.adapter = MistralAIAPIAdapter(self.session_id) + elif _type == LlmName.MistralMedium.value: + self.adapter = MistralAIAPIAdapter(self.session_id) + elif _type == LlmName.MistralSmall.value: + self.adapter = MistralAIAPIAdapter(self.session_id) else: raise BotTypeNotFoundException(_type) self.type = _type diff --git a/manager/bot.py b/manager/bot.py index 9c42ce82..72bf5f29 100644 --- a/manager/bot.py +++ b/manager/bot.py @@ -29,7 +29,7 @@ from adapter.gpt4free import g4f_helper from chatbot.chatgpt import ChatGPTBrowserChatbot from config import OpenAIAuthBase, OpenAIAPIKey, Config, BingCookiePath, BardCookiePath, YiyanCookiePath, ChatGLMAPI, \ - PoeCookieAuth, SlackAppAccessToken, XinghuoCookiePath, G4fModels + PoeCookieAuth, SlackAppAccessToken, XinghuoCookiePath, G4fModels, MistralAIAPIKey from exceptions import NoAvailableBotException, APIKeyNoFundsError @@ -46,6 +46,7 @@ class BotManager: "xinghuo-cookie": [], "slack-accesstoken": [], "gpt4free": [], + "mistral": [], } """Bot list""" @@ -76,6 +77,9 @@ class BotManager: gpt4free: List[G4fModels] """gpt4free Account Infos""" + mistral: List[MistralAIAPIKey] + """MistralAIAPIKey Account Infos""" + roundrobin: Dict[str, itertools.cycle] = {} def __init__(self, config: Config) -> None: @@ -89,6 +93,7 @@ def __init__(self, config: Config) -> None: self.slack = config.slack.accounts if config.slack else [] self.xinghuo = config.xinghuo.accounts if config.xinghuo else [] self.gpt4free = config.gpt4free.accounts if config.gpt4free else [] + self.mistral = config.mistral.accounts if config.mistral else [] try: os.mkdir('data') @@ -149,6 +154,7 @@ async def login(self): "chatglm-api": [], "slack-accesstoken": [], "gpt4free": [], + "mistral": [], } self.__setup_system_proxy() @@ -162,7 +168,8 @@ async def login(self): 'openai': self.handle_openai, 'yiyan': self.login_yiyan, 'chatglm': self.login_chatglm, - 'gpt4free': self.login_gpt4free + 'gpt4free': self.login_gpt4free, + 'mistral': self.login_mistral, } for key, login_func in login_funcs.items(): @@ -194,6 +201,7 @@ async def login(self): "chatglm-api": "chatglm-api", "xinghuo-cookie": "xinghuo", "gpt4free": self.bots["gpt4free"][0].alias if len(self.bots["gpt4free"]) > 0 else "", + "mistral": "mistral-large-latest", } self.config.response.default_ai = next( @@ -398,6 +406,17 @@ async def login_openai(self): # sourcery skip: raise-specific-error logger.error("所有 OpenAI 账号均登录失败!") logger.success(f"成功登录 {counter}/{len(self.openai)} 个 OpenAI 账号!") + def login_mistral(self): + for i, account in enumerate(self.mistral): + logger.info("正在解析第 {i} 个 MistralAI 账号", i=i + 1) + if proxy := self.__check_proxy(account.proxy): + account.proxy = proxy + self.bots["mistral"].append(account) + logger.success("解析成功!", i=i + 1) + if len(self.bots) < 1: + logger.error("所有 MistralAI 账号均解析失败!") + logger.success(f"成功解析 {len(self.bots['mistral'])}/{len(self.mistral)} 个 MistralAI 账号!") + def __login_browser(self, account) -> ChatGPTBrowserChatbot: logger.info("模式:浏览器登录") logger.info("这需要你拥有最新版的 Chrome 浏览器。") @@ -600,4 +619,8 @@ def bots_info(self): if len(self.bots['gpt4free']) > 0: for model in self.bots['gpt4free']: bot_info += f"* {model.alias} : {model.description}\n" + if len(self.bots['mistral']) > 0: + bot_info += f"* {LlmName.MistralLarge.value} : Mistral Large 模型\n" + bot_info += f"* {LlmName.MistralMedium.value} : Mistral Medium 模型\n" + bot_info += f"* {LlmName.MistralSmall.value} : Mistral Small 模型\n" return bot_info diff --git a/requirements.txt b/requirements.txt index 84b460ae..99ac7d42 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,7 +40,7 @@ tls-client python-dateutil~=2.8.2 regex~=2023.6.3 httpx~=0.24.1 -Quart==0.17.0 +Quart==0.19.4 creart~=0.3.0 pydub~=0.25.1 httpcore~=0.17.3