From 520566acb6b07cfbf2503d1763ad10b436e9c578 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Thu, 19 Sep 2024 14:41:01 +0200 Subject: [PATCH] Inference server, lots of related changes (#42) --- .github/workflows/push.yml | 20 +- .../docs/concepts/transforms.md | 12 +- docs/source/Config/transforms.rst | 21 + eole/bin/convert/convert_HF.py | 116 ++++- eole/bin/run/serve.py | 416 ++++++++++++++++++ eole/bin/tools/LM_scoring.py | 9 +- eole/config/inference.py | 14 +- eole/config/run.py | 40 +- eole/constants.py | 6 + eole/inference_engine.py | 43 +- eole/models/model.py | 2 - eole/models/model_saver.py | 16 +- eole/predict/__init__.py | 1 + eole/predict/encoder.py | 8 +- eole/predict/generator.py | 8 +- eole/predict/greedy_search.py | 62 +-- eole/predict/inference.py | 33 +- eole/predict/translator.py | 8 +- eole/tests/data/inference-engine_py.yaml | 6 +- eole/tests/pull_request_check.sh | 20 +- eole/transforms/misc.py | 2 + eole/transforms/tokenize.py | 1 + eole/transforms/transform.py | 12 +- eole/utils/distributed.py | 7 +- eole/utils/loss.py | 4 +- eole/utils/misc.py | 4 +- eole/utils/scoring_utils.py | 4 +- recipes/gpt2/inference.yaml | 8 +- recipes/llama2/llama-inference-tp-2gpu.yaml | 7 +- recipes/llama2/llama-inference.yaml | 7 +- recipes/llama3.1/llama-inference.yaml | 10 - .../llama3.1/llama-instruct-inference.yaml | 15 - recipes/llama3/llama-inference.yaml | 5 +- recipes/llama3/llama-mmlu.yaml | 4 - .../mistral-7b-awq-gemm-inference.yaml | 7 +- recipes/mixtral/mixtral-inference-awq.yaml | 7 +- recipes/server/README.md | 52 +++ recipes/server/serve.example.yaml | 14 + recipes/wiki_103/README.md | 5 + recipes/wiki_103/inference.yaml | 20 +- 40 files changed, 847 insertions(+), 209 deletions(-) create mode 100644 docs/source/Config/transforms.rst create mode 100644 eole/bin/run/serve.py create mode 100644 recipes/server/README.md create mode 100644 recipes/server/serve.example.yaml diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 97ddc1bb..beb68a0f 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -223,8 +223,8 @@ jobs: -batch_size 10 \ -beam_size 1 \ -seed 1 \ - -random_sampling_topk "-1" \ - -random_sampling_temp 0.0001 \ + -top_k "-1" \ + -temperature 0.0001 \ -tgt eole/tests/data/morph/tgt.valid \ -out /tmp/trans diff eole/tests/data/morph/tgt.valid /tmp/trans && rm /tmp/trans @@ -253,8 +253,8 @@ jobs: -verbose -batch_size 1 \ -beam_size 1 \ -seed 1 \ - -random_sampling_topk -1 \ - -random_sampling_temp 0.0001 \ + -top_k -1 \ + -temperature 0.0001 \ -ban_unk_token \ -length_penalty none \ -out /tmp/gen @@ -266,9 +266,9 @@ jobs: -verbose -batch_size 1 \ -beam_size 1 \ -seed 3 \ - -random_sampling_topk -1 \ - -random_sampling_topp 0.95 \ - -random_sampling_temp 1 \ + -top_k -1 \ + -top_p 0.95 \ + -temperature 1 \ -ban_unk_token \ -length_penalty none \ -out /tmp/gen @@ -280,9 +280,9 @@ jobs: -verbose -batch_size 1 \ -beam_size 10 \ -seed 2 \ - -random_sampling_topk 50 \ - -random_sampling_topp 0.95 \ - -random_sampling_temp 1 \ + -top_k 50 \ + -top_p 0.95 \ + -temperature 1 \ -length_penalty avg \ -ban_unk_token \ -min_length 5 \ diff --git a/docs/docusaurus_tsx/docs/concepts/transforms.md b/docs/docusaurus_tsx/docs/concepts/transforms.md index 9f0602be..a1b05d94 100644 --- a/docs/docusaurus_tsx/docs/concepts/transforms.md +++ b/docs/docusaurus_tsx/docs/concepts/transforms.md @@ -7,7 +7,17 @@ description: Recap of available on-the-fly data transforms. It's your lucky day! We already embedded several transforms that can be used easily. -Note: all the details about every flag and options for each transform can be found in the [train](#train) section. +Note: all the details about every flag and options for each transform can be found in the [Transforms Config](../reference/Config/transforms.md) section. + +### Transform Types + +The concept of `TransformType` was introduced to facilitate transparent configuration management. The underlying issue at stake is that all transforms are not meant to be used in the same concept. For instance, the `filtertoolong` transform is meant as a "safeguard" to limit the size of training batches. Enabling this transform when predicting can introduce some unwanted behaviours and poor results. +For now, the possible transform types are: +- `Default` // `"any"`: usable in any context (default unless specified otherwise in the transform class definition); +- `Train` // `"train"`: usable only in training context; +- `Predict` // `"predict"`: usable only in prediction context. + +This concept might be extended later for various needs, such as different data types, etc. ### General purpose diff --git a/docs/source/Config/transforms.rst b/docs/source/Config/transforms.rst new file mode 100644 index 00000000..5c64489f --- /dev/null +++ b/docs/source/Config/transforms.rst @@ -0,0 +1,21 @@ +Transforms +================= + +.. autopydantic_model:: eole.transforms.tokenize.ONMTTokenizerConfig +.. autopydantic_model:: eole.transforms.tokenize.BaseTokenizerConfig +.. autopydantic_model:: eole.transforms.docify.DocifyConfig +.. autopydantic_model:: eole.transforms.clean.CleanConfig +.. autopydantic_model:: eole.transforms.bart.BARTNoiseConfig +.. autopydantic_model:: eole.transforms.fuzzymatch.FuzzyMatchConfig +.. autopydantic_model:: eole.transforms.inlinetags.InlineTagsConfig +.. autopydantic_model:: eole.transforms.uppercase.UpperCaseConfig +.. autopydantic_model:: eole.transforms.sampling.TokenDropConfig +.. autopydantic_model:: eole.transforms.sampling.TokenMaskConfig +.. autopydantic_model:: eole.transforms.sampling.SwitchOutConfig +.. autopydantic_model:: eole.transforms.terminology.TerminologyConfig +.. autopydantic_model:: eole.transforms.misc.FilterTooLongConfig +.. autopydantic_model:: eole.transforms.misc.PrefixConfig +.. autopydantic_model:: eole.transforms.misc.SuffixConfig +.. autopydantic_model:: eole.transforms.normalize.NormalizeConfig +.. autopydantic_model:: eole.transforms.concat.ConcatConfig +.. autopydantic_model:: eole.transforms.insert_mask_before_placeholder.InsertMaskBeforePlaceholderConfig \ No newline at end of file diff --git a/eole/bin/convert/convert_HF.py b/eole/bin/convert/convert_HF.py index 622c7f09..7da0499f 100755 --- a/eole/bin/convert/convert_HF.py +++ b/eole/bin/convert/convert_HF.py @@ -1,5 +1,6 @@ #!/usr/bin/env python import os +import re import json import torch import pyonmttok @@ -312,6 +313,12 @@ def run(cls, args): ) else: tokenizer_config_json = None + if os.path.exists(os.path.join(args.model_dir, "generation_config.json")): + generation_config_json = os.path.join( + args.model_dir, "generation_config.json" + ) + else: + generation_config_json = None else: directory_path = args.output os.makedirs(directory_path, exist_ok=True) @@ -365,6 +372,16 @@ def run(cls, args): raise huggingface_hub.utils.EntryNotFoundError( "Something went wrong the repo does not contain any tokenizer_config.json file" ) + try: + generation_config_json = huggingface_hub.hf_hub_download( + repo_id=args.model_dir, + filename="generation_config.json", + token=args.token, + ) + except huggingface_hub.utils.EntryNotFoundError: + raise huggingface_hub.utils.EntryNotFoundError( + "Something went wrong the repo does not contain any generation_config.json file" + ) try: wmap_path = huggingface_hub.hf_hub_download( repo_id=args.model_dir, @@ -567,6 +584,9 @@ def run(cls, args): "n_positions": 0, } left_pad = True + eos_token = None + optional_eos = [] + mapped_tokens = [] # ALL THESE IF SHOULD BE HANDLED IN MAPPINGS if arch == "PhiForCausalLM": @@ -913,13 +933,49 @@ def get_weight(checkpoint, tensor_name): add_bos_token = True else: add_bos_token = False + if "chat_template" in data.keys(): + chat_template = {"chat_template": data["chat_template"]} + else: + chat_template = {} + # Not sure if we could do much cleaner to retrieve optional eos tokens + eos_token_id = config.get("eos_token_id", None) + if isinstance(eos_token_id, list): + optional_eos = [ + data["added_tokens_decoder"][str(index)]["content"] + for index in eos_token_id[1:] + ] + eos_token = optional_eos[0] + elif isinstance(eos_token_id, int): + eos_token = data["added_tokens_decoder"][str(eos_token_id)][ + "content" + ] + # Automatically convert added_tokens into mapped_tokens + mapped_tokens = [ + ( + token["content"], + re.sub(r"<\|([^|]*)\|>", "\uff5f\\1\uff60", token["content"]), + ) + for token in data["added_tokens_decoder"].values() + ] else: add_bos_token = True + if generation_config_json is not None: + with open(generation_config_json, encoding="utf-8") as f: + data = json.load(f) + generation_config_dict = {} + # we probably need a better mapping at some point + keys = ["top_k", "top_p", "temperature", "max_length"] + for key in keys: + if key in data.keys(): + generation_config_dict[key] = data[key] + vocabs = {} if ( tokenizer_model is not None ): # sentencepiece mode (might be good to check it's a SP model) + src_subword_type = "sentencepiece" + tokenizer_basename = os.path.basename(tokenizer_model) tokenizer = Tokenizer(model_path=tokenizer_model) vocab = tokenizer.vocab if tokenizer_json is not None: @@ -937,9 +993,10 @@ def get_weight(checkpoint, tensor_name): if "<|startoftext|>" in vocab: index = vocab.index("<|startoftext|>") vocab[index] = DefaultTokens.BOS - if "<|endoftext|>" in vocab and "" not in vocab: - index = vocab.index("<|endoftext|>") - vocab[index] = DefaultTokens.EOS + if eos_token is not None: + if eos_token in vocab and "" not in vocab: + index = vocab.index(eos_token) + vocab[index] = DefaultTokens.EOS if "<0x00>" in vocab: index = vocab.index("<0x00>") vocab[index] = DefaultTokens.PAD @@ -948,13 +1005,20 @@ def get_weight(checkpoint, tensor_name): special_tokens=specials_table[arch], ) else: # # BPE mode - we leverage the HF tokenizer.json info + src_subword_type = "bpe" with open(tokenizer_json, encoding="utf-8") as f: data = json.load(f) - vocab = [ - tok if tok != "Ā" else DefaultTokens.PAD - # "Ā" is '\x00' in unicode (cf tokenize.py gpt2 mapping) - for tok in data["model"]["vocab"] - ] + # gpt2_pretok + gpt2_pretok = False + pretokenizers = data.get("pre_tokenizer", {}).get("pretokenizers", [{}]) + for pretokenizer in pretokenizers: + if pretokenizer.get("type", None) == "ByteLevel": + gpt2_pretok = True + vocab = [ + tok if tok != "Ā" else DefaultTokens.PAD + # "Ā" is '\x00' in unicode (cf tokenize.py gpt2 mapping) + for tok in data["model"]["vocab"] + ] voc_size = len(vocab) if vocab_size > voc_size: for i in range(vocab_size - voc_size): @@ -964,20 +1028,20 @@ def get_weight(checkpoint, tensor_name): if "<|startoftext|>" in vocab: index = vocab.index("<|startoftext|>") vocab[index] = DefaultTokens.BOS - if "<|endoftext|>" in vocab: - index = vocab.index("<|endoftext|>") - vocab[index] = DefaultTokens.EOS if "<|begin_of_text|>" in vocab: index = vocab.index("<|begin_of_text|>") vocab[index] = DefaultTokens.BOS - if "<|end_of_text|>" in vocab: - index = vocab.index("<|end_of_text|>") - vocab[index] = DefaultTokens.EOS + if eos_token is not None: + if eos_token in vocab and "" not in vocab: + index = vocab.index(eos_token) + vocab[index] = DefaultTokens.EOS src_vocab = pyonmttok.build_vocab_from_tokens(vocab) + tokenizer_basename = "bpe.model" + with open( - os.path.join(directory_path, "bpe.model"), "w", encoding="utf-8" + os.path.join(directory_path, tokenizer_basename), "w", encoding="utf-8" ) as bpemodel: bpemodel.write("v3;false;false;false;Ġ;Ġ\n") for merge in data["model"]["merges"]: @@ -1013,9 +1077,17 @@ def get_weight(checkpoint, tensor_name): tgt_vocab_size=vocab_size, vocab_size_multiple=8, decoder_start_token=vocabs["decoder_start_token"], - transforms=["filtertoolong"], + transforms=["onmt_tokenize", "filtertoolong"], transforms_configs={ - "filtertoolong": {"src_seq_length": 512, "tgt_seq_length": 512} + "filtertoolong": {"src_seq_length": 512, "tgt_seq_length": 512}, + "onmt_tokenize": { + "src_subword_type": src_subword_type, + "src_subword_model": os.path.join( + "${MODEL_PATH}", tokenizer_basename + ), + "gpt2_pretok": gpt2_pretok, + "mapped_tokens": mapped_tokens, + }, }, model=arch_table[arch]( layers=n_layers, @@ -1060,6 +1132,16 @@ def get_weight(checkpoint, tensor_name): ), ) config_dict = recursive_model_fields_set(config) + + inference_dict = { + "optional_eos": optional_eos, + # TODO: map other settings from HF decoding_config.json + **generation_config_dict, + **chat_template, + } + + config_dict["inference"] = inference_dict + with open( os.path.join(directory_path, "config.json"), "w", encoding="utf-8" ) as f: diff --git a/eole/bin/run/serve.py b/eole/bin/run/serve.py new file mode 100644 index 00000000..7387b6db --- /dev/null +++ b/eole/bin/run/serve.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python + +import os +import time +import gc +import yaml + +from typing import List, Union + +import torch +import uvicorn + +from fastapi import FastAPI, Request, Body +from fastapi.responses import HTMLResponse +from pydantic import BaseModel, Field, model_validator +from jinja2.exceptions import TemplateError +from jinja2.sandbox import ImmutableSandboxedEnvironment + +import eole +from eole.inference_engine import InferenceEnginePY +from eole.config.run import PredictConfig +from eole.config.inference import DecodingConfig +from eole.bin import register_bin, BaseBin +from eole.utils.logging import logger +from eole.constants import DefaultTokens + +STATUS_OK = "ok" +STATUS_ERROR = "error" + + +class TextRequest(DecodingConfig): + """ + Standard text "completion" request + (as well as encoder/decoder models e.g. translation). + """ + + model: int | str = Field(description="Model identifier from server configuration.") + inputs: Union[str, List[str]] = Field( + description="List of inputs to run inference on. " + "A single string will be automatically cast to a single item list." + ) + + class Config: + json_schema_extra = { + "example": { + "model": "llama3-8b-instruct", + "inputs": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>You are a funny guy.<|eot_id|><|start_header_id|>user<|end_header_id|>Tell me a joke :)<|eot_id|><|start_header_id|>assistant<|end_header_id|>", # noqa: E501 + } + } + + +class TextResponse(BaseModel): + """ + Response of TextRequest. + """ + + predictions: List[List[str]] = Field( + description="List of prediction(s) for each input(s)." + ) + scores: List[List[float]] = Field( + description="Pred scores from the model for each prediction." + ) + + class Config: + json_schema_extra = { + "example": { + "predictions": [ + [ + "\n\nHere's one:\n\nWhy couldn't the bicycle stand up by itself?\n\n(wait for it...)\n\nBecause it was two-tired!\n\nHope that made you laugh!" # noqa: E501 + ] + ], + "scores": [[-0.040771484375]], + } + } + + @model_validator(mode="after") + def _validate_response(self): + """ + Automatically apply some formatting to the provided text response. + This logic might be moved elsewhere at some point. + """ + self.predictions = [ + [pred.replace(DefaultTokens.SEP, "\n") for pred in preds] + for preds in self.predictions + ] + return self + + +class ChatRequest(DecodingConfig): + """ + Request format for chat-based interactions. + """ + + model: int | str = Field(description="Model identifier from server configuration.") + messages: List[dict] = Field( + description="List of message dictionaries with 'role' and 'content' keys." + ) + + class Config: + json_schema_extra = { + "example": { + "model": "llama3-8b-instruct", + "messages": [ + {"role": "system", "content": "You are a funny guy."}, + {"role": "user", "content": "Tell me a joke :)"}, + ], + } + } + + +# TODO: specific response model for chat mode? +# class ChatResponse(BaseModel): +# choices: List[dict] + + +class Server(object): + """ + Main server class to manage configuration, models and corresponding constraints. + """ + + def __init__(self): + self.start_time = time.time() + self.models = {} + self.models_root = None + + def start(self, server_config_path): + """ + Initialize the server with the given configuration. + """ + with open(server_config_path) as f: + server_config = yaml.safe_load(f) + self.models_root = server_config["models_root"] + for model in server_config["models"]: + # instantiate models + model_id = model["id"] + model_path = model["path"] + self.models[model_id] = Model( + model_id=model_id, + model_path=model_path, + models_root=self.models_root, + model_type=model.get("model_type", "default"), + pre_config=model.get("config", {}), + ) + if model.get("preload", False): + self.models[model_id].load() + + def available_models(self): + """ + Return a list of available models. + """ + models = [] + for model_id, model in self.models.items(): + models.append({"id": model_id}) + return models + + def maybe_load_model(self, model_id_to_load): + """ + Very naive method to ensure a single model is loaded for now. + """ + for model_id, model in self.models.items(): + if model_id != model_id_to_load: + model.unload() + + +class Model(object): + """ + Represents a single model in the server. + """ + + def __init__( + self, + model_id=None, + model_path=None, + preload=False, + models_root=None, + model_type=False, + pre_config={}, + ): + self.loaded = False + self.engine = None + self.model_id = model_id + self.preload = preload + self.models_root = models_root + self.model_path = model_path + self.local_path = None + self.model_type = model_type + self.pre_config = pre_config + + def get_config(self): + """ + Instanciate the configuration for the model. + """ + # transforms and inference settings are retrieved from the model config for now + self.config = PredictConfig( + src="dummy", + model_path=self.local_path, + # TODO improve this + gpu_ranks=[0], + world_size=1, + **self.pre_config, + ) + + def maybe_retrieve_model(self): + """ + Download the model if it's not available locally. + """ + from huggingface_hub import HfApi, snapshot_download + + hf_api = HfApi() + try: + hf_api.model_info(self.model_path) + except Exception: + self.local_path = os.path.expandvars(self.model_path) + else: + self.local_path = os.path.expandvars( + os.path.join(self.models_root, self.model_path) + ) + logger.info( + f"Downloading {self.model_path} from huggingface, " + f"to local directory {self.local_path}" + ) + snapshot_download(repo_id=self.model_path, local_dir=self.local_path) + + def load(self): + """ + Create the inference engine. + """ + self.maybe_retrieve_model() + self.get_config() + self.engine = InferenceEnginePY(self.config) + self.loaded = True + logger.info(f"Loaded model {self.model_id} from: {self.model_path}") + + def unload(self): + """ + Not super clean, we might want to do better some day... + """ + del self.engine + gc.collect() + torch.cuda.empty_cache() + self.engine = None + self.loaded = False + logger.info(f"Unloaded model {self.model_id}") + + def apply_chat_template(self, inputs): + """ + Render the model input based on the model chat template + and the request inputs. + """ + + def raise_exception(message): + raise TemplateError(message) + + jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) + jinja_env.globals["raise_exception"] = raise_exception + template = jinja_env.from_string(self.config.chat_template) + rendered_output = template.render( + **{ + "messages": inputs, + "bos_token": "", # handled in numericalize + "add_generation_prompt": True, + } + ) + return rendered_output + + def infer(self, inputs, settings={}, is_chat=False): + """ + Run inference on the given inputs. + """ + if type(inputs) == str: + inputs = [inputs] + if not (self.loaded): + self.load() + if is_chat: + assert hasattr( + self.config, "chat_template" + ), "Chat requests can't be performed without a chat_template." + inputs = [self.apply_chat_template(inputs)] + scores, _, preds = self.engine.infer_list(inputs, settings=settings) + return scores, preds + + +def create_app(config_file): + """ + Create and configure the FastAPI application. + """ + app = FastAPI( + title="Eole Inference Server", + version=eole.__version__, + summary="A simple inference server to expose various models.", + description="", # TODO + ) + + server = Server() + server.start(config_file) + + @app.get("/") + def root(request: Request): + """ + Root endpoint returning HTML content to help users find the docs. + """ + html_content = f""" + + + Eole Server + + +

Eole Server

+

Probably not what you're looking for.

+

API docs --> {request.url}docs.

+ + + """ + return HTMLResponse(content=html_content, status_code=200) + + @app.get("/models") + def models(): + """ + Return available models currently exposed. + """ + models = server.available_models() + out = {"models": models} + return out + + @app.post("/unload_model") + def unload_model(model_id): + """ + Unload a specific model. + """ + server.models[model_id].unload() + + @app.get("/health") + def health(): + """ + Health check endpoint. + """ + out = {} + out["status"] = STATUS_OK + return out + + @app.post("/infer", response_model=TextResponse) + def infer( + request: Union[TextRequest, ChatRequest] = Body( + openapi_examples={ + "text_request": { + "summary": "Text Request Example", + "description": "A sample text request", + "value": TextRequest.Config.json_schema_extra["example"], + }, + "chat_request": { + "summary": "Chat Request Example", + "description": "A sample chat request", + "value": ChatRequest.Config.json_schema_extra["example"], + }, + }, + ), + ): + """ + Run inference on the given input. + """ + if isinstance(request, TextRequest): + inputs = ( + request.inputs if isinstance(request.inputs, list) else [request.inputs] + ) + else: # ChatRequest + # no batch support right now + inputs = request.messages + model_id = request.model + # automatically grab anything that is not model/inputs + # (we could probably rely on pydantic model once properly implemented) + non_settings_keys = ["inputs", "messages", "model"] + settings = { + k: v for k, v in request.model_dump().items() if k not in non_settings_keys + } + # TODO: move this in some `infer` method in the `Server` class? + server.maybe_load_model(model_id) + scores, preds = server.models[model_id].infer( + inputs, + settings=settings, + is_chat=isinstance(request, ChatRequest), + ) + # returned scores are tensors which we need to cast (not anymore?) + # scores = [[score.item() for score in score_list] for score_list in scores] + response = {"predictions": preds, "scores": scores} + return response + + # @app.post("/openai/chat/completions") + # def openai_chat(request: ChatRequest): + # """ + # Simulate an OpenAI Request. + # The idea is to make this a viable alternative as a drop-in + # replacement for OpenAI or other LLM stacks. + # """ + # pass + + return app + + +@register_bin(name="serve") +class Serve(BaseBin): + @classmethod + def add_args(cls, parser): + parser.add_argument( + "--config", + "-config", + "-c", + default="./server_conf.yaml", + help="Path of server YAML config file.", + ) + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default="5000") + + @classmethod + def run(cls, args): + app = create_app(args.config) + uvicorn.run(app=app, host=args.host, port=args.port, log_level="info") diff --git a/eole/bin/tools/LM_scoring.py b/eole/bin/tools/LM_scoring.py index 65b53973..81c00dac 100644 --- a/eole/bin/tools/LM_scoring.py +++ b/eole/bin/tools/LM_scoring.py @@ -70,9 +70,16 @@ def run(cls, args): set_random_seed(config.seed, False) ppl_file = codecs.open(config.output + ".ppl", "w+", "utf-8") + # no tensor_parallel support device = ( - torch.device("cuda", config.gpu) if config.gpu > -1 else torch.device("cpu") + torch.device("cuda", config.gpu_ranks[0]) + if len(config.gpu_ranks) > 0 + else torch.device("cpu") ) + if len(config.gpu_ranks) > 1: + logger.warning( + f"gpu_ranks is {str(config.gpu_ranks)} but only the first one will be used." + ) vocabs, model, model_opt = config.model.model_class.load_test_model(config) padding_idx = vocabs["tgt"][DefaultTokens.PAD] diff --git a/eole/config/inference.py b/eole/config/inference.py index b5d9f89f..e120ddb3 100644 --- a/eole/config/inference.py +++ b/eole/config/inference.py @@ -11,13 +11,13 @@ class DecodingConfig(Config): ratio: float = Field( default=-0.0, description="Ratio based beam stop condition." ) # is the minus sign useful here? - random_sampling_topk: int = Field( + top_k: int = Field( default=0, description="Set this to -1 to do random sampling from full distribution. " "Set this to value k>1 to do random sampling restricted to " "the k most likely next tokens. Set this to 1 to use argmax.", ) - random_sampling_topp: float = Field( + top_p: float = Field( default=0.0, description="Probability for top-p/nucleus sampling. " "Restrict tokens to the most likely until the cumulated probability " @@ -25,7 +25,7 @@ class DecodingConfig(Config): ge=0.0, lte=1.0, ) - random_sampling_temp: float = Field( + temperature: float = Field( default=1.0, description="If doing random sampling, divide the logits by this " "before computing softmax during decoding.", @@ -45,14 +45,15 @@ class DecodingConfig(Config): default=False, description="Apply coverage penalty at every decoding step. Helpful for summary penalty.", ) - min_length: int = Field(default=0, description="Minimum prediction length.") + min_length: int = Field(default=0, description="Minimum prediction length.", ge=0) max_length: int = Field(default=250, description="Maximum prediction length.") max_length_ratio: float = Field( default=2, description="Maximum prediction length ratio. For European languages, " "2 is large enough, for target Asian charageters, " "need to increase to 2-3, for special languages (Burmese, Amharic) to 10.", - ) + ge=1, + ) # we might want to validate this against min_length block_ngram_repeat: int = Field( default=0, description="Block repetition of ngrams during decoding." ) @@ -137,9 +138,6 @@ class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig) batch_type: Literal["sents", "tokens"] = Field( default="sents", description="Batch grouping for batch size." ) - gpu: int = Field( - default=-1, description="Device to run on. -1 will default to CPU." - ) avg_raw_probs: bool = Field( default=False, description="If set, during ensembling scores from different models will be combined " diff --git a/eole/config/run.py b/eole/config/run.py index e0ed7a1e..4032f202 100644 --- a/eole/config/run.py +++ b/eole/config/run.py @@ -1,3 +1,5 @@ +import os +import json from typing import Dict, List, Any from eole.config.config import get_config_dict @@ -11,6 +13,8 @@ VocabConfig, NestedAllTransformsConfig, ) +from eole.transforms import get_transforms_cls +from eole.constants import TransformType from pydantic import Field, field_validator, model_validator @@ -23,9 +27,6 @@ class TrainConfig( description="Number of transformed samples per corpus to use to build the vocabulary. " "Set to -1 to use the full corpora.", ) # not sure how to handle the legacy build_vocab_only flag here (different default value in both cases) # noqa: E501 - override_opts: bool = Field( - default=False, description="Allow to override some checkpoint opts." - ) # this should probably be clarified down the line verbose: bool = Field( default=False, description="Print data loading and statistics for all process " @@ -99,6 +100,7 @@ class PredictConfig( None # patch for CT2 inference engine (to improve later) ) model: ModelConfig | None = None + chat_template: str | None = None optional_eos: List[str] | None = Field( default=[], description="Optional EOS tokens that would stop generation, e.g. <|eot_id|> for Llama3", @@ -106,10 +108,42 @@ class PredictConfig( @model_validator(mode="after") def _validate_predict_config(self): + # Not sure we want to call this at every validation + self._update_with_model_config() + # TODO: do we really need this _all_transform? if self._all_transform is None: self._all_transform = self.transforms return self + def _update_with_model_config(self): + # Note: in case of ensemble decoding, grabbing the first model's + # config and artifacts by default + os.environ["MODEL_PATH"] = self.model_path[0] + config_path = os.path.join(self.model_path[0], "config.json") + if os.path.exists(config_path): + with open(config_path) as f: + config_dict = json.loads(os.path.expandvars(f.read())) + else: + config_dict = {} + # Filter out Train transforms + transforms = config_dict.get("transforms", []) + transforms_cls = get_transforms_cls(transforms) + transforms = [ + t for t in transforms if transforms_cls[t].type != TransformType.Train + ] + + if "transforms" not in self.model_fields_set: + self.transforms = self._all_transform = transforms + if "transforms_configs" not in self.model_fields_set: + self.transforms_configs = config_dict.get("transforms_configs", {}) + if "compute_dtype" not in self.model_fields_set: + self.compute_dtype = config_dict.get("training", {}).get( + "compute_dtype", "fp16" + ) + for key, value in config_dict.get("inference", {}).items(): + if key not in self.model_fields_set: + setattr(self, key, value) + class BuildVocabConfig( DataConfig, MiscConfig, BaseVocabConfig diff --git a/eole/constants.py b/eole/constants.py index 2b228c0f..41d4cfbf 100644 --- a/eole/constants.py +++ b/eole/constants.py @@ -63,6 +63,12 @@ class ActivationFunction(str, Enum): gated_silu = "gated-silu" +class TransformType(str, Enum): + Default = "any" + Train = "train" + Predict = "predict" + + ACTIVATION_FUNCTIONS = { ActivationFunction.relu: F.relu, ActivationFunction.gelu: F.gelu, diff --git a/eole/inference_engine.py b/eole/inference_engine.py index 4fd708e0..3ca83584 100755 --- a/eole/inference_engine.py +++ b/eole/inference_engine.py @@ -39,7 +39,7 @@ def infer_file(self): scores, estims, preds = self.infer_file_parallel() return scores, estims, preds - def infer_list(self, src): + def infer_list(self, src, settings={}): """List of strings inference `src`""" if self.config.world_size <= 1: infer_iter = build_dynamic_dataset_iter( @@ -51,9 +51,9 @@ def infer_list(self, src): device_id=self.device_id, model_type=self.model_type, ) - scores, estims, preds = self._predict(infer_iter) + scores, estims, preds = self._predict(infer_iter, settings=settings) else: - scores, estims, preds = self.infer_list_parallel(src) + scores, estims, preds = self.infer_list_parallel(src, settings=settings) return scores, estims, preds def infer_file_parallel(self): @@ -62,7 +62,7 @@ def infer_file_parallel(self): "The inference in mulitprocessing with partitioned models is not implemented." ) - def infer_list_parallel(self, src): + def infer_list_parallel(self, src, settings={}): """The inference in mulitprocessing with partitioned models.""" raise NotImplementedError( "The inference in mulitprocessing with partitioned models is not implemented." @@ -128,11 +128,13 @@ def __init__(self, config): self.error_queue = mp.SimpleQueue() self.error_handler = ErrorHandler(self.error_queue) self.queue_instruct = [] + self.queue_settings = [] self.queue_result = [] self.procs = [] for device_id in range(config.world_size): self.queue_instruct.append(mp.Queue()) + self.queue_settings.append(mp.Queue()) self.queue_result.append(mp.Queue()) self.procs.append( mp.Process( @@ -150,7 +152,10 @@ def __init__(self, config): self.procs[device_id].start() self.error_handler.add_child(self.procs[device_id].pid) else: - self.device_id = config.gpu + if len(config.gpu_ranks) > 0: + self.device_id = config.gpu_ranks[0] + else: + self.device_id = -1 # cpu self.predictor = build_predictor( config, self.device_id, logger=self.logger, report_score=True ) @@ -159,7 +164,8 @@ def __init__(self, config): self.transforms = make_transforms(config, self.transforms_cls, self.vocabs) self.transform_pipe = TransformPipe.build_from(self.transforms.values()) - def _predict(self, infer_iter): + def _predict(self, infer_iter, settings={}): + self.predictor.update_settings(**settings) scores, estims, preds = self.predictor._predict( infer_iter, infer_iter.transforms, @@ -191,10 +197,12 @@ def score_file_parallel(self): score_results.append(self.queue_result[device_id].get()) return score_results[0] - def infer_file_parallel(self): + def infer_file_parallel(self, settings={}): assert self.config.world_size > 1, "World size must be greater than 1." for device_id in range(self.config.world_size): self.queue_instruct[device_id].put(("infer_file", self.config)) + # not sure if we want a separate queue or additional info in queue_instruct + self.queue_settings[device_id].put(settings) scores, estims, preds = [], [], [] for device_id in range(self.config.world_size): scores.append(self.queue_result[device_id].get()) @@ -202,10 +210,11 @@ def infer_file_parallel(self): preds.append(self.queue_result[device_id].get()) return scores[0], estims[0], preds[0] - def infer_list_parallel(self, src): + def infer_list_parallel(self, src, settings={}): assert self.config.world_size > 1, "World size must be greater than 1." for device_id in range(self.config.world_size): self.queue_instruct[device_id].put(("infer_list", src)) + self.queue_settings[device_id].put(settings) scores, estims, preds = [], [], [] for device_id in range(self.config.world_size): scores.append(self.queue_result[device_id].get()) @@ -237,11 +246,12 @@ def __init__(self, config, model_type=None): ), "A model_type kwarg must be passed for CT2 models." self.logger = init_logger(config.log_file) assert self.config.world_size <= 1, "World size must be less than 1." - self.device_id = config.gpu if config.world_size == 1: + self.device_id = config.gpu_ranks[0] self.device_index = config.gpu_ranks self.device = "cuda" else: + self.device_id = -1 self.device_index = 0 self.device = "cpu" self.transforms_cls = get_transforms_cls(self.config._all_transform) @@ -293,9 +303,9 @@ def predict_batch(self, batch, config): max_length=config.max_length, return_scores=True, include_prompt_in_result=False, - sampling_topk=config.random_sampling_topk, - sampling_topp=config.random_sampling_topp, - sampling_temperature=config.random_sampling_temp, + sampling_topk=config.top_k, + sampling_topp=config.top_p, + sampling_temperature=config.temperature, ) preds = [ [self.transforms.apply_reverse(tokens) for tokens in out.sequences] @@ -311,9 +321,9 @@ def predict_batch(self, batch, config): num_hypotheses=config.n_best, max_decoding_length=config.max_length, return_scores=True, - sampling_topk=config.random_sampling_topk, - sampling_topp=config.random_sampling_topp, - sampling_temperature=config.random_sampling_temp, + sampling_topk=config.top_k, + sampling_topp=config.top_p, + sampling_temperature=config.temperature, ) preds = [ [self.transforms.apply_reverse(tokens) for tokens in out.hypotheses] @@ -323,7 +333,8 @@ def predict_batch(self, batch, config): return scores, None, preds - def _predict(self, infer_iter): + def _predict(self, infer_iter, settings={}): + # TODO: convert settings to CT2 naming scores = [] preds = [] for batch, bucket_idx in infer_iter: diff --git a/eole/models/model.py b/eole/models/model.py index 8a8f7b66..2354ef71 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -338,8 +338,6 @@ def inference_logic(self, checkpoint, running_config, vocabs, device_id=None): if use_gpu(running_config): if len(running_config.gpu_ranks) > 0: device_id = running_config.gpu_ranks[0] - elif running_config.gpu > -1: - device_id = running_config.gpu device = torch.device("cuda", device_id) else: device = torch.device("cpu") diff --git a/eole/models/model_saver.py b/eole/models/model_saver.py index 1a68fa38..c1affeb6 100644 --- a/eole/models/model_saver.py +++ b/eole/models/model_saver.py @@ -50,6 +50,9 @@ def load_checkpoint(model_path): config_dict = json.load(f) # drop data to prevent validation issues config_dict["data"] = {} + # drop inference to prevent validation issues + if "inference" in config_dict.keys(): + config_dict.pop("inference") _config = TrainConfig(**config_dict) checkpoint["config"] = _config else: @@ -271,7 +274,12 @@ def _save_config(self): def _save_transforms_artifacts(self): if self.transforms is not None: for transform_name, transform in self.transforms.items(): - transform._save_artifacts(self.model_path) + transform_save_config = transform._save_artifacts(self.model_path) + setattr( + self.config.transforms_configs, + transform_name, + transform_save_config, + ) # we probably do not need to save transforms artifacts for each checkpoint # transform._save_artifacts(os.path.join(self.model_path, self.step_dir)) @@ -293,9 +301,6 @@ def _save(self, step): if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: self.update_step_dir(step) - logger.info(f"Saving config and vocab to {self.model_path}") - self._save_vocab() - self._save_config() logger.info( f"Saving optimizer and weights to {self.step_dir}, and symlink to {self.model_path}" ) @@ -303,6 +308,9 @@ def _save(self, step): self._save_weights(model_state_dict) logger.info(f"Saving transforms artifacts, if any, to {self.model_path}") self._save_transforms_artifacts() + logger.info(f"Saving config and vocab to {self.model_path}") + self._save_vocab() + self._save_config() self.cleanup() # we shall trigger optional saves from transforms here + some default inference config ? if torch.distributed.is_initialized(): diff --git a/eole/predict/__init__.py b/eole/predict/__init__.py index 8cf3c6ce..cc9927e6 100644 --- a/eole/predict/__init__.py +++ b/eole/predict/__init__.py @@ -42,6 +42,7 @@ def build_predictor(config, device_id=0, report_score=True, logger=None, out_fil vocabs, config, model_config, + device_id=device_id, global_scorer=scorer, out_file=out_file, report_align=config.report_align, diff --git a/eole/predict/encoder.py b/eole/predict/encoder.py index 232af03b..4838cfc2 100644 --- a/eole/predict/encoder.py +++ b/eole/predict/encoder.py @@ -24,7 +24,7 @@ def predict_batch(self, batch, attn_debug): else: max_length = self.max_length with torch.no_grad(): - if self.sample_from_topk != 0 or self.sample_from_topp != 0: + if self.top_k != 0 or self.top_p != 0: decode_strategy = GreedySearch( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, @@ -39,9 +39,9 @@ def predict_batch(self, batch, attn_debug): block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, - sampling_temp=self.random_sampling_temp, - keep_topk=self.sample_from_topk, - keep_topp=self.sample_from_topp, + sampling_temp=self.temperature, + top_k=self.top_k, + top_p=self.top_p, beam_size=self.beam_size, ban_unk_token=self.ban_unk_token, ) diff --git a/eole/predict/generator.py b/eole/predict/generator.py index 95602a9a..7bf690ae 100644 --- a/eole/predict/generator.py +++ b/eole/predict/generator.py @@ -25,7 +25,7 @@ def predict_batch(self, batch, attn_debug, scoring=False): """Predict a batch of sentences.""" max_length = 0 if scoring else self.max_length with torch.no_grad(): - if self.sample_from_topk != 0 or self.sample_from_topp != 0: + if self.top_k != 0 or self.top_p != 0: decode_strategy = GreedySearchLM( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, @@ -40,9 +40,9 @@ def predict_batch(self, batch, attn_debug, scoring=False): block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, - sampling_temp=self.random_sampling_temp, - keep_topk=self.sample_from_topk, - keep_topp=self.sample_from_topp, + temperature=self.temperature, + top_k=self.top_k, + top_p=self.top_p, beam_size=self.beam_size, ban_unk_token=self.ban_unk_token, ) diff --git a/eole/predict/greedy_search.py b/eole/predict/greedy_search.py index 91dd90d6..e79192c5 100644 --- a/eole/predict/greedy_search.py +++ b/eole/predict/greedy_search.py @@ -3,11 +3,11 @@ from eole.predict.decode_strategy import DecodeStrategy -def sample_topp(logits, keep_topp): +def sample_topp(logits, top_p): sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=1) cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1) - sorted_indices_to_keep = cumulative_probs.lt(keep_topp) + sorted_indices_to_keep = cumulative_probs.lt(top_p) # keep indices until overflowing p cumsum_mask = sorted_indices_to_keep.cumsum(dim=1) @@ -25,8 +25,8 @@ def sample_topp(logits, keep_topp): return logits.masked_fill(~keep_indices, -10000) -def sample_topk(logits, keep_topk): - top_values, _ = torch.topk(logits, keep_topk, dim=1) +def sample_topk(logits, top_k): + top_values, _ = torch.topk(logits, top_k, dim=1) kth_best = top_values[:, -1].view([-1, 1]) kth_best = kth_best.repeat([1, logits.shape[1]]).float() @@ -36,11 +36,11 @@ def sample_topk(logits, keep_topk): return logits.masked_fill(ignore, -10000) -def sample_with_temperature(logits, sampling_temp, keep_topk, keep_topp): +def sample_with_temperature(logits, temperature, top_k, top_p): """Select next tokens randomly from the top k possible next tokens. - Samples from a categorical distribution over the ``keep_topk`` words using - the category probabilities ``logits / sampling_temp``. + Samples from a categorical distribution over the ``top_k`` words using + the category probabilities ``logits / temperature``. Args: logits (FloatTensor): Shaped ``(batch_size, vocab_size)``. @@ -48,13 +48,13 @@ def sample_with_temperature(logits, sampling_temp, keep_topk, keep_topp): (The distribution actually uses the log-probabilities ``logits - logits.logsumexp(-1)``, which equals the logits if they are log-probabilities summing to 1.) - sampling_temp (float): Used to scale down logits. The higher the + temperature (float): Used to scale down logits. The higher the value, the more likely it is that a non-max word will be sampled. - keep_topk (int): This many words could potentially be chosen. The + top_k (int): This many words could potentially be chosen. The other logits are set to have probability 0. - keep_topp (float): Keep most likely words until the cumulated - probability is greater than p. If used with keep_topk: both + top_p (float): Keep most likely words until the cumulated + probability is greater than p. If used with top_k: both conditions will be applied Returns: @@ -63,21 +63,21 @@ def sample_with_temperature(logits, sampling_temp, keep_topk, keep_topp): * topk_ids: Shaped ``(batch_size, 1)``. These are the sampled word indices in the output vocab. * topk_scores: Shaped ``(batch_size, 1)``. These - are essentially ``(logits / sampling_temp)[topk_ids]``. + are essentially ``(logits / temperature)[topk_ids]``. """ - if sampling_temp == 0.0 or keep_topk == 1: + if temperature == 0.0 or top_k == 1: # For temp=0.0, take the argmax to avoid divide-by-zero errors. - # keep_topk=1 is also equivalent to argmax. + # top_k=1 is also equivalent to argmax. topk_scores, topk_ids = logits.topk(1, dim=-1) - if sampling_temp > 0: - topk_scores /= sampling_temp + if temperature > 0: + topk_scores /= temperature else: - logits = torch.div(logits, sampling_temp) - if keep_topp > 0: - logits = sample_topp(logits, keep_topp) - if keep_topk > 0: - logits = sample_topk(logits, keep_topk) + logits = torch.div(logits, temperature) + if top_p > 0: + logits = sample_topp(logits, top_p) + if top_k > 0: + logits = sample_topk(logits, top_k) dist = torch.distributions.Categorical(logits=logits) topk_ids = dist.sample().view(-1, 1) topk_scores = logits.gather(dim=1, index=topk_ids) @@ -108,11 +108,11 @@ class GreedySearch(DecodeStrategy): exclusion_tokens (set[int]): See base. return_attention (bool): See base. max_length (int): See base. - sampling_temp (float): See + temperature (float): See :func:`~eole.predict.greedy_search.sample_with_temperature()`. - keep_topk (int): See + top_k (int): See :func:`~eole.predict.greedy_search.sample_with_temperature()`. - keep_topp (float): See + top_p (float): See :func:`~eole.predict.greedy_search.sample_with_temperature()`. beam_size (int): Number of beams to use. """ @@ -132,9 +132,9 @@ def __init__( exclusion_tokens, return_attention, max_length, - sampling_temp, - keep_topk, - keep_topp, + temperature, + top_k, + top_p, beam_size, ban_unk_token, add_estimator=False, @@ -156,9 +156,9 @@ def __init__( ban_unk_token, add_estimator, ) - self.sampling_temp = sampling_temp - self.keep_topk = keep_topk - self.keep_topp = keep_topp + self.temperature = temperature + self.top_k = top_k + self.top_p = top_p self.topk_scores = None self.beam_size = beam_size self.n_best = n_best @@ -201,7 +201,7 @@ def _pick(self, log_probs): # maybe fix some prediction at this step by modifying log_probs log_probs = self.target_prefixing(log_probs) topk_ids, topk_scores = sample_with_temperature( - log_probs, self.sampling_temp, self.keep_topk, self.keep_topp + log_probs, self.temperature, self.top_k, self.top_p ) return topk_ids, topk_scores diff --git a/eole/predict/inference.py b/eole/predict/inference.py index 6ef35c2b..b6dbbab4 100644 --- a/eole/predict/inference.py +++ b/eole/predict/inference.py @@ -28,9 +28,11 @@ class Inference(object): max_length (int): See :class:`eole.predict.decode_strategy.DecodeStrategy`. beam_size (int): Number of beams. - random_sampling_topk (int): See + top_p (float): See :class:`eole.predict.greedy_search.GreedySearch`. - random_sampling_temp (float): See + top_k (int): See + :class:`eole.predict.greedy_search.GreedySearch`. + temperature (float): See :class:`eole.predict.greedy_search.GreedySearch`. stepwise_penalty (bool): Whether coverage penalty is applied every step or not. @@ -62,9 +64,9 @@ def __init__( max_length_ratio=1.5, ratio=0.0, beam_size=30, - random_sampling_topk=0, - random_sampling_topp=0.0, - random_sampling_temp=1.0, + top_k=0, + top_p=0.0, + temperature=1.0, stepwise_penalty=None, dump_beam=False, block_ngram_repeat=0, @@ -112,9 +114,9 @@ def __init__( self.max_length_ratio = max_length_ratio self.beam_size = beam_size - self.random_sampling_temp = random_sampling_temp - self.sample_from_topk = random_sampling_topk - self.sample_from_topp = random_sampling_topp + self.temperature = temperature + self.top_k = top_k + self.top_p = top_p self.min_length = min_length self.ban_unk_token = ban_unk_token @@ -169,6 +171,7 @@ def from_config( vocabs, config, # running/predict config model_config, + device_id=0, global_scorer=None, out_file=None, report_align=False, @@ -197,16 +200,16 @@ def from_config( return cls( model, vocabs, - gpu=config.gpu, + gpu=device_id, n_best=config.n_best, min_length=config.min_length, max_length=config.max_length, max_length_ratio=config.max_length_ratio, ratio=config.ratio, beam_size=config.beam_size, - random_sampling_topk=config.random_sampling_topk, - random_sampling_topp=config.random_sampling_topp, - random_sampling_temp=config.random_sampling_temp, + top_k=config.top_k, + top_p=config.top_p, + temperature=config.temperature, stepwise_penalty=config.stepwise_penalty, dump_beam=config.dump_beam, block_ngram_repeat=config.block_ngram_repeat, @@ -247,6 +250,12 @@ def _gold_score(self, batch, enc_out, src_len, enc_final_hs, batch_size, src): glp = None return gs, glp + def update_settings(self, **kwargs): + # we probably would need some validation at some point + for k, v in kwargs.items(): + if hasattr(self, k): + setattr(self, k, v) + def _predict( self, infer_iter, diff --git a/eole/predict/translator.py b/eole/predict/translator.py index c2d1c2d6..8b90d560 100644 --- a/eole/predict/translator.py +++ b/eole/predict/translator.py @@ -84,7 +84,7 @@ def predict_batch(self, batch, attn_debug): else: max_length = self.max_length with torch.no_grad(): - if self.sample_from_topk != 0 or self.sample_from_topp != 0: + if self.top_k != 0 or self.top_p != 0: decode_strategy = GreedySearch( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, @@ -99,9 +99,9 @@ def predict_batch(self, batch, attn_debug): block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, - sampling_temp=self.random_sampling_temp, - keep_topk=self.sample_from_topk, - keep_topp=self.sample_from_topp, + temperature=self.temperature, + top_k=self.top_k, + top_p=self.top_p, beam_size=self.beam_size, ban_unk_token=self.ban_unk_token, ) diff --git a/eole/tests/data/inference-engine_py.yaml b/eole/tests/data/inference-engine_py.yaml index 89db33dd..f738c678 100644 --- a/eole/tests/data/inference-engine_py.yaml +++ b/eole/tests/data/inference-engine_py.yaml @@ -2,9 +2,9 @@ world_size: 0 max_length: 512 batch_type: sents batch_size: 100 -random_sampling_topk: 40 -random_sampling_topp: 0.75 -random_sampling_temp: 0.1 +top_k: 40 +top_p: 0.75 +temperature: 0.1 beam_size: 2 n_best: 2 src: None diff --git a/eole/tests/pull_request_check.sh b/eole/tests/pull_request_check.sh index 87eb0de3..3c7f2e4b 100755 --- a/eole/tests/pull_request_check.sh +++ b/eole/tests/pull_request_check.sh @@ -352,8 +352,8 @@ ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model2 \ -verbose -batch_size 10 \ -beam_size 1 \ -seed 1 \ - -random_sampling_topk -1 \ - -random_sampling_temp 0.0001 \ + -top_k -1 \ + -temperature 0.0001 \ -tgt ${DATA_DIR}/morph/tgt.valid \ -out $TMP_OUT_DIR/trans_sampling >> ${LOG_FILE} 2>&1 diff ${DATA_DIR}/morph/tgt.valid $TMP_OUT_DIR/trans_sampling @@ -389,8 +389,8 @@ ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model_lm \ -verbose -batch_size 1 \ -beam_size 1 \ -seed 1 \ - -random_sampling_topk -1 \ - -random_sampling_temp 0.0001 \ + -top_k -1 \ + -temperature 0.0001 \ -ban_unk_token \ -length_penalty none \ -out $TMP_OUT_DIR/gen_sampling >> ${LOG_FILE} 2>&1 @@ -405,9 +405,9 @@ ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model_lm \ -verbose -batch_size 1 \ -beam_size 1 \ -seed 3 \ - -random_sampling_topk -1 \ - -random_sampling_topp 0.95 \ - -random_sampling_temp 1 \ + -top_k -1 \ + -top_p 0.95 \ + -temperature 1 \ -ban_unk_token \ -length_penalty none \ -out $TMP_OUT_DIR/gen_sampling >> ${LOG_FILE} 2>&1 @@ -422,9 +422,9 @@ ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model_lm \ -verbose -batch_size 1 \ -beam_size 10 \ -seed 2 \ - -random_sampling_topk 50 \ - -random_sampling_topp 0.95 \ - -random_sampling_temp 1 \ + -top_k 50 \ + -top_p 0.95 \ + -temperature 1 \ -length_penalty avg \ -ban_unk_token \ -min_length 5 \ diff --git a/eole/transforms/misc.py b/eole/transforms/misc.py index d681df5b..606b90c2 100644 --- a/eole/transforms/misc.py +++ b/eole/transforms/misc.py @@ -1,5 +1,6 @@ from eole.utils.logging import logger from eole.transforms import register_transform +from eole.constants import TransformType from .transform import Transform, ObservableStats, TransformConfig from pydantic import Field @@ -48,6 +49,7 @@ class FilterTooLongTransform(Transform): """Filter out sentence that are too long.""" config_model = FilterTooLongConfig + type = TransformType.Train def __init__(self, config): super().__init__(config) diff --git a/eole/transforms/tokenize.py b/eole/transforms/tokenize.py index e071a3a1..ccf492d5 100644 --- a/eole/transforms/tokenize.py +++ b/eole/transforms/tokenize.py @@ -532,6 +532,7 @@ def tokenize_string(self, sentence, side="src", is_train=False): if self.mapped_tokens is not None: mapped_dict = {b: a for a, b in self.mapped_tokens} segmented = [mapped_dict.get(tok, tok) for tok in segmented] + return segmented def _detokenize(self, tokens, side="src", is_train=False): diff --git a/eole/transforms/transform.py b/eole/transforms/transform.py index 6a35b987..90837568 100644 --- a/eole/transforms/transform.py +++ b/eole/transforms/transform.py @@ -1,11 +1,11 @@ """Base Transform class and relate utils.""" import os -import json import shutil import copy from eole.utils.logging import logger -from eole.config import Config, recursive_model_fields_set +from eole.config import Config +from eole.constants import TransformType class TransformConfig(Config): @@ -20,6 +20,7 @@ class Transform(object): """A Base class that every transform method should derived from.""" name = None # set in register_transform wrapper + type = TransformType.Default def __init__(self, config): """Initialize Transform by parsing `opts` and add them as attribute.""" @@ -29,6 +30,7 @@ def __init__(self, config): self.artifacts = [] # retain a copy of the full config for some specific cases (seed, share_vocab, etc.) self.full_config = config + # restrict usage to some context self._parse_config() def _set_seed(self, seed): @@ -69,11 +71,7 @@ def _save_artifacts(self, model_path): artifact, os.path.join("${MODEL_PATH}", os.path.basename(maybe_artifact)), ) - config_path = os.path.join(model_path, f"{self.name}.json") - with open(config_path, "w") as f: - json.dump( - recursive_model_fields_set(save_config), f, indent=2, ensure_ascii=False - ) + return save_config @classmethod def add_options(cls, parser): diff --git a/eole/utils/distributed.py b/eole/utils/distributed.py index 6127fd47..e7ba2261 100644 --- a/eole/utils/distributed.py +++ b/eole/utils/distributed.py @@ -186,7 +186,9 @@ def spawned_train(process_fn, config, device_id, error_queue): # noqa: E501 error_queue.put((config.training.gpu_ranks[device_id], traceback.format_exc())) -def spawned_infer(config, device_id, error_queue, queue_instruct, queue_result): +def spawned_infer( + config, device_id, error_queue, queue_instruct, queue_result, queue_settings=None +): """Run various functions for prediction in spawned process on `device_id`.""" try: running_config = ( @@ -205,6 +207,9 @@ def spawned_infer(config, device_id, error_queue, queue_instruct, queue_result): transforms = make_transforms(config, transforms_cls, predictor.vocabs) while True: instruction = queue_instruct.get() + if queue_settings is not None: + settings = queue_settings.get() + predictor.update_settings(**settings) if instruction[0] == "stop": break elif instruction[0] == "infer_list": diff --git a/eole/utils/loss.py b/eole/utils/loss.py index 6a8f7c33..0d6b8e10 100644 --- a/eole/utils/loss.py +++ b/eole/utils/loss.py @@ -107,12 +107,10 @@ def from_config(cls, config, model, vocab, train=True): lm_prior_tau = config.training.lm_prior_tau if config.training.lm_prior_model: if config.training.lm_prior_model[-3:] == ".pt": - # TODO: we should probably find a way around this - config.gpu = 0 _, lm_prior_model, lm_model_config = DecoderModel.load_test_model( config, model_path=config.training.lm_prior_model ) # lm_model_config does not seem used - lm_prior_model.to(torch.device("cuda", config.training.gpu)) + # lm_prior_model.to(torch.device("cuda", config.training.gpu)) lm_prior_model.eval() lm_generator = None else: diff --git a/eole/utils/misc.py b/eole/utils/misc.py index 6ef2cc47..1732bbbd 100644 --- a/eole/utils/misc.py +++ b/eole/utils/misc.py @@ -69,9 +69,7 @@ def use_gpu(config): """ Creates a boolean if gpu used """ - return (hasattr(config, "gpu_ranks") and len(config.gpu_ranks) > 0) or ( - hasattr(config, "gpu") and config.gpu > -1 - ) + return hasattr(config, "gpu_ranks") and len(config.gpu_ranks) > 0 def set_random_seed(seed, is_cuda): diff --git a/eole/utils/scoring_utils.py b/eole/utils/scoring_utils.py index 6bae5cb8..0e6648eb 100644 --- a/eole/utils/scoring_utils.py +++ b/eole/utils/scoring_utils.py @@ -50,7 +50,6 @@ def translate(self, model, gpu_rank, step): # (take 'inference' field of config if exists?) # Set "default" translation options on empty cfgfile predict_config = PredictConfig(model_path=["dummy"], src="dummy") - predict_config.gpu = gpu_rank predict_config.compute_dtype = self.config.training.compute_dtype if predict_config.transforms_configs.prefix.tgt_prefix != "": predict_config.tgt_file_prefix = True @@ -67,6 +66,7 @@ def translate(self, model, gpu_rank, step): self.vocabs, predict_config, model_config, + device_id=gpu_rank, global_scorer=scorer, out_file=out_file, report_align=predict_config.report_align, @@ -101,7 +101,7 @@ def translate(self, model, gpu_rank, step): translator.vocabs, task=CorpusTask.INFER, tgt="", # This force to clear the target side (needed when using tgt_file_prefix) - device_id=predict_config.gpu, + device_id=gpu_rank, ) # ########### # diff --git a/recipes/gpt2/inference.yaml b/recipes/gpt2/inference.yaml index dc000932..b3542422 100644 --- a/recipes/gpt2/inference.yaml +++ b/recipes/gpt2/inference.yaml @@ -7,17 +7,15 @@ transforms_configs: world_size: 1 gpu_ranks: [0] -gpu: 0 model_path: ${EOLE_MODEL_DIR}/openai_gpt2 src: lm_input.txt output: lm_pred.txt beam_size: 5 -# random_sampling_topp: 0.5 -random_sampling_temp: 1.0 -random_sampling_topk: 50 -random_sampling_topp: 1 +temperature: 1.0 +top_k: 50 +top_p: 1 n_best: 5 seed: 42 diff --git a/recipes/llama2/llama-inference-tp-2gpu.yaml b/recipes/llama2/llama-inference-tp-2gpu.yaml index 51273b90..3b157788 100755 --- a/recipes/llama2/llama-inference-tp-2gpu.yaml +++ b/recipes/llama2/llama-inference-tp-2gpu.yaml @@ -10,7 +10,6 @@ model_path: "${EOLE_MODEL_DIR}/llama2-7b-chat-hf" # Inference seed: 42 max_length: 256 -gpu: 0 batch_type: sents batch_size: 8 world_size: 2 @@ -19,9 +18,9 @@ parallel_mode: "tensor_parallel" quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] quant_type: "bnb_NF4" compute_dtype: fp16 -random_sampling_topk: 5 -random_sampling_topp: 0.8 -random_sampling_temp: 0.9 +top_k: 5 +top_p: 0.8 +temperature: 0.9 beam_size: 1 n_best: 1 report_time: true diff --git a/recipes/llama2/llama-inference.yaml b/recipes/llama2/llama-inference.yaml index 6099015d..6fa433c7 100755 --- a/recipes/llama2/llama-inference.yaml +++ b/recipes/llama2/llama-inference.yaml @@ -10,7 +10,6 @@ model_path: "${EOLE_MODEL_DIR}/llama2-7b-chat-hf" # Inference seed: 42 max_length: 256 -gpu: 0 batch_type: sents batch_size: 8 world_size: 1 @@ -19,9 +18,9 @@ gpu_ranks: [0] quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] quant_type: "bnb_NF4" compute_dtype: fp16 -#random_sampling_topk: 1 -#random_sampling_topp: 0.0 -#random_sampling_temp: 0.9 +#top_k: 1 +#top_p: 0.0 +#temperature: 0.9 beam_size: 1 n_best: 1 report_time: true diff --git a/recipes/llama3.1/llama-inference.yaml b/recipes/llama3.1/llama-inference.yaml index 1bc672e3..abf81f43 100755 --- a/recipes/llama3.1/llama-inference.yaml +++ b/recipes/llama3.1/llama-inference.yaml @@ -1,13 +1,3 @@ -transforms: [onmt_tokenize] - -transforms_configs: - onmt_tokenize: - src_subword_type: bpe - src_subword_model: "${EOLE_MODEL_DIR}/llama3.1-8b/bpe.model" - tgt_subword_type: bpe - tgt_subword_model: "${EOLE_MODEL_DIR}/llama3.1-8b/bpe.model" - gpt2_pretok: true - # Model info model_path: "${EOLE_MODEL_DIR}/llama3.1-8b" diff --git a/recipes/llama3.1/llama-instruct-inference.yaml b/recipes/llama3.1/llama-instruct-inference.yaml index 75dd3525..7c4f907f 100755 --- a/recipes/llama3.1/llama-instruct-inference.yaml +++ b/recipes/llama3.1/llama-instruct-inference.yaml @@ -1,18 +1,3 @@ -transforms: [onmt_tokenize] - -transforms_configs: - onmt_tokenize: - src_subword_type: bpe - src_subword_model: "${EOLE_MODEL_DIR}/llama3.1-8b-instruct/bpe.model" - src_onmttok_kwargs: {"mode": "space", "spacer_annotate": True, "preserve_placeholders": True} - tgt_subword_type: bpe - tgt_subword_model: "${EOLE_MODEL_DIR}/llama3.1-8b-instruct/bpe.model" - tgt_onmttok_kwargs: {"mode": "space", "spacer_annotate": True, "preserve_placeholders": True} - gpt2_pretok: true - mapped_tokens: [['<|start_header_id|>', '⦅start_header_id⦆'], ['<|end_header_id|>', '⦅end_header_id⦆'], ['<|eot_id|>', '⦅eot_id⦆']] - -optional_eos: ['<|eot_id|>'] - # Model info model_path: "${EOLE_MODEL_DIR}/llama3.1-8b-instruct" diff --git a/recipes/llama3/llama-inference.yaml b/recipes/llama3/llama-inference.yaml index ebbd3e27..ab81c8e9 100755 --- a/recipes/llama3/llama-inference.yaml +++ b/recipes/llama3/llama-inference.yaml @@ -15,7 +15,6 @@ model_path: "${EOLE_MODEL_DIR}/llama3-8b-instruct" seed: 42 max_length: 256 # max_length: 1 -gpu: 0 batch_type: sents batch_size: 4 world_size: 1 @@ -23,8 +22,8 @@ gpu_ranks: [0] # world_size: 2 # gpu_ranks: [0, 1] # parallel_mode: "tensor_parallel" -# quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] -# quant_type: "bnb_NF4" +quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] +quant_type: "bnb_NF4" compute_dtype: fp16 #random_sampling_topk: 1 #random_sampling_topp: 0.0 diff --git a/recipes/llama3/llama-mmlu.yaml b/recipes/llama3/llama-mmlu.yaml index 5151a9d4..e5b7439a 100755 --- a/recipes/llama3/llama-mmlu.yaml +++ b/recipes/llama3/llama-mmlu.yaml @@ -15,7 +15,6 @@ model_path: "${EOLE_MODEL_DIR}/llama3-8b-instruct/model.pt" seed: 42 # max_length: 256 max_length: 1 -gpu: 0 batch_type: sents batch_size: 1 world_size: 1 @@ -26,9 +25,6 @@ gpu_ranks: [0] # quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] # quant_type: "bnb_NF4" compute_dtype: fp16 -#random_sampling_topk: 1 -#random_sampling_topp: 0.0 -#random_sampling_temp: 0.9 beam_size: 1 n_best: 1 report_time: true diff --git a/recipes/mistral/mistral-7b-awq-gemm-inference.yaml b/recipes/mistral/mistral-7b-awq-gemm-inference.yaml index 8e7282bb..1cbff4bb 100755 --- a/recipes/mistral/mistral-7b-awq-gemm-inference.yaml +++ b/recipes/mistral/mistral-7b-awq-gemm-inference.yaml @@ -10,7 +10,6 @@ model_path: "$EOLE_MODEL_DIR/mistral-7b-v0.3" # Inference seed: 42 max_length: 256 -gpu: 0 batch_type: sents batch_size: 8 world_size: 1 @@ -22,9 +21,9 @@ gpu_ranks: [0] #quant_type: "bnb_NF4" # precision: fp16 precision: fp16 -#random_sampling_topk: 1 -#random_sampling_topp: 0.6 -#random_sampling_temp: 0.9 +#top_k: 1 +#top_p: 0.6 +#temperature: 0.9 beam_size: 1 n_best: 1 profile: false diff --git a/recipes/mixtral/mixtral-inference-awq.yaml b/recipes/mixtral/mixtral-inference-awq.yaml index bae055f5..0b07ee1b 100755 --- a/recipes/mixtral/mixtral-inference-awq.yaml +++ b/recipes/mixtral/mixtral-inference-awq.yaml @@ -10,7 +10,6 @@ model_path: "${EOLE_MODEL_DIR}/mixtral-8x7b-instruct-v0.1-awq" # Inference seed: 42 max_length: 256 -gpu: 0 batch_type: sents batch_size: 1 world_size: 2 @@ -20,9 +19,9 @@ parallel_mode: "tensor_parallel" #quant_layers: ['gate_up_proj', 'down_proj', 'up_proj'] #quant_type: "bnb_sparse" compute_dtype: fp16 -#random_sampling_topk: 1 -#random_sampling_topp: 0.6 -#random_sampling_temp: 0.9 +#top_k: 1 +#top_p: 0.6 +#temperature: 0.9 beam_size: 1 n_best: 1 profile: false diff --git a/recipes/server/README.md b/recipes/server/README.md new file mode 100644 index 00000000..819833a3 --- /dev/null +++ b/recipes/server/README.md @@ -0,0 +1,52 @@ +# Serving models with Eole + +The provided example configuration allows to serve Llama3-8B-Instruct. + +```yaml +models_root: "." # used only for HF downloads for now, but might override $EOLE_MODEL_DIR at some point +models: +# local model +- id: "llama3-8b-instruct" + path: "${EOLE_MODEL_DIR}/llama3-8b-instruct" + preload: false + config: + quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] + quant_type: "bnb_NF4" +# HF repo id, automatically downloaded to models_root +- id: "llama3-8b-instruct-hf" + path: "fhdz/llama3-8b-instruct" + preload: true +``` + +Note: the `preload` flag allow to load the corresponding model at server startup. See below for the two options. + +## Retrieve and convert model + +### Set environment variables + +``` +export EOLE_MODEL_DIR= +export HF_TOKEN= +``` + +### Option 1 - Download and convert model + +The first example `"llama3-8b-instruct"` requires you to manually convert the model in your desired `$EOLE_MODEL_DIR`. + +``` +eole convert HF --model_dir meta-llama/Meta-Llama-3-8B-Instruct --output $EOLE_MODEL_DIR/llama3-8b-instruct --token $HF_TOKEN +``` + +### Option 2 - Retrieve an already converted model from HF + +The second example `"llama3-8b-instruct-hf"` downloads a model that has already been converted, for the sake of this example. + +## Run server + +``` +eole serve -c serve.example.yaml +``` + +## Play with the API + +FastAPI exposes a swagger UI by default. It should be accessible via your browser at `http://localhost:5000/docs`. \ No newline at end of file diff --git a/recipes/server/serve.example.yaml b/recipes/server/serve.example.yaml new file mode 100644 index 00000000..c844d793 --- /dev/null +++ b/recipes/server/serve.example.yaml @@ -0,0 +1,14 @@ +models_root: "." # used only for HF downloads for now, but might override $EOLE_MODEL_DIR at some point +some_params: null # might be useful to setup some server level params (available gpus, network, etc.) +models: +# local model +- id: "llama3-8b-instruct" + path: "${EOLE_MODEL_DIR}/llama3-8b-instruct" + preload: false + config: + quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] + quant_type: "bnb_NF4" +# HF repo id, automatically downloaded to models_root +- id: "llama3-8b-instruct-hf" + path: "fhdz/llama3-8b-instruct" + preload: true \ No newline at end of file diff --git a/recipes/wiki_103/README.md b/recipes/wiki_103/README.md index 90714c9a..4134a5ea 100644 --- a/recipes/wiki_103/README.md +++ b/recipes/wiki_103/README.md @@ -72,3 +72,8 @@ The following command will provide inference with nucleus sampling of p=0.9 and ```bash eole predict -config inference.yaml -model_path data/wikitext/wikitext-103-raw-v1/run/model-lm/step_1000000 -src data/wikitext/wikitext-103-raw-v1/test_input.txt -output data/wikitext/wikitext-103-raw-v1/test_pred.txt ``` + +**Note**: main transform-related settings are now stored within the model and its configuration, which makes the (rather complex) `inference.yaml` config not strictly necessary anymore. The above command can be converted to a simple command line with the desired settings: +```bash +eole predict -model_path data/wikitext/wikitext-103-raw-v1/run/model-lm/step_1000000 -src data/wikitext/wikitext-103-raw-v1/test_input.txt -output data/wikitext/wikitext-103-raw-v1/test_pred.txt -world_size 1 -gpu_ranks 0 -n_best 3 -top_p 0.9 -beam_size 10 +``` \ No newline at end of file diff --git a/recipes/wiki_103/inference.yaml b/recipes/wiki_103/inference.yaml index c7532183..87fe3e04 100644 --- a/recipes/wiki_103/inference.yaml +++ b/recipes/wiki_103/inference.yaml @@ -1,19 +1,19 @@ +# Note: keeping for reference, but not needed anymore since transforms are loaded transparently from the model's config.json # transforms related stuff -transforms: [onmt_tokenize] -transforms_configs: - onmt_tokenize: - src_subword_type: bpe - src_subword_model: data/wikitext/wikitext-103-raw-v1/subwords.bpe - src_onmttok_kwargs: {"mode": "aggressive", "joiner_annotate": True, "preserve_placeholders": - True, "case_markup": True, "soft_case_regions": True, "preserve_segmented_tokens": - True} +# transforms: [onmt_tokenize] +# transforms_configs: +# onmt_tokenize: +# src_subword_type: bpe +# src_subword_model: data/wikitext/wikitext-103-raw-v1/subwords.bpe +# src_onmttok_kwargs: {"mode": "aggressive", "joiner_annotate": True, "preserve_placeholders": +# True, "case_markup": True, "soft_case_regions": True, "preserve_segmented_tokens": +# True} verbose: false n_best: 3 -random_sampling_topp: 0.9 +top_p: 0.9 beam_size: 10 -gpu: 0 world_size: 1 gpu_ranks: [0]