From 23f76a8adf8142bc24b5c50deb08373385f431be Mon Sep 17 00:00:00 2001 From: lwaekfjlk <1125027232@qq.com> Date: Fri, 20 Oct 2023 22:41:22 +0000 Subject: [PATCH] support inference on the whole dataset --- llm_ft/fastchat/serve/inf.py | 321 +++++++++++++++++++++++++++++ llm_ft/fastchat/serve/inference.py | 79 +++++++ llm_ft/mistral-inference.sh | 2 +- 3 files changed, 401 insertions(+), 1 deletion(-) create mode 100644 llm_ft/fastchat/serve/inf.py diff --git a/llm_ft/fastchat/serve/inf.py b/llm_ft/fastchat/serve/inf.py new file mode 100644 index 00000000..1c7d54f1 --- /dev/null +++ b/llm_ft/fastchat/serve/inf.py @@ -0,0 +1,321 @@ +""" +Chat with a model with command line interface. + +Usage: +python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.3 +python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0 + +Other commands: +- Type "!!exit" or an empty line to exit. +- Type "!!reset" to start a new conversation. +- Type "!!remove" to remove the last prompt. +- Type "!!regen" to regenerate the last message. +- Type "!!save " to save the conversation history to a json file. +- Type "!!load " to load a conversation history from a json file. +""" +import argparse +import os +import re +import sys +import json +from tqdm import tqdm + +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.history import InMemoryHistory +from prompt_toolkit.key_binding import KeyBindings +from rich.console import Console +from rich.live import Live +from rich.markdown import Markdown +import torch + +from fastchat.model.model_adapter import add_model_args +from fastchat.modules.awq import AWQConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.serve.inference import ChatIO, chat_loop, run_inference +from fastchat.utils import str_to_torch_dtype +from fastchat.model.model_adapter import ( + load_model, + get_conversation_template, + get_generate_stream_function, +) + + +class SimpleChatIO(ChatIO): + def __init__(self, multiline: bool = False): + self._multiline = multiline + + def prompt_for_input(self, role) -> str: + if not self._multiline: + return input(f"{role}: ") + + prompt_data = [] + line = input(f"{role} [ctrl-d/z on empty line to end]: ") + while True: + prompt_data.append(line.strip()) + try: + line = input() + except EOFError as e: + break + return "\n".join(prompt_data) + + def prompt_for_output(self, role: str): + print(f"{role}: ", end="", flush=True) + + def stream_output(self, output_stream): + pre = 0 + for outputs in output_stream: + output_text = outputs["text"] + output_text = output_text.strip().split(" ") + now = len(output_text) - 1 + if now > pre: + print(" ".join(output_text[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(output_text[pre:]), flush=True) + return " ".join(output_text) + + def print_output(self, text: str): + print(text) + + +class RichChatIO(ChatIO): + bindings = KeyBindings() + + @bindings.add("escape", "enter") + def _(event): + event.app.current_buffer.newline() + + def __init__(self, multiline: bool = False, mouse: bool = False): + self._prompt_session = PromptSession(history=InMemoryHistory()) + self._completer = WordCompleter( + words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"], + pattern=re.compile("$"), + ) + self._console = Console() + self._multiline = multiline + self._mouse = mouse + + def prompt_for_input(self, role) -> str: + self._console.print(f"[bold]{role}:") + # TODO(suquark): multiline input has some issues. fix it later. + prompt_input = self._prompt_session.prompt( + completer=self._completer, + multiline=False, + mouse_support=self._mouse, + auto_suggest=AutoSuggestFromHistory(), + key_bindings=self.bindings if self._multiline else None, + ) + self._console.print() + return prompt_input + + def prompt_for_output(self, role: str): + self._console.print(f"[bold]{role}:") + + def stream_output(self, output_stream): + """Stream output from a role.""" + # TODO(suquark): the console flickers when there is a code block + # above it. We need to cut off "live" when a code block is done. + + # Create a Live context for updating the console output + with Live(console=self._console, refresh_per_second=4) as live: + # Read lines from the stream + for outputs in output_stream: + if not outputs: + continue + text = outputs["text"] + # Render the accumulated text as Markdown + # NOTE: this is a workaround for the rendering "unstandard markdown" + # in rich. The chatbots output treat "\n" as a new line for + # better compatibility with real-world text. However, rendering + # in markdown would break the format. It is because standard markdown + # treat a single "\n" in normal text as a space. + # Our workaround is adding two spaces at the end of each line. + # This is not a perfect solution, as it would + # introduce trailing spaces (only) in code block, but it works well + # especially for console output, because in general the console does not + # care about trailing spaces. + lines = [] + for line in text.splitlines(): + lines.append(line) + if line.startswith("```"): + # Code block marker - do not add trailing spaces, as it would + # break the syntax highlighting + lines.append("\n") + else: + lines.append(" \n") + markdown = Markdown("".join(lines)) + # Update the Live console output + live.update(markdown) + self._console.print() + return text + + def print_output(self, text: str): + self.stream_output([{"text": text}]) + + +class ProgrammaticChatIO(ChatIO): + def prompt_for_input(self, role) -> str: + contents = "" + # `end_sequence` signals the end of a message. It is unlikely to occur in + # message content. + end_sequence = " __END_OF_A_MESSAGE_47582648__\n" + len_end = len(end_sequence) + while True: + if len(contents) >= len_end: + last_chars = contents[-len_end:] + if last_chars == end_sequence: + break + try: + char = sys.stdin.read(1) + contents = contents + char + except EOFError: + continue + contents = contents[:-len_end] + print(f"[!OP:{role}]: {contents}", flush=True) + return contents + + def prompt_for_output(self, role: str): + print(f"[!OP:{role}]: ", end="", flush=True) + + def stream_output(self, output_stream): + pre = 0 + for outputs in output_stream: + output_text = outputs["text"] + output_text = output_text.strip().split(" ") + now = len(output_text) - 1 + if now > pre: + print(" ".join(output_text[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(output_text[pre:]), flush=True) + return " ".join(output_text) + + def print_output(self, text: str): + print(text) + + +def main(args, input_str, model, tokenizer): + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + os.environ["XPU_VISIBLE_DEVICES"] = args.gpus + exllama_config = None + if args.style == "simple": + chatio = SimpleChatIO(args.multiline) + elif args.style == "rich": + chatio = RichChatIO(args.multiline, args.mouse) + elif args.style == "programmatic": + chatio = ProgrammaticChatIO() + else: + raise ValueError(f"Invalid style for console: {args.style}") + + outputs = run_inference( + input_str, + model, + tokenizer, + args.model_path, + args.device, + args.num_gpus, + args.max_gpu_memory, + str_to_torch_dtype(args.dtype), + args.load_8bit, + args.cpu_offloading, + args.conv_template, + args.conv_system_msg, + args.temperature, + args.repetition_penalty, + args.max_new_tokens, + chatio, + exllama_config=exllama_config, + revision=args.revision, + judge_sent_end=args.judge_sent_end, + debug=args.debug, + history=not args.no_history, + ) + return outputs + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_model_args(parser) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--conv-system-msg", type=str, default=None, help="Conversation system message." + ) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--repetition_penalty", type=float, default=1.0) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--no-history", action="store_true") + parser.add_argument( + "--style", + type=str, + default="simple", + choices=["simple", "rich", "programmatic"], + help="Display style.", + ) + parser.add_argument( + "--multiline", + action="store_true", + help="Enable multiline input. Use ESC+Enter for newline.", + ) + parser.add_argument( + "--mouse", + action="store_true", + help="[Rich Style]: Enable mouse support for cursor positioning.", + ) + parser.add_argument( + "--judge-sent-end", + action="store_true", + help="Whether enable the correction logic that interrupts the output of sentences due to EOS.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Print useful debug information (e.g., prompts)", + ) + args = parser.parse_args() + + gptq_config=GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + awq_config=AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ) + + # Model + model, tokenizer = load_model( + args.model_path, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + dtype=args.dtype, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + revision=args.revision, + debug=args.debug, + ) + + with open('./data/fastchat-ft-gpt4-gpt4-easy-2-side-partial.json', 'r') as f: + dataset = json.load(f) + + new_dataset = [] + for data in tqdm(dataset): + input_str = data['conversations'][0]['value'] + gth_str = data['conversations'][-1]['value'] + prediction = main(args, input_str, model, tokenizer) + print(prediction) + import pdb; pdb.set_trace() \ No newline at end of file diff --git a/llm_ft/fastchat/serve/inference.py b/llm_ft/fastchat/serve/inference.py index 08c593cc..e7f5f410 100644 --- a/llm_ft/fastchat/serve/inference.py +++ b/llm_ft/fastchat/serve/inference.py @@ -502,3 +502,82 @@ def reload_conv(conv): conv.messages.pop() reload_conv(conv) + + + +def run_inference( + inp: str, + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + model_path: str, + device: str, + num_gpus: int, + max_gpu_memory: str, + dtype: Optional[torch.dtype], + load_8bit: bool, + cpu_offloading: bool, + conv_template: Optional[str], + conv_system_msg: Optional[str], + temperature: float, + repetition_penalty: float, + max_new_tokens: int, + chatio: ChatIO, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + revision: str = "main", + judge_sent_end: bool = True, + debug: bool = True, + history: bool = True, +): + generate_stream_func = get_generate_stream_function(model, model_path) + + model_type = str(type(model)).lower() + is_t5 = "t5" in model_type + is_codet5p = "codet5p" in model_type + + # Hardcode T5's default repetition penalty to be 1.2 + if is_t5 and repetition_penalty == 1.0: + repetition_penalty = 1.2 + + # Set context length + context_len = get_context_length(model.config) + + # Chat + def new_chat(): + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + if conv_system_msg is not None: + conv.set_system_message(conv_system_msg) + return conv + + def reload_conv(conv): + """ + Reprints the conversation from the start. + """ + for message in conv.messages[conv.offset :]: + chatio.prompt_for_output(message[0]) + chatio.print_output(message[1]) + + conv = new_chat() + + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + if is_codet5p: # codet5p is a code completion model. + prompt = inp + + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) + output_ids = model.generate( + input_ids=input_ids, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_new_tokens=max_new_tokens, + stop_token_ids=conv.stop_token_ids, + ) + newly_generated_tokens = output_ids[0][input_ids.shape[1]:] + generated_text = tokenizer.decode(newly_generated_tokens, skip_special_tokens=True) + return generated_text diff --git a/llm_ft/mistral-inference.sh b/llm_ft/mistral-inference.sh index e709cc2d..cd32f62b 100644 --- a/llm_ft/mistral-inference.sh +++ b/llm_ft/mistral-inference.sh @@ -1 +1 @@ -python3 -m fastchat.serve.cli --model-path ./checkpoint-ft/checkpoint-4525 --conv-template "vicuna_v1.1" \ No newline at end of file +python3 -m fastchat.serve.inf --model-path ./Mistral-7B-v0.1 --conv-template "vicuna_v1.1" \ No newline at end of file