Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support inference on the whole dataset #66

Merged
merged 1 commit into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 321 additions & 0 deletions llm_ft/fastchat/serve/inf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
"""
Chat with a model with command line interface.

Usage:
python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.3
python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0

Other commands:
- Type "!!exit" or an empty line to exit.
- Type "!!reset" to start a new conversation.
- Type "!!remove" to remove the last prompt.
- Type "!!regen" to regenerate the last message.
- Type "!!save <filename>" to save the conversation history to a json file.
- Type "!!load <filename>" to load a conversation history from a json file.
"""
import argparse
import os
import re
import sys
import json
from tqdm import tqdm

from prompt_toolkit import PromptSession
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.history import InMemoryHistory
from prompt_toolkit.key_binding import KeyBindings
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
import torch

from fastchat.model.model_adapter import add_model_args
from fastchat.modules.awq import AWQConfig
from fastchat.modules.exllama import ExllamaConfig
from fastchat.modules.gptq import GptqConfig
from fastchat.serve.inference import ChatIO, chat_loop, run_inference
from fastchat.utils import str_to_torch_dtype
from fastchat.model.model_adapter import (
load_model,
get_conversation_template,
get_generate_stream_function,
)


class SimpleChatIO(ChatIO):
def __init__(self, multiline: bool = False):
self._multiline = multiline

def prompt_for_input(self, role) -> str:
if not self._multiline:
return input(f"{role}: ")

prompt_data = []
line = input(f"{role} [ctrl-d/z on empty line to end]: ")
while True:
prompt_data.append(line.strip())
try:
line = input()
except EOFError as e:
break
return "\n".join(prompt_data)

def prompt_for_output(self, role: str):
print(f"{role}: ", end="", flush=True)

