diff --git a/docs/commands/webserver.md b/docs/commands/webserver.md index fc304b86..91675b8f 100644 --- a/docs/commands/webserver.md +++ b/docs/commands/webserver.md @@ -20,10 +20,9 @@ pip3 install git+https://github.com/huggingface/transformers ### Launch servers ``` python3 -m fastchat.serve.controller --host 0.0.0.0 --port 21001 - python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name https:// - python3 -m fastchat.serve.test_message --model vicuna-13b --controller http://localhost:21001 -python3 -m fastchat.serve.gradio_web_server --controller http://localhost:21001 +export OPENAI_API_KEY= +python3 -m fastchat.serve.gradio_web_server --controller http://localhost:21001 --moderate --concurrency 20 ``` diff --git a/fastchat/serve/controller.py b/fastchat/serve/controller.py index 22b5c779..9f9e4578 100644 --- a/fastchat/serve/controller.py +++ b/fastchat/serve/controller.py @@ -5,6 +5,7 @@ import argparse import asyncio import dataclasses +import json import logging import time from typing import List, Union @@ -17,7 +18,7 @@ import uvicorn from fastchat.constants import CONTROLLER_HEART_BEAT_EXPIRATION -from fastchat.utils import build_logger +from fastchat.utils import build_logger, server_error_msg logger = build_logger("controller", "controller.log") @@ -116,6 +117,7 @@ def get_worker_address(self, model_name: str): pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) worker_name = worker_names[pt] + #logger.info(f"speeds: {worker_speeds}, pt: {pt}, worker_name: {worker_name}") return worker_name # Check status before returning @@ -159,17 +161,27 @@ def remove_stable_workers_by_expiration(self): def worker_api_generate_stream(self, params): worker_addr = self.get_worker_address(params["model"]) if not worker_addr: + logger.info(f"no worker: {params['model']}") ret = { "text": server_error_msg, "error_code": 2, } - yield (json.dumps(ret) + "\0").encode("utf-8") + yield json.dumps(ret).encode() + b"\0" + + try: + response = requests.post(worker_addr + "/worker_generate_stream", + json=params, stream=True, timeout=5) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk + b"\0" + except requests.exceptions.RequestException as e: + logger.info(f"worker timeout: {worker_addr}") + ret = { + "text": server_error_msg, + "error_code": 3, + } + yield json.dumps(ret).encode() + b"\0" - response = requests.post(worker_addr + "/worker_generate_stream", - json=params, stream=True) - for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - yield chunk + b"\0" # Let the controller act as a worker to achieve hierarchical # management. This can be used to connect isolated sub networks. diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index 2b56943e..531130c1 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -219,14 +219,15 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req if chunk: data = json.loads(chunk.decode()) if data["error_code"] == 0: - output = data["text"][len(prompt) + 2:] + output = data["text"][len(prompt) + 1:].strip() output = post_process_code(output) state.messages[-1][-1] = output + "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 else: output = data["text"] - state.messages[-1][-1] = output + "▌" + state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return time.sleep(0.04) except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg @@ -304,8 +305,12 @@ def build_demo(): show_label=False).style(container=False) chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) - textbox = gr.Textbox(show_label=False, - placeholder="Enter text and press ENTER", visible=False).style(container=False) + with gr.Row(): + with gr.Column(scale=10): + textbox = gr.Textbox(show_label=False, + placeholder="Enter text and press ENTER", visible=False).style(container=False) + with gr.Column(scale=1, min_width=60): + submit_btn = gr.Button(value="Submit") with gr.Row(visible=False) as button_row: upvote_btn = gr.Button(value="👍 Upvote", interactive=False) @@ -339,6 +344,9 @@ def build_demo(): textbox.submit(add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then(http_bot, [state, model_selector, temperature, max_output_tokens], [state, chatbot] + btn_list) + submit_btn.click(add_text, [state, textbox], [state, chatbot, textbox] + btn_list + ).then(http_bot, [state, model_selector, temperature, max_output_tokens], + [state, chatbot] + btn_list) if args.model_list_mode == "once": demo.load(load_demo, [url_params], [state, model_selector, @@ -367,6 +375,7 @@ def build_demo(): models = get_model_list() + logger.info(args) demo = build_demo() demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False).launch( diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index 63585561..ffcdd8ea 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -95,7 +95,9 @@ def register_to_controller(self): assert r.status_code == 200 def send_heart_beat(self): - logger.info(f"Send heart beat. Models: {[self.model_name]}") + logger.info(f"Send heart beat. Models: {[self.model_name]}. " + f"Semaphore: {model_semaphore}. " + f"global_counter: {global_counter}") url = self.controller_addr + "/receive_heart_beat" try: