Skip to content

Commit

Permalink
feat: save ingredient list detection in DB
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Oct 27, 2023
1 parent 0624efe commit ffd9c5f
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 0 deletions.
3 changes: 3 additions & 0 deletions robotoff/prediction/ingredient_list/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

INGREDIENT_ID2LABEL = {0: "O", 1: "B-ING", 2: "I-ING"}

MODEL_NAME = "ingredient-detection"
MODEL_VERSION = "ingredient-detection-1.0"


@dataclasses.dataclass
class IngredientPredictionAggregatedEntity:
Expand Down
58 changes: 58 additions & 0 deletions robotoff/workers/tasks/import_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import datetime
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -32,6 +33,7 @@
with_db,
)
from robotoff.off import generate_image_url, get_source_from_url
from robotoff.prediction import ingredient_list
from robotoff.prediction.upc_image import UPCImageType, find_image_is_upc
from robotoff.products import get_product_store
from robotoff.slack import NotifierFactory
Expand Down Expand Up @@ -114,6 +116,15 @@ def run_import_image_job(product_id: ProductIdentifier, image_url: str, ocr_url:
image_url=image_url,
ocr_url=ocr_url,
)
# Only extract ingredient lists for food products, as the model was not
# trained on non-food products
enqueue_job(
extract_ingredients_job,
get_high_queue(product_id),
job_kwargs={"result_ttl": 0},
product_id=product_id,
ocr_url=ocr_url,
)
# We make sure there are no concurrent insight processing by sending
# the job to the same queue. The queue is selected based on the product
# barcode. See `get_high_queue` documentation for more details.
Expand Down Expand Up @@ -535,3 +546,50 @@ def add_image_fingerprint_job(image_model_id: int):
return

add_image_fingerprint(image_model)


@with_db
def extract_ingredients_job(product_id: ProductIdentifier, ocr_url: str):
"""Extracts ingredients using ingredient extraction model from an image
OCR.
:param product_id: The identifier of the product to extract ingredients
for.
:param ocr_url: The URL of the image to extract ingredients from.
"""
source_image = get_source_from_url(ocr_url)

with db:
image_model = ImageModel.get_or_none(
source_image=source_image, server_type=product_id.server_type.name
)

if not image_model:
logger.info("Missing image in DB for image %s", source_image)
return

# Stop the job here if the image has already been processed
if (
ImagePrediction.get_or_none(
image=image_model, model_name=ingredient_list.MODEL_NAME
)
) is not None:
return

output = ingredient_list.predict_from_ocr(ocr_url)
logger.warning("predict_from_ocr output: %s", output)
entities: list[
ingredient_list.IngredientPredictionAggregatedEntity
] = output.entities # type: ignore (we know it's an
# aggregated entity)

ImagePrediction.create(
image=image_model,
type="ner",
model_name=ingredient_list.MODEL_NAME,
model_version=ingredient_list.MODEL_VERSION,
data=dataclasses.asdict(output),
timestamp=datetime.datetime.utcnow(),
max_confidence=max(entity.score for entity in entities),
)
logger.info("create image prediction (ingredient detection) from %s", ocr_url)
Empty file.
Empty file.
99 changes: 99 additions & 0 deletions tests/integration/workers/tasks/test_import_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import dataclasses
from unittest.mock import patch

import pytest

from robotoff.models import ImagePrediction
from robotoff.prediction.ingredient_list import (
IngredientPredictionAggregatedEntity,
IngredientPredictionOutput,
)
from robotoff.types import ProductIdentifier, ServerType
from robotoff.workers.tasks.import_image import extract_ingredients_job

from ...models_utils import ImageModelFactory, ImagePredictionFactory, clean_db


@pytest.fixture(autouse=True)
def _set_up_and_tear_down(peewee_db):
with peewee_db:
clean_db()
# Run the test case.
yield

with peewee_db:
clean_db()


@patch("robotoff.workers.tasks.import_image.ingredient_list")
def test_extract_ingredients_job(mocker, peewee_db):
full_text = "Best product ever!\ningredients: water, salt, sugar."
entities = [
IngredientPredictionAggregatedEntity(
start=19, end=51, score=0.9, text="water, salt, sugar."
)
]
mocker.predict_from_ocr.return_value = IngredientPredictionOutput(
entities=entities, text=full_text
)
mocker.MODEL_NAME = "ingredient-detection"
mocker.MODEL_VERSION = "ingredient-detection-1.0"

barcode = "1234567890123"
ocr_url = "https://images.openfoodfacts.org/images/products/123/456/789/0123/1.json"

with peewee_db:
image = ImageModelFactory(
barcode=barcode, server_type=ServerType.off, image_id="1"
)
extract_ingredients_job(
ProductIdentifier(barcode, ServerType.off), ocr_url=ocr_url
)
mocker.predict_from_ocr.assert_called_once_with(ocr_url)
image_prediction = ImagePrediction.get_or_none(
ImagePrediction.model_name == "ingredient-detection",
ImagePrediction.image_id == image.id,
)
assert image_prediction is not None
assert image_prediction.data == {
"text": full_text,
"entities": [dataclasses.asdict(entities[0])],
}
assert image_prediction.max_confidence == 0.9
assert image_prediction.type == "ner"
assert image_prediction.model_name == "ingredient-detection"
assert image_prediction.model_version == "ingredient-detection-1.0"


@patch("robotoff.workers.tasks.import_image.ingredient_list")
def test_extract_ingredients_job_missing_image(mocker, peewee_db):
barcode = "1234567890123"
ocr_url = "https://images.openfoodfacts.org/images/products/123/456/789/0123/1.json"

with peewee_db:
extract_ingredients_job(
ProductIdentifier(barcode, ServerType.off), ocr_url=ocr_url
)
mocker.predict_from_ocr.assert_not_called()


@patch("robotoff.workers.tasks.import_image.ingredient_list")
def test_extract_ingredients_job_existing_image_prediction(mocker, peewee_db):
mocker.MODEL_NAME = "ingredient-detection"
mocker.MODEL_VERSION = "ingredient-detection-1.0"
barcode = "1234567890123"
ocr_url = "https://images.openfoodfacts.org/images/products/123/456/789/0123/1.json"

with peewee_db:
image = ImageModelFactory(
barcode=barcode, server_type=ServerType.off, image_id="1"
)
ImagePredictionFactory(
image=image,
model_name="ingredient-detection",
model_version="ingredient-detection-1.0",
)
extract_ingredients_job(
ProductIdentifier(barcode, ServerType.off), ocr_url=ocr_url
)
mocker.predict_from_ocr.assert_not_called()

0 comments on commit ffd9c5f

Please sign in to comment.