From 826d536e3d1ebd04290beda9f594950c9e2d4710 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Thu, 26 Dec 2024 12:08:35 +0100 Subject: [PATCH 1/2] feat: add commands to manage triton models --- robotoff/cli/main.py | 4 ++ robotoff/cli/triton.py | 70 +++++++++++++++++++++++++++++++++ robotoff/triton.py | 89 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 163 insertions(+) create mode 100644 robotoff/cli/triton.py diff --git a/robotoff/cli/main.py b/robotoff/cli/main.py index 8ae28b4be4..7db5bcb579 100644 --- a/robotoff/cli/main.py +++ b/robotoff/cli/main.py @@ -4,6 +4,7 @@ import typer +from robotoff.cli.triton import app as triton_app from robotoff.types import ( ImportImageFlag, ObjectDetectionModel, @@ -1250,5 +1251,8 @@ def launch_normalize_barcode_job( logger.info("Updated %d images", updated) +app.add_typer(triton_app, name="triton") + + def main() -> None: app() diff --git a/robotoff/cli/triton.py b/robotoff/cli/triton.py new file mode 100644 index 0000000000..4fd736eb37 --- /dev/null +++ b/robotoff/cli/triton.py @@ -0,0 +1,70 @@ +from typing import Optional + +import typer + +app = typer.Typer() + + +@app.command() +def load_model(model_name: str, model_version: Optional[str] = None): + """Load a model in Triton Inference Server. + + If the model was never loaded, it will be loaded with the default + configuration generated by Triton. + + Otherwise, the behavior will depend on whether the `--model-version` option is + provided: + + - If the option is provided, only the specified version will be loaded, the other + versions will be unloaded. + - If the option is not provided, the two latest versions will be loaded. + """ + from robotoff import triton + from robotoff.utils import get_logger + + get_logger() + + typer.echo(f"Loading model {model_name}") + typer.echo("** Current models (before) **") + list_models() + triton_stub = triton.get_triton_inference_stub() + triton.load_model(triton_stub, model_name, model_version=model_version) + typer.echo("Done.") + typer.echo("**Current models (after) **") + list_models() + + +@app.command() +def list_models(): + """List all models loaded in Triton Inference Server.""" + from robotoff import triton + + triton_stub = triton.get_triton_inference_stub() + models = triton.list_models(triton_stub) + + for model in models: + typer.echo(f"{model.name} (version: {model.version}), state: {model.state}") + + +@app.command() +def get_model_config(model_name: str, model_version: Optional[str] = None): + """Display the configuration of a model in Triton Inference Server.""" + from robotoff import triton + + typer.echo(f"Getting config for model {model_name}") + triton_stub = triton.get_triton_inference_stub() + config = triton.get_model_config(triton_stub, model_name, model_version) + typer.echo(config) + + +@app.command() +def unload_model(model_name: str): + """Unload all versions of a model from Triton Inference Server.""" + from robotoff import triton + + typer.echo(f"Unloading model {model_name}") + triton_stub = triton.get_triton_inference_stub() + triton.unload_model(triton_stub=triton_stub, model_name=model_name) + typer.echo("Done.") + typer.echo("**Current models (after) **") + list_models() diff --git a/robotoff/triton.py b/robotoff/triton.py index 84a06d763b..510bc99ecc 100644 --- a/robotoff/triton.py +++ b/robotoff/triton.py @@ -1,9 +1,12 @@ import functools +import json import struct import grpc import numpy as np +from google.protobuf.json_format import MessageToJson from more_itertools import chunked +from openfoodfacts.types import JSONType from PIL import Image from transformers import CLIPImageProcessor from tritonclient.grpc import service_pb2, service_pb2_grpc @@ -175,3 +178,89 @@ def add_triton_infer_input_tensor(request, name: str, data: np.ndarray, datatype input_tensor.shape.extend(data.shape) request.inputs.extend([input_tensor]) request.raw_input_contents.extend([data.tobytes()]) + + +def load_model( + triton_stub: GRPCInferenceServiceStub, + model_name: str, + model_version: str | None = None, +) -> None: + """Load a model in Triton Inference Server. + + If the model was never loaded, it will be loaded with the default + configuration generated by Triton. + + Otherwise, the behavior will depend on whether the `--model-version` option is + provided: + + - If the option is provided, only the specified version will be loaded, the other + versions will be unloaded. + - If the option is not provided, the two latest versions will be loaded. + + :param triton_stub: gRPC stub for Triton Inference Server + :param model_name: name of the model to load + :param model_version: version of the model to load, defaults to None + """ + request = service_pb2.RepositoryModelLoadRequest() + request.model_name = model_name + + current_models = list_models(triton_stub) + first_load = not any( + model.name == model_name and model.state == "READY" for model in current_models + ) + + if first_load: + logger.info("First load of model") + else: + logger.info("Previous model already loaded") + model_config = json.loads( + MessageToJson(get_model_config(triton_stub, model_name)) + ) + if model_version: + logger.info( + f"Model version specified, only loading that version ({model_version})" + ) + version_policy: JSONType = {"specific": {"versions": [model_version]}} + else: + logger.info("No model version specified, loading 2 latest version") + version_policy = {"latest": {"num_versions": 2}} + + request.parameters["config"].string_param = json.dumps( + { + "input": model_config["input"], + "output": model_config["output"], + "versionPolicy": version_policy, + "max_batch_size": model_config["maxBatchSize"], + "backend": model_config["backend"], + "platform": model_config["platform"], + } + ) + + triton_stub.RepositoryModelLoad(request) + + +def unload_model(triton_stub: GRPCInferenceServiceStub, model_name: str) -> None: + """Unload completely a model from Triton Inference Server.""" + request = service_pb2.RepositoryModelUnloadRequest() + request.model_name = model_name + triton_stub.RepositoryModelUnload(request) + + +def list_models(triton_stub: GRPCInferenceServiceStub): + request = service_pb2.RepositoryIndexRequest() + response = triton_stub.RepositoryIndex(request) + return response.models + + +def get_model_config( + triton_stub: GRPCInferenceServiceStub, + model_name: str, + model_version: str | None = None, +): + request = service_pb2.ModelConfigRequest() + request.name = model_name + if model_version: + request.version = model_version + + response = triton_stub.ModelConfig(request) + return response.config From c008acd2b7e08bee78cc7ad67a6697c2c8051e9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Thu, 26 Dec 2024 12:40:20 +0100 Subject: [PATCH 2/2] feat: add command to download models from HF --- robotoff/cli/triton.py | 10 +++++++ robotoff/triton.py | 60 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/robotoff/cli/triton.py b/robotoff/cli/triton.py index 4fd736eb37..01b70356f8 100644 --- a/robotoff/cli/triton.py +++ b/robotoff/cli/triton.py @@ -68,3 +68,13 @@ def unload_model(model_name: str): typer.echo("Done.") typer.echo("**Current models (after) **") list_models() + + +@app.command() +def download_models(): + """Download all models.""" + from robotoff import triton + from robotoff.utils import get_logger + + get_logger() + triton.download_models() diff --git a/robotoff/triton.py b/robotoff/triton.py index 510bc99ecc..5b320a7b35 100644 --- a/robotoff/triton.py +++ b/robotoff/triton.py @@ -1,13 +1,18 @@ import functools import json +import shutil import struct +import tempfile +from pathlib import Path import grpc import numpy as np from google.protobuf.json_format import MessageToJson +from huggingface_hub import snapshot_download from more_itertools import chunked from openfoodfacts.types import JSONType from PIL import Image +from pydantic import BaseModel from transformers import CLIPImageProcessor from tritonclient.grpc import service_pb2, service_pb2_grpc from tritonclient.grpc.service_pb2_grpc import GRPCInferenceServiceStub @@ -24,6 +29,30 @@ # Get model config: /v2/models/{MODEL_NAME}/config +class HuggingFaceModel(BaseModel): + name: str + version: int + repo_id: str + subfolder: str = "onnx" + revision: str = "main" + + +HUGGINGFACE_MODELS = [ + HuggingFaceModel( + name="nutrition_extractor", + version=1, + repo_id="openfoodfacts/nutrition-extractor", + revision="dea426bf3c3d289ad7b65d29a7744ea6851632a6", + ), + HuggingFaceModel( + name="nutrition_extractor", + version=2, + repo_id="openfoodfacts/nutrition-extractor", + revision="7a43f38725f50f37a8c7bce417fc75741bea49fe", + ), +] + + @functools.cache def get_triton_inference_stub( triton_uri: str | None = None, @@ -264,3 +293,34 @@ def get_model_config( response = triton_stub.ModelConfig(request) return response.config + + +def download_models(): + """Downloading all models from Hugging Face Hub. + + The models are downloaded in the Triton models directory. If the model + already exists, it is not downloaded. + """ + for model in HUGGINGFACE_MODELS: + base_model_dir = settings.TRITON_MODELS_DIR / model.name + base_model_dir.mkdir(parents=True, exist_ok=True) + model_with_version_dir = base_model_dir / str(model.version) / "model.onnx" + + if model_with_version_dir.exists(): + logger.info( + f"Model {model.name} version {model.version} already downloaded" + ) + continue + + with tempfile.TemporaryDirectory() as temp_dir_str: + logger.info(f"Temporary cache directory: {temp_dir_str}") + temp_dir = Path(temp_dir_str) + snapshot_download( + repo_id=model.repo_id, + allow_patterns=[f"{model.subfolder}/*"], + revision=model.revision, + local_dir=temp_dir, + ) + model_with_version_dir.parent.mkdir(parents=True, exist_ok=True) + logger.info(f"Copying model files to {model_with_version_dir}") + shutil.move(temp_dir / model.subfolder, model_with_version_dir)