Skip to content

Commit

Permalink
Improve prompts & fix OOM (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Mar 20, 2023
1 parent 2a62af7 commit 79049a1
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 17 deletions.
7 changes: 4 additions & 3 deletions chatserver/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,16 @@ def copy(self):
sep=self.sep)



default_conversation = Conversation(
system="A chat between a curious human and a knowledgeable artificial intelligence assistant.",
roles=("Human", "Assistant"),
messages=(
("Human", "Hello! What can you do?"),
("Assistant", "As an AI assistant, I can answer questions and chat with you."),
("Human", "What is the name of the tallest mountain in the world?"),
("Assistant", "Everest."),
("Human", "Give three tips for staying healthy."),
("Assistant", "1. Eat a balanced diet and make sure to include plenty of fruits and vegetables.\n"
"2. Exercise regularly to keep your body active and strong.\n"
"3. Get enough sleep and maintain a consistent sleep schedule."),
)
)

Expand Down
15 changes: 13 additions & 2 deletions chatserver/server/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def heart_beat_controller(controller):

while True:
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
controller.remove_stable_workers()
controller.remove_stable_workers_by_expiration()


class Controller:
Expand Down Expand Up @@ -68,6 +68,8 @@ def register_model_worker(self, model_name: str, worker_name: str):

logger.info(f"Register new. {(model_name, worker_name)}")

self.remove_stable_workers_by_checking()

def get_worker_address(self, model_name: str):
if model_name not in self.model_info:
return ""
Expand Down Expand Up @@ -110,7 +112,7 @@ def receive_heart_beat(self, worker_name: str):
logger.info(f"Receive heart beat. {worker_name}")
return True

def remove_stable_workers(self):
def remove_stable_workers_by_expiration(self):
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
to_delete = []
for worker_name, w_info in self.worker_info.items():
Expand All @@ -120,6 +122,15 @@ def remove_stable_workers(self):
for worker_name in to_delete:
self.remove_worker(worker_name)

def remove_stable_workers_by_checking(self):
to_delete = []
for worker_name in self.worker_info:
if not self.check_worker_status(worker_name):
to_delete.append(worker_name)

for worker_name in to_delete:
self.remove_worker(worker_name)

def list_models(self):
models = []
for model, m_info in self.model_info.items():
Expand Down
21 changes: 15 additions & 6 deletions chatserver/server/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,44 @@ def http_bot(history, model_selector):
worker_addr = ret.json()["address"]
print(f"worker_addr: {worker_addr}")

# Fix some bugs in gradio UI
for i in range(len(history)):
history[i][0] = history[i][0].replace("<br>", "")
if history[i][1]:
history[i][1] = history[i][1].replace("<br>", "")

# No available worker
if worker_addr == "":
history[-1][-1] = "**NETWORK ERROR. PLEASE TRY AGAIN OR CHOOSE OTHER MODELS.**"
yield history
return

# Construct prompt
conv = default_conversation.copy()
conv.append_gradio_chatbot_history(history)
prompt = conv.get_prompt()

txt = prompt.replace(conv.sep, '\n')
print(f"==== Conversation ====\n{txt}")

# Make requests
headers = {"User-Agent": "Alpa Client"}
pload = {
"prompt": prompt,
"max_new_tokens": 64,
"temperature": 0.8,
"max_new_tokens": 512,
"temperature": 0.7,
"stop": conv.sep,
}
response = requests.post(worker_addr + "/generate_stream",
headers=headers, json=pload, stream=True)

# Stream output
sep = f"{conv.sep}{conv.roles[1]}: "
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"].split(sep)[-1]
history[-1][-1] = output
yield history

print(f"{output}")


Expand All @@ -67,7 +75,7 @@ def http_bot(history, model_selector):

def build_demo(models):
models.sort(key=lambda x: priority[x])
css = """#model_selector_row {width: 300px;}"""
css = """#model_selector_row {width: 350px;}"""

with gr.Blocks(title="Chat Server", css=css) as demo:
gr.Markdown(
Expand Down Expand Up @@ -103,6 +111,7 @@ def build_demo(models):
parser.add_argument("--port", type=int)
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
parser.add_argument("--concurrency-count", type=int, default=2)
parser.add_argument("--share", action="store_true")
args = parser.parse_args()

ret = requests.post(args.controller_url + "/list_models")
Expand All @@ -111,4 +120,4 @@ def build_demo(models):

demo = build_demo(models)
demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10).launch(
server_name=args.host, server_port=args.port)
server_name=args.host, server_port=args.port, share=args.share)
27 changes: 22 additions & 5 deletions chatserver/server/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from chatserver.server.constants import WORKER_HEART_BEAT_INTERVAL

GB = 1 << 30

logger = logging.getLogger("model_worker")

Expand Down Expand Up @@ -54,16 +55,19 @@ def load_model(model_name, num_gpus):
model = AutoModelForCausalLM.from_pretrained(
hf_model_name + "llama-7b/", torch_dtype=torch.float16, **kwargs)
else:
hf_model_name = model_name

tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
hf_model_name, torch_dtype=torch.float16, **kwargs)
model_name, torch_dtype=torch.float16, **kwargs)

if num_gpus == 1:
model.cuda()

return tokenizer, model, 2048
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048

return tokenizer, model, context_len


class ModelWorker:
Expand All @@ -72,6 +76,7 @@ def __init__(self, controller_addr, worker_addr, model_name, num_gpus):
self.worker_addr = worker_addr
self.model_name = model_name

logger.info("Loading the model...")
self.tokenizer, self.model, self.context_len = load_model(model_name, num_gpus)

self.register_to_controller()
Expand Down Expand Up @@ -99,7 +104,12 @@ def send_heart_beat(self):
if not exist:
self.register_to_controller()

@torch.inference_mode()
def generate_stream(self, args):
#cur_mem = torch.cuda.memory_allocated()
#max_mem = torch.cuda.max_memory_allocated()
#logging.info(f"cur mem: {cur_mem/GB:.2f} GB, max_mem: {max_mem/GB:.2f} GB")

tokenizer, model = self.tokenizer, self.model

context = args["prompt"]
Expand Down Expand Up @@ -132,6 +142,11 @@ def generate_stream(self, args):
probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))

assert out.hidden_states is None
assert out.attentions is None

if token == tokenizer.eos_token_id:
break
output_ids.append(token)
output = tokenizer.decode(output_ids, skip_special_tokens=True)

Expand All @@ -150,6 +165,8 @@ def generate_stream(self, args):
if stopped:
break

del past_key_values


app = FastAPI()

Expand Down
2 changes: 1 addition & 1 deletion chatserver/server/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main():
pload = {
"model": "facebook/opt-350m",
"prompt": prompt,
"max_new_tokens": 64,
"max_new_tokens": 32,
"temperature": 0.8,
"stop": conv.sep,
}
Expand Down

0 comments on commit 79049a1

Please sign in to comment.