Skip to content

Commit

Permalink
feat: add command to download models from HF
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Dec 26, 2024
1 parent 826d536 commit c008acd
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
10 changes: 10 additions & 0 deletions robotoff/cli/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,13 @@ def unload_model(model_name: str):
typer.echo("Done.")
typer.echo("**Current models (after) **")
list_models()


@app.command()
def download_models():
"""Download all models."""
from robotoff import triton
from robotoff.utils import get_logger

get_logger()
triton.download_models()
60 changes: 60 additions & 0 deletions robotoff/triton.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import functools
import json
import shutil
import struct
import tempfile
from pathlib import Path

import grpc
import numpy as np
from google.protobuf.json_format import MessageToJson
from huggingface_hub import snapshot_download
from more_itertools import chunked
from openfoodfacts.types import JSONType
from PIL import Image
from pydantic import BaseModel
from transformers import CLIPImageProcessor
from tritonclient.grpc import service_pb2, service_pb2_grpc
from tritonclient.grpc.service_pb2_grpc import GRPCInferenceServiceStub
Expand All @@ -24,6 +29,30 @@
# Get model config: /v2/models/{MODEL_NAME}/config


class HuggingFaceModel(BaseModel):
name: str
version: int
repo_id: str
subfolder: str = "onnx"
revision: str = "main"


HUGGINGFACE_MODELS = [
HuggingFaceModel(
name="nutrition_extractor",
version=1,
repo_id="openfoodfacts/nutrition-extractor",
revision="dea426bf3c3d289ad7b65d29a7744ea6851632a6",
),
HuggingFaceModel(
name="nutrition_extractor",
version=2,
repo_id="openfoodfacts/nutrition-extractor",
revision="7a43f38725f50f37a8c7bce417fc75741bea49fe",
),
]


@functools.cache
def get_triton_inference_stub(
triton_uri: str | None = None,
Expand Down Expand Up @@ -264,3 +293,34 @@ def get_model_config(

response = triton_stub.ModelConfig(request)
return response.config


def download_models():
"""Downloading all models from Hugging Face Hub.
The models are downloaded in the Triton models directory. If the model
already exists, it is not downloaded.
"""
for model in HUGGINGFACE_MODELS:
base_model_dir = settings.TRITON_MODELS_DIR / model.name
base_model_dir.mkdir(parents=True, exist_ok=True)
model_with_version_dir = base_model_dir / str(model.version) / "model.onnx"

if model_with_version_dir.exists():
logger.info(
f"Model {model.name} version {model.version} already downloaded"
)
continue

with tempfile.TemporaryDirectory() as temp_dir_str:
logger.info(f"Temporary cache directory: {temp_dir_str}")
temp_dir = Path(temp_dir_str)
snapshot_download(
repo_id=model.repo_id,
allow_patterns=[f"{model.subfolder}/*"],
revision=model.revision,
local_dir=temp_dir,
)
model_with_version_dir.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Copying model files to {model_with_version_dir}")
shutil.move(temp_dir / model.subfolder, model_with_version_dir)

0 comments on commit c008acd

Please sign in to comment.