From 8b4cf6db9758521a6bbf8288300cb6668bbc1fd3 Mon Sep 17 00:00:00 2001 From: Wizerd Date: Sat, 3 Feb 2024 13:37:13 +0800 Subject: [PATCH] =?UTF-8?q?[fix]=20=E4=BC=98=E5=8C=96wss=E8=BF=9E=E6=8E=A5?= =?UTF-8?q?=E6=97=B6=E6=9C=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 134 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 74 insertions(+), 60 deletions(-) diff --git a/main.py b/main.py index 6d555ad..4e4b706 100644 --- a/main.py +++ b/main.py @@ -197,9 +197,9 @@ def generate_gpts_payload(model, messages): # PANDORA_UPLOAD_URL = 'files.pandoranext.com' -VERSION = '0.7.3' +VERSION = '0.7.4' # VERSION = 'test' -UPDATE_INFO = '优化检测是否为sse的方式' +UPDATE_INFO = '优化wss连接时机' # UPDATE_INFO = '【仅供临时测试使用】 ' with app.app_context(): @@ -1572,11 +1572,16 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m "file_output_buffer": "", "file_output_accumulating": False, "execution_output_image_url_buffer": "", - "execution_output_image_id_buffer": "" + "execution_output_image_id_buffer": "", + "is_sse": False, + "upstream_response": None, + "messages": messages, + "api_key": api_key, + "model": model } def on_message(ws, message): - # logger.debug(f"on_message: {message}") + logger.debug(f"on_message: {message}") if stop_event.is_set(): logger.info(f"接受到停止信号,停止 Websocket 处理线程") ws.close() @@ -1611,8 +1616,7 @@ def on_message(ws, message): q_data = complete_data data_queue.put(q_data) stop_event.set() - ws.close() - + ws.close() def on_error(ws, error): logger.error(error) @@ -1621,66 +1625,76 @@ def on_close(ws, b, c): logger.debug("wss closed") def on_open(ws): + logger.debug(f"on_open: wss") + upstream_response = send_text_prompt_and_get_response(context["messages"], context["api_key"], True, context["model"]) + # upstream_wss_url = None + # 检查 Content-Type 是否为 SSE 响应 + content_type = upstream_response.headers.get('Content-Type') + logger.debug(f"Content-Type: {content_type}") + # 判断content_type是否包含'text/event-stream' + if content_type and 'text/event-stream' in content_type: + logger.debug("上游响应为 SSE 响应") + context["is_sse"] = True + context["upstream_response"] = upstream_response + ws.close() + return + else: + if upstream_response.status_code != 200: + logger.error(f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}") + complete_data = 'data: [DONE]\n\n' + timestamp = context["timestamp"] + + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join("```json\n{\n\"error\": \"Upstream error...\"\n}\n```") + }, + "finish_reason": None + } + ] + } + q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + data_queue.put(q_data) + + q_data = complete_data + data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Upstream error...\"\n}\n```")) + data_queue.put(q_data) + stop_event.set() + ws.close() + try: + upstream_response_json = upstream_response.json() + logger.debug(f"upstream_response_json: {upstream_response_json}") + # upstream_wss_url = upstream_response_json.get("wss_url", None) + upstream_response_id = upstream_response_json.get("response_id", None) + context["response_id"] = upstream_response_id + except json.JSONDecodeError: + pass def run(*args): - logger.debug(f"on_open: wss") while True: if stop_event.is_set(): + logger.debug(f"接受到停止信号,停止 Websocket") ws.close() break - upstream_response = send_text_prompt_and_get_response(messages, api_key, True, model) - upstream_wss_url = None - # 检查 Content-Type 是否为 SSE 响应 - content_type = upstream_response.headers.get('Content-Type') - # 判断content_type是否包含'text/event-stream' - if content_type and 'text/event-stream' in content_type: - logger.debug("上游响应为 SSE 响应") - old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format) - else: - if upstream_response.status_code != 200: - logger.error(f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}") - complete_data = 'data: [DONE]\n\n' - timestamp = context["timestamp"] - new_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join("```json\n{\n\"error\": \"Upstream error...\"\n}\n```") - }, - "finish_reason": None - } - ] - } - q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - data_queue.put(q_data) - - q_data = complete_data - data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Upstream error...\"\n}\n```")) - data_queue.put(q_data) - stop_event.set() - return - try: - upstream_response_json = upstream_response.json() - upstream_wss_url = upstream_response_json.get("wss_url", None) - upstream_response_id = upstream_response_json.get("response_id", None) - context["response_id"] = upstream_response_id - except json.JSONDecodeError: - pass - if upstream_wss_url is not None: - logger.debug(f"start wss...") - ws = websocket.WebSocketApp(wss_url, - on_message = on_message, - on_error = on_error, - on_close = on_close) - ws.on_open = on_open - ws.run_forever() - - logger.debug(f"end wss...") + logger.debug(f"start wss...") + ws = websocket.WebSocketApp(wss_url, + on_message = on_message, + on_error = on_error, + on_close = on_close, + on_open = on_open) + ws.on_open = on_open + ws.run_forever() + + logger.debug(f"end wss...") + if context["is_sse"] == True: + logger.debug(f"process sse...") + old_data_fetcher(context["upstream_response"], data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format)