Skip to content

Commit

Permalink
[fix] 优化wss连接时机
Browse files Browse the repository at this point in the history
  • Loading branch information
Wizerd authored and Wizerd committed Feb 3, 2024
1 parent 7a28cf6 commit 8b4cf6d
Showing 1 changed file with 74 additions and 60 deletions.
134 changes: 74 additions & 60 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ def generate_gpts_payload(model, messages):
# PANDORA_UPLOAD_URL = 'files.pandoranext.com'


VERSION = '0.7.3'
VERSION = '0.7.4'
# VERSION = 'test'
UPDATE_INFO = '优化检测是否为sse的方式'
UPDATE_INFO = '优化wss连接时机'
# UPDATE_INFO = '【仅供临时测试使用】 '

with app.app_context():
Expand Down Expand Up @@ -1572,11 +1572,16 @@ def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_m
"file_output_buffer": "",
"file_output_accumulating": False,
"execution_output_image_url_buffer": "",
"execution_output_image_id_buffer": ""
"execution_output_image_id_buffer": "",
"is_sse": False,
"upstream_response": None,
"messages": messages,
"api_key": api_key,
"model": model
}

def on_message(ws, message):
# logger.debug(f"on_message: {message}")
logger.debug(f"on_message: {message}")
if stop_event.is_set():
logger.info(f"接受到停止信号,停止 Websocket 处理线程")
ws.close()
Expand Down Expand Up @@ -1611,8 +1616,7 @@ def on_message(ws, message):
q_data = complete_data
data_queue.put(q_data)
stop_event.set()
ws.close()

ws.close()

def on_error(ws, error):
logger.error(error)
Expand All @@ -1621,66 +1625,76 @@ def on_close(ws, b, c):
logger.debug("wss closed")

def on_open(ws):
logger.debug(f"on_open: wss")
upstream_response = send_text_prompt_and_get_response(context["messages"], context["api_key"], True, context["model"])
# upstream_wss_url = None
# 检查 Content-Type 是否为 SSE 响应
content_type = upstream_response.headers.get('Content-Type')
logger.debug(f"Content-Type: {content_type}")
# 判断content_type是否包含'text/event-stream'
if content_type and 'text/event-stream' in content_type:
logger.debug("上游响应为 SSE 响应")
context["is_sse"] = True
context["upstream_response"] = upstream_response
ws.close()
return
else:
if upstream_response.status_code != 200:
logger.error(f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}")
complete_data = 'data: [DONE]\n\n'
timestamp = context["timestamp"]

new_data = {
"id": chat_message_id,
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": ''.join("```json\n{\n\"error\": \"Upstream error...\"\n}\n```")
},
"finish_reason": None
}
]
}
q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n'
data_queue.put(q_data)

q_data = complete_data
data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Upstream error...\"\n}\n```"))
data_queue.put(q_data)
stop_event.set()
ws.close()
try:
upstream_response_json = upstream_response.json()
logger.debug(f"upstream_response_json: {upstream_response_json}")
# upstream_wss_url = upstream_response_json.get("wss_url", None)
upstream_response_id = upstream_response_json.get("response_id", None)
context["response_id"] = upstream_response_id
except json.JSONDecodeError:
pass
def run(*args):
logger.debug(f"on_open: wss")
while True:
if stop_event.is_set():
logger.debug(f"接受到停止信号,停止 Websocket")
ws.close()
break
upstream_response = send_text_prompt_and_get_response(messages, api_key, True, model)
upstream_wss_url = None
# 检查 Content-Type 是否为 SSE 响应
content_type = upstream_response.headers.get('Content-Type')
# 判断content_type是否包含'text/event-stream'
if content_type and 'text/event-stream' in content_type:
logger.debug("上游响应为 SSE 响应")
old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format)
else:
if upstream_response.status_code != 200:
logger.error(f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}")
complete_data = 'data: [DONE]\n\n'
timestamp = context["timestamp"]

new_data = {
"id": chat_message_id,
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": ''.join("```json\n{\n\"error\": \"Upstream error...\"\n}\n```")
},
"finish_reason": None
}
]
}
q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n'
data_queue.put(q_data)

q_data = complete_data
data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Upstream error...\"\n}\n```"))
data_queue.put(q_data)
stop_event.set()
return
try:
upstream_response_json = upstream_response.json()
upstream_wss_url = upstream_response_json.get("wss_url", None)
upstream_response_id = upstream_response_json.get("response_id", None)
context["response_id"] = upstream_response_id
except json.JSONDecodeError:
pass
if upstream_wss_url is not None:
logger.debug(f"start wss...")
ws = websocket.WebSocketApp(wss_url,
on_message = on_message,
on_error = on_error,
on_close = on_close)
ws.on_open = on_open
ws.run_forever()

logger.debug(f"end wss...")
logger.debug(f"start wss...")
ws = websocket.WebSocketApp(wss_url,
on_message = on_message,
on_error = on_error,
on_close = on_close,
on_open = on_open)
ws.on_open = on_open
ws.run_forever()

logger.debug(f"end wss...")
if context["is_sse"] == True:
logger.debug(f"process sse...")
old_data_fetcher(context["upstream_response"], data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format)



Expand Down

0 comments on commit 8b4cf6d

Please sign in to comment.