Skip to content

Commit

Permalink
Fix the workers (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Mar 29, 2023
1 parent d04577b commit b63fdd6
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 51 deletions.
4 changes: 2 additions & 2 deletions chatserver/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CONTROLLER_HEART_BEAT_EXPIRATION = 3 * 60
WORKER_HEART_BEAT_INTERVAL = 60
CONTROLLER_HEART_BEAT_EXPIRATION = 2 * 60
WORKER_HEART_BEAT_INTERVAL = 30

LOGDIR = "."
14 changes: 10 additions & 4 deletions chatserver/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ def get_worker_address(self, model_name: str):
return ""
worker_speeds = worker_speeds / norm

if True: # Directly return address
pt = np.random.choice(np.arange(len(worker_names)),
p=worker_speeds)
worker_name = worker_names[pt]
return worker_name

# Check status before returning
while True:
pt = np.random.choice(np.arange(len(worker_names)),
p=worker_speeds)
Expand All @@ -121,7 +128,7 @@ def get_worker_address(self, model_name: str):
break
else:
self.remove_worker(worker_name)
self.worker_speeds[pt] = 0
worker_speeds[pt] = 0
norm = np.sum(worker_speeds)
if norm < 1e-4:
return ""
Expand Down Expand Up @@ -150,7 +157,6 @@ def remove_stable_workers_by_expiration(self):
self.remove_worker(worker_name)

def worker_api_generate_stream(self, params):
headers = {"User-Agent": "ChatServer Client"}
worker_addr = self.get_worker_address(params["model"])
if not worker_addr:
ret = {
Expand All @@ -159,9 +165,9 @@ def worker_api_generate_stream(self, params):
}
yield (json.dumps(ret) + "\0").encode("utf-8")

response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
response = requests.post(worker_addr + "/worker_generate_stream",
json=params, stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield chunk + b"\0"

Expand Down
61 changes: 37 additions & 24 deletions chatserver/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

logger = build_logger("gradio_web_server", "gradio_web_server.log")

headers = {"User-Agent": "ChatServer Client"}

upvote_msg = "👍 Upvote the last response"
downvote_msg = "👎 Downvote the last response"

Expand Down Expand Up @@ -46,6 +48,7 @@ def load_demo(request: gr.Request):
logger.info(f"load demo: {request.client.host}")
state = default_conversation.copy()
return (state,
gr.Dropdown.update(visible=True),
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Row.update(visible=True),
Expand Down Expand Up @@ -156,7 +159,6 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
prompt = state.get_prompt()

# Make requests
headers = {"User-Agent": "Client"}
pload = {
"model": model_name,
"prompt": prompt,
Expand All @@ -165,21 +167,33 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
}
logger.info(f"==== request ====\n{pload}")
response = requests.post(worker_addr + "/worker_generate_stream",
headers=headers, json=pload, stream=True)

# Stream output
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
if data["error_code"] == 0:
output = data["text"][len(prompt) + 2:]
state.messages[-1][-1] = output
yield state, state.to_gradio_chatbot()
else:
output = data["text"]
state.messages[-1][-1] = output
yield state, state.to_gradio_chatbot()

state.messages[-1][-1] = "▌"
yield state, state.to_gradio_chatbot()

try:
# Stream output
response = requests.post(worker_addr + "/worker_generate_stream",
headers=headers, json=pload, stream=True, timeout=10)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][len(prompt) + 2:]
state.messages[-1][-1] = output + "▌"
yield state, state.to_gradio_chatbot()
else:
output = data["text"]
state.messages[-1][-1] = output + "▌"
yield state, state.to_gradio_chatbot()
time.sleep(0.05)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg
yield state, state.to_gradio_chatbot()
return

state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield state, state.to_gradio_chatbot()

finish_tstamp = time.time()
logger.info(f"{output}")
Expand Down Expand Up @@ -209,11 +223,11 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req


learn_more_markdown = ("""
# Links
- Support this project by staring ChatServer on [Github](https://github.com/lm-sys/ChatServer).
- Read the blog [post]() about this model.
### Learn More
- Support this project by starting ChatServer on [Github](https://github.com/lm-sys/ChatServer).
- Read this blog [post]() about the Vicuna model.
# License
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMa and [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI. Please contact us if you find any potential violation.
""")

Expand All @@ -232,7 +246,7 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
def build_demo():
models = get_model_list()

with gr.Blocks(title="Chat Server", theme=gr.themes.Soft(), css=css) as demo:
with gr.Blocks(title="Chat Server", theme=gr.themes.Base(), css=css) as demo:
state = gr.State()

# Draw layout
Expand All @@ -259,8 +273,7 @@ def build_demo():
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Temperature",)
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)

with gr.Accordion("Learn more", open=False) as learn_more_row:
gr.Markdown(learn_more_markdown)
gr.Markdown(learn_more_markdown)

# Register listeners
upvote_btn.click(upvote_last_response,
Expand All @@ -281,7 +294,7 @@ def build_demo():
[state, chatbot])

if args.model_list_mode == "once":
demo.load(load_demo, None, [state,
demo.load(load_demo, None, [state, model_selector,
chatbot, textbox, button_row, parameter_row])
elif args.model_list_mode == "reload":
demo.load(load_demo_refresh_model_list, None, [state, model_selector,
Expand Down
36 changes: 20 additions & 16 deletions chatserver/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import threading
import uuid

from fastapi import FastAPI, Request
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
import requests
from transformers import AutoTokenizer, AutoModelForCausalLM
Expand All @@ -25,6 +25,7 @@

worker_id = str(uuid.uuid4())[:6]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
global_counter = 0


def heart_beat_worker(controller):
Expand Down Expand Up @@ -139,17 +140,15 @@ def generate_stream(self, params):
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + 1).cuda()
out = model(input_ids=torch.as_tensor([[token]]).cuda(),
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device="cuda")
out = model(input_ids=torch.as_tensor([[token]], device="cuda"),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values

assert out.hidden_states is None
assert out.attentions is None

last_token_logits = logits[0][-1]
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
Expand All @@ -173,14 +172,14 @@ def generate_stream(self, params):
"text": output,
"error_code": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8")
yield json.dumps(ret).encode() + b"\0"

if stopped:
break

del past_key_values

def generate_stream_gate(self, params, release_semaphore):
def generate_stream_gate(self, params):
try:
for x in self.generate_stream(params):
yield x
Expand All @@ -189,25 +188,30 @@ def generate_stream_gate(self, params, release_semaphore):
"text": server_error_msg,
"error_code": 1,
}
yield (json.dumps(ret) + "\0").encode("utf-8")
if release_semaphore:
release_semaphore.release()
yield json.dumps(ret).encode() + b"\0"


app = FastAPI()
model_semaphore = None


def release_model_semaphore():
model_semaphore.release()


@app.post("/worker_generate_stream")
async def generate_stream(request: Request):
global model_semaphore
global model_semaphore, global_counter
global_counter += 1
params = await request.json()

if model_semaphore is None:
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
await model_semaphore.acquire()

generator = worker.generate_stream_gate(params, model_semaphore)
return StreamingResponse(generator)
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(generator, background=background_tasks)


@app.post("/worker_get_status")
Expand All @@ -227,7 +231,7 @@ async def get_status(request: Request):
parser.add_argument("--model-name", type=str)
parser.add_argument("--num-gpus", type=int, default=1)
parser.add_argument("--limit-model-concurrency", type=int, default=4)
parser.add_argument("--stream-interval", type=int, default=4)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
args = parser.parse_args()

Expand Down
2 changes: 1 addition & 1 deletion chatserver/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from chatserver.constants import LOGDIR

server_error_msg = "**NETWORK ERROR. PLEASE REFRESH THIS PAGE.**"
server_error_msg = "**NETWORK ERROR. PLEASE REGENERATE OR REFRESH THIS PAGE.**"

handler = None

Expand Down
8 changes: 4 additions & 4 deletions docs/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ python3 -m chatserver.serve.gradio_web_server --controller http://localhost:2100

#### Local GPU cluster (node-02)
```
CUDA_VISIBLE_DEVICES=0 python3 -m chatserver.serve.model_worker --model-path ~/model_weights/vicuna-13b/ --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker http://node-02:31000
CUDA_VISIBLE_DEVICES=1 python3 -m chatserver.serve.model_worker --model-path ~/model_weights/vicuna-13b/ --controller http://node-01:10002 --host 0.0.0.0 --port 31001 --worker http://node-02:31001
CUDA_VISIBLE_DEVICES=2 python3 -m chatserver.serve.model_worker --model-path ~/model_weights/vicuna-13b/ --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker http://node-02:31002
CUDA_VISIBLE_DEVICES=3 python3 -m chatserver.serve.model_worker --model-path ~/model_weights/vicuna-13b/ --controller http://node-01:10002 --host 0.0.0.0 --port 31003 --worker http://node-02:31003
CUDA_VISIBLE_DEVICES=0 python3 -m chatserver.serve.model_worker --model-path ~/model_weights/vicuna-13b/ --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker http://$(hostname):31000
CUDA_VISIBLE_DEVICES=1 python3 -m chatserver.serve.model_worker --model-path ~/model_weights/vicuna-13b/ --controller http://node-01:10002 --host 0.0.0.0 --port 31001 --worker http://$(hostname):31001
CUDA_VISIBLE_DEVICES=2 python3 -m chatserver.serve.model_worker --model-path ~/model_weights/vicuna-13b/ --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker http://$(hostname):31002
CUDA_VISIBLE_DEVICES=3 python3 -m chatserver.serve.model_worker --model-path ~/model_weights/vicuna-13b/ --controller http://node-01:10002 --host 0.0.0.0 --port 31003 --worker http://$(hostname):31003
```

### Host a gradio web server
Expand Down

0 comments on commit b63fdd6

Please sign in to comment.