Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add triton CLI commands #1510

Merged
merged 2 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions robotoff/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import typer

from robotoff.cli.triton import app as triton_app
from robotoff.types import (
ImportImageFlag,
ObjectDetectionModel,
Expand Down Expand Up @@ -1250,5 +1251,8 @@ def launch_normalize_barcode_job(
logger.info("Updated %d images", updated)


app.add_typer(triton_app, name="triton")


def main() -> None:
app()
80 changes: 80 additions & 0 deletions robotoff/cli/triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Optional

import typer

app = typer.Typer()


@app.command()
def load_model(model_name: str, model_version: Optional[str] = None):
"""Load a model in Triton Inference Server.

If the model was never loaded, it will be loaded with the default
configuration generated by Triton.

Otherwise, the behavior will depend on whether the `--model-version` option is
provided:

- If the option is provided, only the specified version will be loaded, the other
versions will be unloaded.
- If the option is not provided, the two latest versions will be loaded.
"""
from robotoff import triton
from robotoff.utils import get_logger

get_logger()

typer.echo(f"Loading model {model_name}")
typer.echo("** Current models (before) **")
list_models()
triton_stub = triton.get_triton_inference_stub()
triton.load_model(triton_stub, model_name, model_version=model_version)
typer.echo("Done.")
typer.echo("**Current models (after) **")
list_models()


@app.command()
def list_models():
"""List all models loaded in Triton Inference Server."""
from robotoff import triton

triton_stub = triton.get_triton_inference_stub()
models = triton.list_models(triton_stub)

for model in models:
typer.echo(f"{model.name} (version: {model.version}), state: {model.state}")


@app.command()
def get_model_config(model_name: str, model_version: Optional[str] = None):
"""Display the configuration of a model in Triton Inference Server."""
from robotoff import triton

typer.echo(f"Getting config for model {model_name}")
triton_stub = triton.get_triton_inference_stub()
config = triton.get_model_config(triton_stub, model_name, model_version)
typer.echo(config)


@app.command()
def unload_model(model_name: str):
"""Unload all versions of a model from Triton Inference Server."""
from robotoff import triton

typer.echo(f"Unloading model {model_name}")
triton_stub = triton.get_triton_inference_stub()
triton.unload_model(triton_stub=triton_stub, model_name=model_name)
typer.echo("Done.")
typer.echo("**Current models (after) **")
list_models()


@app.command()
def download_models():
"""Download all models."""
from robotoff import triton
from robotoff.utils import get_logger

get_logger()
triton.download_models()
149 changes: 149 additions & 0 deletions robotoff/triton.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import functools
import json
import shutil
import struct
import tempfile
from pathlib import Path

import grpc
import numpy as np
from google.protobuf.json_format import MessageToJson
from huggingface_hub import snapshot_download
from more_itertools import chunked
from openfoodfacts.types import JSONType
from PIL import Image
from pydantic import BaseModel
from transformers import CLIPImageProcessor
from tritonclient.grpc import service_pb2, service_pb2_grpc
from tritonclient.grpc.service_pb2_grpc import GRPCInferenceServiceStub
Expand All @@ -21,6 +29,30 @@
# Get model config: /v2/models/{MODEL_NAME}/config


class HuggingFaceModel(BaseModel):
name: str
version: int
repo_id: str
subfolder: str = "onnx"
revision: str = "main"


HUGGINGFACE_MODELS = [
HuggingFaceModel(
name="nutrition_extractor",
version=1,
repo_id="openfoodfacts/nutrition-extractor",
revision="dea426bf3c3d289ad7b65d29a7744ea6851632a6",
),
HuggingFaceModel(
name="nutrition_extractor",
version=2,
repo_id="openfoodfacts/nutrition-extractor",
revision="7a43f38725f50f37a8c7bce417fc75741bea49fe",
),
]


@functools.cache
def get_triton_inference_stub(
triton_uri: str | None = None,
Expand Down Expand Up @@ -175,3 +207,120 @@ def add_triton_infer_input_tensor(request, name: str, data: np.ndarray, datatype
input_tensor.shape.extend(data.shape)
request.inputs.extend([input_tensor])
request.raw_input_contents.extend([data.tobytes()])


def load_model(
triton_stub: GRPCInferenceServiceStub,
model_name: str,
model_version: str | None = None,
) -> None:
"""Load a model in Triton Inference Server.

If the model was never loaded, it will be loaded with the default
configuration generated by Triton.

Otherwise, the behavior will depend on whether the `--model-version` option is
provided:

- If the option is provided, only the specified version will be loaded, the other
versions will be unloaded.
- If the option is not provided, the two latest versions will be loaded.

:param triton_stub: gRPC stub for Triton Inference Server
:param model_name: name of the model to load
:param model_version: version of the model to load, defaults to None
"""
request = service_pb2.RepositoryModelLoadRequest()
request.model_name = model_name

current_models = list_models(triton_stub)
first_load = not any(
model.name == model_name and model.state == "READY" for model in current_models
)

if first_load:
logger.info("First load of model")
else:
logger.info("Previous model already loaded")
model_config = json.loads(
MessageToJson(get_model_config(triton_stub, model_name))
)
if model_version:
logger.info(
f"Model version specified, only loading that version ({model_version})"
)
version_policy: JSONType = {"specific": {"versions": [model_version]}}
else:
logger.info("No model version specified, loading 2 latest version")
version_policy = {"latest": {"num_versions": 2}}

request.parameters["config"].string_param = json.dumps(
{
"input": model_config["input"],
"output": model_config["output"],
"versionPolicy": version_policy,
"max_batch_size": model_config["maxBatchSize"],
"backend": model_config["backend"],
"platform": model_config["platform"],
}
)

triton_stub.RepositoryModelLoad(request)


def unload_model(triton_stub: GRPCInferenceServiceStub, model_name: str) -> None:
"""Unload completely a model from Triton Inference Server."""
request = service_pb2.RepositoryModelUnloadRequest()
request.model_name = model_name
triton_stub.RepositoryModelUnload(request)


def list_models(triton_stub: GRPCInferenceServiceStub):
request = service_pb2.RepositoryIndexRequest()
response = triton_stub.RepositoryIndex(request)
return response.models


def get_model_config(
triton_stub: GRPCInferenceServiceStub,
model_name: str,
model_version: str | None = None,
):
request = service_pb2.ModelConfigRequest()
request.name = model_name
if model_version:
request.version = model_version

response = triton_stub.ModelConfig(request)
return response.config


def download_models():
"""Downloading all models from Hugging Face Hub.

The models are downloaded in the Triton models directory. If the model
already exists, it is not downloaded.
"""
for model in HUGGINGFACE_MODELS:
base_model_dir = settings.TRITON_MODELS_DIR / model.name
base_model_dir.mkdir(parents=True, exist_ok=True)
model_with_version_dir = base_model_dir / str(model.version) / "model.onnx"

if model_with_version_dir.exists():
logger.info(
f"Model {model.name} version {model.version} already downloaded"
)
continue

with tempfile.TemporaryDirectory() as temp_dir_str:
logger.info(f"Temporary cache directory: {temp_dir_str}")
temp_dir = Path(temp_dir_str)
snapshot_download(
repo_id=model.repo_id,
allow_patterns=[f"{model.subfolder}/*"],
revision=model.revision,
local_dir=temp_dir,
)
model_with_version_dir.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Copying model files to {model_with_version_dir}")
shutil.move(temp_dir / model.subfolder, model_with_version_dir)
Loading