Skip to content

Commit

Permalink
Support BAIR chat model (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Mar 23, 2023
1 parent 249fe85 commit fd02d2a
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 41 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pip3 install git+https://github.com/huggingface/transformers
```

## Serving

### Web UI
```
# Launch a controller
python3 -m chatserver.serve.controller
Expand All @@ -28,9 +30,13 @@ python3 -m chatserver.serve.gradio_web_server
# You can open your brower and chat with a model now.
```

### Command Line Interface
```
python3 -m chatserver.serve.cli --model facebook/opt-350m
```

## Deploy Chatbot on Any Cloud with SkyPilot
### Training on ShareGPT dataset
### Training on ShareGPT Dataset
1. Install skypilot and setup the credentials locally following the instructions [here](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html)
```
# Need this version of skypilot, for the fix of `--env` flag.
Expand Down
62 changes: 53 additions & 9 deletions chatserver/conversation.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,45 @@
import dataclasses
from enum import auto, Enum
from typing import List, Tuple


class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()


@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None

def get_prompt(self):
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")

def append_message(self, role, message):
self.messages.append([role, message])
Expand All @@ -37,7 +59,9 @@ def copy(self):
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep=self.sep)
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2)

def dict(self):
return {
Expand All @@ -46,10 +70,11 @@ def dict(self):
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}


default_conversation = Conversation(
conv_v1 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
Expand All @@ -72,8 +97,27 @@ def dict(self):
"help improve the quality of your sleep.")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)

conv_bair_v1 = Conversation(
system="BEGINNING OF CONVERSATION:",
roles=("USER", "GPT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)


default_conversation = conv_v1
conv_templates = {
"v1": conv_v1,
"bair_v1": conv_bair_v1,
}


if __name__ == "__main__":
print(default_conversation.get_prompt())
40 changes: 30 additions & 10 deletions chatserver/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,42 @@
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

from chatserver.conversation import default_conversation
from chatserver.conversation import conv_templates, SeparatorStyle
from chatserver.utils import disable_torch_init


@torch.inference_mode()
def main(args):
model_name = args.model_name
num_gpus = args.num_gpus

# Model
disable_torch_init()
if num_gpus == 1:
kwargs = {}
else:
kwargs = {
"device_map": "auto",
"max_memory": {i: "13GiB" for i in range(num_gpus)},
}

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,
torch_dtype=torch.float16).cuda()
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, **kwargs)

conv = default_conversation.copy()
if num_gpus == 1:
model.cuda()

# Chat
conv = conv_templates[args.conv_template].copy()
while True:
inp = input(f"{conv.roles[0]}: ")
if not inp:
print("exit...")
break

conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
output_ids = model.generate(
Expand All @@ -33,19 +46,26 @@ def main(args):
temperature=0.7,
max_new_tokens=256)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
sep = conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2
try:
index = outputs.index(conv.sep, len(prompt))
index = outputs.index(sep, len(prompt))
except ValueError:
outputs += conv.sep
index = outputs.index(conv.sep, len(prompt))
outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
outputs += sep
index = outputs.index(sep, len(prompt))

outputs = outputs[len(prompt) + 2:index].strip()
print(f"{conv.roles[1]}: {outputs}")
conv.append_message(conv.roles[1], outputs)
conv.messages[-1][-1] = outputs

if args.debug:
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--num-gpus", type=int, default=1)
parser.add_argument("--conv-template", type=str, default="v1")
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
main(args)
22 changes: 16 additions & 6 deletions chatserver/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import gradio as gr
import requests

from chatserver.conversation import default_conversation
from chatserver.conversation import (default_conversation, conv_templates,
SeparatorStyle)
from chatserver.constants import LOGDIR
from chatserver.utils import build_logger
from chatserver.serve.gradio_patch import Chatbot as grChatbot
Expand All @@ -18,7 +19,6 @@

upvote_msg = "👍 Upvote the last response"
downvote_msg = "👎 Downvote the last response"
init_prompt = default_conversation.get_prompt()

priority = {
}
Expand Down Expand Up @@ -107,10 +107,21 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
start_tstamp = time.time()

if len(state.messages) == state.offset:
# skip empty "Regenerate"
# Skip empty "Regenerate"
yield state, state.to_gradio_chatbot()
return

if len(state.messages) == state.offset + 2:
# First round of conversation
if "bair-chat" in model_selector: # Hardcode the condition
template_name = "bair_v1"
else:
template_name = "v1"
new_state = conv_templates[template_name].copy()
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state

# Query worker address
controller_url = args.controller_url
ret = requests.post(controller_url + "/get_worker_address",
Expand All @@ -126,15 +137,14 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req

# Construct prompt
prompt = state.get_prompt()
txt = prompt.replace(state.sep, '\n')

# Make requests
headers = {"User-Agent": "Alpa Client"}
headers = {"User-Agent": "Client"}
pload = {
"prompt": prompt,
"temperature": float(temperature),
"max_new_tokens": int(max_new_tokens),
"stop": state.sep,
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
}
logger.info(f"==== request ====\n{pload}")
response = requests.post(worker_addr + "/generate_stream",
Expand Down
12 changes: 3 additions & 9 deletions chatserver/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,9 @@ def load_model(model_name, num_gpus):
"max_memory": {i: "13GiB" for i in range(num_gpus)},
}

if model_name == "facebook/llama-7b":
hf_model_name = "/home/ubuntu/llama_weights/hf-llama-7b/"
tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
model = AutoModelForCausalLM.from_pretrained(
hf_model_name, torch_dtype=torch.float16, **kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, **kwargs)

if num_gpus == 1:
model.cuda()
Expand Down
12 changes: 6 additions & 6 deletions docs/commands.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
### Launch a service with three models on 8 V100 (16 GB) GPUs.
### Launch a service with four models on 8 V100 (16 GB) GPUs.
```
# Launch a controller
python3 -m chatserver.serve.controller
# Launch model workers
CUDA_VISIBLE_DEVICES=4 python3 -m chatserver.serve.model_worker --model facebook/opt-350m --port 21004 --worker-address http://localhost:21004
CUDA_VISIBLE_DEVICES=2 python3 -m chatserver.serve.model_worker --model /home/ubuntu/model_weights/hf-llama-7b --port 21002 --worker-address http://localhost:21002
CUDA_VISIBLE_DEVICES=5 python3 -m chatserver.serve.model_worker --model facebook/opt-6.7b --port 21005 --worker-address http://localhost:21005
CUDA_VISIBLE_DEVICES=3 python3 -m chatserver.serve.model_worker --model /home/ubuntu/model_weights/alpaca-7b --port 21003 --worker-address http://localhost:21003
CUDA_VISIBLE_DEVICES=6,7 python3 -m chatserver.serve.model_worker --model facebook/llama-7b --port 21006 --worker-address http://localhost:21006 --num-gpus 2
CUDA_VISIBLE_DEVICES=4,5 python3 -m chatserver.serve.model_worker --model /home/ubuntu/model_weights/alpaca-13b --port 21004 --worker-address http://localhost:21004 --num-gpus 2
CUDA_VISIBLE_DEVICES=6,7 python3 -m chatserver.serve.model_worker --model /home/ubuntu/model_weights/bair-chat-13b --port 21006 --worker-address http://localhost:21006 --num-gpus 2
# Luanch a gradio web server.
python3 -m chatserver.serve.gradio_web_server
```


### Host a gradio web server
```
sudo apt update
Expand Down

0 comments on commit fd02d2a

Please sign in to comment.