Skip to content

Commit

Permalink
refactor: use openfoodfacts.ml for object detection (#1493)
Browse files Browse the repository at this point in the history
- use ml submodule for object detection
- add a ModelConfig pydantic class to specify model configuration.
  Previously, the model configuration was not clear and could lead to
  bugs.
- add a `threshold` parameter in model predict route
  • Loading branch information
raphael0202 authored Dec 10, 2024
1 parent bb72dce commit 54843be
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 291 deletions.
1 change: 0 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ ARG PYTHON_VERSION=3.11
FROM python:$PYTHON_VERSION-slim AS python-base
RUN apt-get update && \
apt-get install --no-install-suggests --no-install-recommends -y gettext curl build-essential && \
apt-get install ffmpeg libsm6 libxext6 -y && \
apt-get autoremove --purge && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
Expand Down
27 changes: 14 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ python-redis-lock = "~4.0.0"
transformers = "~4.44.2"
lark = "~1.1.4"
h5py = "~3.8.0"
opencv-contrib-python-headless = "~4.10.0.84"
opencv-python-headless = "~4.10.0.84"
toml = "~0.10.2"
openfoodfacts = "1.1.5"
openfoodfacts = "2.3.4"
imagehash = "~4.3.1"
peewee-migrate = "~1.12.2"
diskcache = "~5.6.3"
Expand Down
7 changes: 5 additions & 2 deletions robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,7 @@ class ImagePredictorResource:
def on_get(self, req: falcon.Request, resp: falcon.Response):
image_url = req.get_param("image_url", required=True)
models: list[str] = req.get_param_as_list("models", required=True)
threshold: float = req.get_param_as_float("threshold", default=0.5)

available_object_detection_models = list(
ObjectDetectionModel.__members__.keys()
Expand Down Expand Up @@ -921,14 +922,16 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
model = ObjectDetectionModelRegistry.get(
ObjectDetectionModel[model_name]
)
result = model.detect_from_image(image, output_image=output_image)
result = model.detect_from_image(
image, output_image=output_image, threshold=threshold
)

if output_image:
boxed_image = cast(Image.Image, result.boxed_image)
image_response(boxed_image, resp)
return
else:
predictions[model_name] = result.to_json()
predictions[model_name] = result.to_list()
else:
model_enum = ImageClassificationModel[model_name]
classifier = image_classifier.ImageClassifier(
Expand Down
6 changes: 2 additions & 4 deletions robotoff/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def import_logos(
"""
from robotoff.cli import logos
from robotoff.models import db
from robotoff.prediction.object_detection import OBJECT_DETECTION_MODEL_VERSION
from robotoff.prediction.object_detection import MODELS_CONFIG
from robotoff.utils import get_logger

logger = get_logger()
Expand All @@ -935,9 +935,7 @@ def import_logos(
imported = logos.import_logos(
data_path,
ObjectDetectionModel.universal_logo_detector.value,
OBJECT_DETECTION_MODEL_VERSION[
ObjectDetectionModel.universal_logo_detector
],
MODELS_CONFIG[ObjectDetectionModel.universal_logo_detector].model_version,
batch_size,
server_type,
)
Expand Down
6 changes: 3 additions & 3 deletions robotoff/insights/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from robotoff.off import get_source_from_url
from robotoff.prediction import ocr
from robotoff.prediction.object_detection import (
OBJECT_DETECTION_MODEL_VERSION,
MODELS_CONFIG,
ObjectDetectionModelRegistry,
)
from robotoff.types import (
Expand Down Expand Up @@ -88,13 +88,13 @@ def run_object_detection_model(
results = ObjectDetectionModelRegistry.get(model_name).detect_from_image(
image, output_image=False, triton_uri=triton_uri, threshold=threshold
)
data = results.to_json()
data = results.to_list()
max_confidence = max((item["score"] for item in data), default=None)
return ImagePrediction.create(
image=image_model,
type="object_detection",
model_name=model_name.name,
model_version=OBJECT_DETECTION_MODEL_VERSION[model_name],
model_version=MODELS_CONFIG[model_name].model_version,
data={"objects": data},
timestamp=timestamp,
max_confidence=max_confidence,
Expand Down
6 changes: 1 addition & 5 deletions robotoff/prediction/object_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
# flake8: noqa
from .core import (
OBJECT_DETECTION_MODEL_VERSION,
ObjectDetectionModelRegistry,
ObjectDetectionRawResult,
)
from .core import MODELS_CONFIG, ObjectDetectionModelRegistry, ObjectDetectionResult
Loading

0 comments on commit 54843be

Please sign in to comment.