def stream_output(self, output_stream):
pre = 0
for outputs in output_stream:
output_text = outputs["text"]
output_text = output_text.strip().split(" ")
now = len(output_text) - 1
if now > pre:
print(" ".join(output_text[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(output_text[pre:]), flush=True)
return " ".join(output_text)

def print_output(self, text: str):
print(text)


class RichChatIO(ChatIO):
bindings = KeyBindings()

@bindings.add("escape", "enter")
def _(event):
event.app.current_buffer.newline()

def __init__(self, multiline: bool = False, mouse: bool = False):
self._prompt_session = PromptSession(history=InMemoryHistory())
self._completer = WordCompleter(
words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"],
pattern=re.compile("$"),
)
self._console = Console()
self._multiline = multiline
self._mouse = mouse

def prompt_for_input(self, role) -> str:
self._console.print(f"[bold]{role}:")
# TODO(suquark): multiline input has some issues. fix it later.
prompt_input = self._prompt_session.prompt(
completer=self._completer,
multiline=False,
mouse_support=self._mouse,
auto_suggest=AutoSuggestFromHistory(),
key_bindings=self.bindings if self._multiline else None,
)
self._console.print()
return prompt_input

def prompt_for_output(self, role: str):
self._console.print(f"[bold]{role}:")

def stream_output(self, output_stream):
"""Stream output from a role."""
# TODO(suquark): the console flickers when there is a code block
# above it. We need to cut off "live" when a code block is done.

# Create a Live context for updating the console output
with Live(console=self._console, refresh_per_second=4) as live:
# Read lines from the stream
for outputs in output_stream:
if not outputs:
continue
text = outputs["text"]
# Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown"
# in rich. The chatbots output treat "\n" as a new line for
# better compatibility with real-world text. However, rendering
# in markdown would break the format. It is because standard markdown
# treat a single "\n" in normal text as a space.
# Our workaround is adding two spaces at the end of each line.
# This is not a perfect solution, as it would
# introduce trailing spaces (only) in code block, but it works well
# especially for console output, because in general the console does not
# care about trailing spaces.
lines = []
for line in text.splitlines():
lines.append(line)
if line.startswith("```"):
# Code block marker - do not add trailing spaces, as it would
# break the syntax highlighting
lines.append("\n")
else:
lines.append(" \n")
markdown = Markdown("".join(lines))
# Update the Live console output
live.update(markdown)
self._console.print()
return text

def print_output(self, text: str):
self.stream_output([{"text": text}])


class ProgrammaticChatIO(ChatIO):
def prompt_for_input(self, role) -> str:
contents = ""
# `end_sequence` signals the end of a message. It is unlikely to occur in
# message content.
end_sequence = " __END_OF_A_MESSAGE_47582648__\n"
len_end = len(end_sequence)
while True:
if len(contents) >= len_end:
last_chars = contents[-len_end:]
if last_chars == end_sequence:
break
try:
char = sys.stdin.read(1)
contents = contents + char
except EOFError:
continue
contents = contents[:-len_end]
print(f"[!OP:{role}]: {contents}", flush=True)
return contents

def prompt_for_output(self, role: str):
print(f"[!OP:{role}]: ", end="", flush=True)

def stream_output(self, output_stream):
pre = 0
for outputs in output_stream:
output_text = outputs["text"]
output_text = output_text.strip().split(" ")
now = len(output_text) - 1
if now > pre:
print(" ".join(output_text[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(output_text[pre:]), flush=True)
return " ".join(output_text)

def print_output(self, text: str):
print(text)


def main(args, input_str, model, tokenizer):
if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus
exllama_config = None
if args.style == "simple":
chatio = SimpleChatIO(args.multiline)
elif args.style == "rich":
chatio = RichChatIO(args.multiline, args.mouse)
elif args.style == "programmatic":
chatio = ProgrammaticChatIO()
else:
raise ValueError(f"Invalid style for console: {args.style}")

outputs = run_inference(
input_str,
model,
tokenizer,
args.model_path,
args.device,
args.num_gpus,
args.max_gpu_memory,
str_to_torch_dtype(args.dtype),
args.load_8bit,
args.cpu_offloading,
args.conv_template,
args.conv_system_msg,
args.temperature,
args.repetition_penalty,
args.max_new_tokens,
chatio,
exllama_config=exllama_config,
revision=args.revision,
judge_sent_end=args.judge_sent_end,
debug=args.debug,
history=not args.no_history,
)
return outputs



if __name__ == "__main__":
parser = argparse.ArgumentParser()
add_model_args(parser)
parser.add_argument(
"--conv-template", type=str, default=None, help="Conversation prompt template."
)
parser.add_argument(
"--conv-system-msg", type=str, default=None, help="Conversation system message."
)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--repetition_penalty", type=float, default=1.0)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--no-history", action="store_true")
parser.add_argument(
"--style",
type=str,
default="simple",
choices=["simple", "rich", "programmatic"],
help="Display style.",
)
parser.add_argument(
"--multiline",
action="store_true",
help="Enable multiline input. Use ESC+Enter for newline.",
)
parser.add_argument(
"--mouse",
action="store_true",
help="[Rich Style]: Enable mouse support for cursor positioning.",
)
parser.add_argument(
"--judge-sent-end",
action="store_true",
help="Whether enable the correction logic that interrupts the output of sentences due to EOS.",
)
parser.add_argument(
"--debug",
action="store_true",
help="Print useful debug information (e.g., prompts)",
)
args = parser.parse_args()

gptq_config=GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config=AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)

# Model
model, tokenizer = load_model(
args.model_path,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
dtype=args.dtype,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
revision=args.revision,
debug=args.debug,
)

with open('./data/fastchat-ft-gpt4-gpt4-easy-2-side-partial.json', 'r') as f:
dataset = json.load(f)

new_dataset = []
for data in tqdm(dataset):
input_str = data['conversations'][0]['value']
gth_str = data['conversations'][-1]['value']
prediction = main(args, input_str, model, tokenizer)
print(prediction)
import pdb; pdb.set_trace()
Loading