From 74dec15f6cbfaec044bc8cc131d031a943c94066 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Tue, 22 Oct 2024 12:04:45 +0200 Subject: [PATCH] feat: add nutrition extractor --- robotoff/cli/main.py | 7 + robotoff/insights/annotate.py | 19 - robotoff/insights/importer.py | 28 + robotoff/off.py | 12 +- .../nutrition_extraction/__init__.py | 547 ++++++++++++++++++ robotoff/products.py | 3 + robotoff/triton.py | 16 + robotoff/types.py | 10 +- robotoff/workers/tasks/import_image.py | 110 +++- .../prediction/test_nutrition_extraction.py | 370 ++++++++++++ 10 files changed, 1093 insertions(+), 29 deletions(-) create mode 100644 robotoff/prediction/nutrition_extraction/__init__.py create mode 100644 tests/unit/prediction/test_nutrition_extraction.py diff --git a/robotoff/cli/main.py b/robotoff/cli/main.py index d242d44e9a..e2a8ad9168 100644 --- a/robotoff/cli/main.py +++ b/robotoff/cli/main.py @@ -93,12 +93,19 @@ def create_redis_update( get_logger() client = get_redis_client() + flavor_to_product_type = { + "off": "food", + "obf": "beauty", + "opff": "petfood", + "opf": "product", + } event = { "code": barcode, "flavor": flavor, "user_id": user_id, "action": action, "comment": comment, + "product_type": flavor_to_product_type[flavor], } diffs: JSONType diff --git a/robotoff/insights/annotate.py b/robotoff/insights/annotate.py index bf4a424fc6..a066e43811 100644 --- a/robotoff/insights/annotate.py +++ b/robotoff/insights/annotate.py @@ -653,24 +653,6 @@ def process_annotation( return UPDATED_ANNOTATION_RESULT -class NutritionTableStructureAnnotator(InsightAnnotator): - @classmethod - def process_annotation( - cls, - insight: ProductInsight, - data: Optional[dict] = None, - auth: Optional[OFFAuthentication] = None, - is_vote: bool = False, - ) -> AnnotationResult: - insight.data["annotation"] = data - insight.save() - return SAVED_ANNOTATION_RESULT - - @classmethod - def is_data_required(cls) -> bool: - return True - - ANNOTATOR_MAPPING: dict[str, Type] = { InsightType.packager_code.name: PackagerCodeAnnotator, InsightType.label.name: LabelAnnotator, @@ -681,7 +663,6 @@ def is_data_required(cls) -> bool: InsightType.store.name: StoreAnnotator, InsightType.packaging.name: PackagingAnnotator, InsightType.nutrition_image.name: NutritionImageAnnotator, - InsightType.nutrition_table_structure.name: NutritionTableStructureAnnotator, InsightType.is_upc_image.name: UPCImageAnnotator, } diff --git a/robotoff/insights/importer.py b/robotoff/insights/importer.py index bc4c757904..6bc449c345 100644 --- a/robotoff/insights/importer.py +++ b/robotoff/insights/importer.py @@ -1524,6 +1524,33 @@ def _keep_prediction( ) +class NutrientExtractionImporter(InsightImporter): + @staticmethod + def get_type() -> InsightType: + return InsightType.nutrient_extraction + + @classmethod + def get_required_prediction_types(cls) -> set[PredictionType]: + return {PredictionType.nutrient_extraction} + + @classmethod + def generate_candidates( + cls, + product: Optional[Product], + predictions: list[Prediction], + product_id: ProductIdentifier, + ) -> Iterator[ProductInsight]: + for prediction in predictions: + yield ProductInsight(**prediction.to_dict()) + + @classmethod + def is_conflicting_insight( + cls, candidate: ProductInsight, reference: ProductInsight + ) -> bool: + # Only one insight per product + return True + + class PackagingElementTaxonomyException(Exception): pass @@ -1860,6 +1887,7 @@ def import_product_predictions( UPCImageImporter, NutritionImageImporter, IngredientSpellcheckImporter, + NutrientExtractionImporter, ] diff --git a/robotoff/off.py b/robotoff/off.py index efbb47b89f..c183fc44c3 100644 --- a/robotoff/off.py +++ b/robotoff/off.py @@ -68,8 +68,16 @@ def get_username(self) -> Optional[str]: return None -def get_source_from_url(ocr_url: str) -> str: - url_path = urlparse(ocr_url).path +def get_source_from_url(url: str) -> str: + """Get the `source_image` field from an image or OCR URL. + + It's the path of the image or OCR JSON file, but without the `/images/products` + prefix. It always ends with `.jpg`, whather it's an image or an OCR JSON file. + + :param url: the URL of the image or OCR JSON file + :return: the source image path + """ + url_path = urlparse(url).path if url_path.startswith("/images/products"): url_path = url_path[len("/images/products") :] diff --git a/robotoff/prediction/nutrition_extraction/__init__.py b/robotoff/prediction/nutrition_extraction/__init__.py new file mode 100644 index 0000000000..df423e1c25 --- /dev/null +++ b/robotoff/prediction/nutrition_extraction/__init__.py @@ -0,0 +1,547 @@ +import dataclasses +import functools +import re +import typing +from collections import Counter +from pathlib import Path + +import numpy as np +from openfoodfacts.ocr import OCRResult +from openfoodfacts.utils import load_json +from PIL import Image +from transformers import AutoProcessor, BatchEncoding, PreTrainedTokenizerBase +from tritonclient.grpc import service_pb2 + +from robotoff import settings +from robotoff.triton import ( + GRPCInferenceServiceStub, + add_triton_infer_input_tensor, + get_triton_inference_stub, +) +from robotoff.types import JSONType +from robotoff.utils.logger import get_logger + +logger = get_logger(__name__) + +MODEL_NAME = "nutrition_extractor" +MODEL_VERSION = f"{MODEL_NAME}-1.0" + +# The tokenizer assets are stored in the model directory +MODEL_DIR = settings.TRITON_MODELS_DIR / f"{MODEL_NAME}/1/model.onnx" + + +@dataclasses.dataclass +class NutrientPrediction: + entity: str + text: str + value: str | None + unit: str | None + score: float + start: int + end: int + char_start: int + char_end: int + + +@dataclasses.dataclass +class NutritionEntities: + raw: list[dict] + aggregated: list[dict] + postprocessed: list[dict] + + +@dataclasses.dataclass +class NutritionExtractionPrediction: + nutrients: dict[str, NutrientPrediction] + entities: NutritionEntities + + +def predict( + image: Image.Image, + ocr_result: OCRResult, + model_version: str = "1", + triton_uri: str | None = None, +) -> NutritionExtractionPrediction | None: + """Predict the nutrient values from an image and an OCR result. + + The function returns a `NutritionExtractionPrediction` object with the following + fields: + + - `nutrients`: a dictionary mapping nutrient names to `NutrientPrediction` objects + - `entities`: a `NutritionEntities` object containing the raw, aggregated and + postprocessed entities + + If the OCR result does not contain any text annotation, the function returns + `None`. + + :param image: the *original* image (not resized) + :param ocr_result: the OCR result + :param model_version: the version of the model to use, defaults to "1" + :param triton_uri: the URI of the Triton Inference Server, if not provided, the + default value from settings is used + :return: a `NutritionExtractionPrediction` object + """ + triton_stub = get_triton_inference_stub(triton_uri) + id2label = get_id2label(MODEL_DIR) + processor = get_processor(MODEL_DIR) + + preprocess_result = preprocess(image, ocr_result, processor) + + if preprocess_result is None: + return None + + words, char_offsets, _, batch_encoding = preprocess_result + logits = send_infer_request( + input_ids=batch_encoding.input_ids, + attention_mask=batch_encoding.attention_mask, + bbox=batch_encoding.bbox, + pixel_values=batch_encoding.pixel_values, + model_name=MODEL_NAME, + triton_stub=triton_stub, + model_version=model_version, + ) + return postprocess(logits[0], words, char_offsets, batch_encoding, id2label) + + +def preprocess( + image: Image.Image, ocr_result: OCRResult, processor +) -> ( + tuple[ + list[str], list[tuple[int, int]], list[tuple[int, int, int, int]], BatchEncoding + ] + | None +): + """Preprocess an image and OCR result for the LayoutLMv3 model. + + The *original* image must be provided, as we use the image size to normalize + the bounding boxes. + + The function returns a tuple containing the following elements: + + - `words`: a list of words + - `char_offsets`: a list of character offsets + - `bboxes`: a list of bounding boxes + - `batch_encoding`: the BatchEncoding returned by the tokenizer + + If the OCR result does not contain any text annotation, the function returns + `None`. + + :param image: the original image + :param ocr_result: the OCR result + :param processor: the LaymoutLM processor + :return: a tuple containing the words, character offsets, bounding boxes and + BatchEncoding + """ + if not ocr_result.full_text_annotation: + return None + + words = [] + char_offsets = [] + bboxes = [] + width, height = image.size + for page in ocr_result.full_text_annotation.pages: + for block in page.blocks: + for paragraph in block.paragraphs: + for word in paragraph.words: + words.append(word.text) + char_offsets.append((word.start_idx, word.end_idx)) + vertices = word.bounding_poly.vertices + # LayoutLM requires an integer between 0 and 1000 (excluded) + # for the dataset + x_min = int(min(v[0] for v in vertices) * 1000 / width) + x_max = int(max(v[0] for v in vertices) * 1000 / width) + y_min = int(min(v[1] for v in vertices) * 1000 / height) + y_max = int(max(v[1] for v in vertices) * 1000 / height) + bboxes.append( + ( + max(0, min(999, x_min)), + max(0, min(999, y_min)), + max(0, min(999, x_max)), + max(0, min(999, y_max)), + ) + ) + + batch_encoding = processor( + [image], + [words], + boxes=[bboxes], + truncation=True, + padding=False, + return_tensors="np", + return_offsets_mapping=True, + return_special_tokens_mask=True, + ) + return words, char_offsets, bboxes, batch_encoding + + +def postprocess( + logits: np.ndarray, + words: list[str], + char_offsets: list[tuple[int, int]], + batch_encoding: BatchEncoding, + id2label: dict[int, str], +) -> NutritionExtractionPrediction: + """Postprocess the model output to extract the nutrient predictions. + + The function returns a `NutritionExtractionPrediction` object with the following + fields: + + - `nutrients`: a dictionary mapping nutrient names to `NutrientPrediction` objects + - `entities`: a `NutritionEntities` object containing the raw, aggregated and + postprocessed entities + + :param logits: the predicted logits + :param words: the words corresponding to the input + :param char_offsets: the character offsets of the words + :param batch_encoding: the BatchEncoding returned by the tokenizer + :param id2label: a dictionary mapping label IDs to label names + :return: a `NutritionExtractionPrediction` object + """ + pre_entities = gather_pre_entities( + logits, words, char_offsets, batch_encoding, id2label + ) + aggregated_entities = aggregate_entities(pre_entities) + postprocessed_entities = postprocess_aggregated_entities(aggregated_entities) + return NutritionExtractionPrediction( + nutrients={ + entity["entity"]: NutrientPrediction( + **{k: v for k, v in entity.items() if k != "valid"} + ) + for entity in postprocessed_entities + if entity["valid"] + }, + entities=NutritionEntities( + raw=pre_entities, + aggregated=aggregated_entities, + postprocessed=postprocessed_entities, + ), + ) + + +def gather_pre_entities( + logits: np.ndarray, + words: list[str], + char_offsets: list[tuple[int, int]], + batch_encoding: BatchEncoding, + id2label: dict[int, str], +) -> list[JSONType]: + """Gather the pre-entities extracted by the model. + + This function takes as input the predicted logits returned by the model and + additional tokenizer outputs (words, char_offsets, batch_encoding) and returns a + list of pre-entities with the following fields: + + - `word`: the word corresponding to the entity + - `entity`: the entity type (string, ex: "ENERGY_KCAL_100G") + - `score`: the score of the entity (float) + - `index`: the index of the word in the input + - `char_start`: the character start index of the entity + - `char_end`: the character end index of the entity + + :param logits: the predicted logits + :param words: the words corresponding to the input + :param char_offsets: the character offsets of the words + :param batch_encoding: the BatchEncoding returned by the tokenizer + :param id2label: a dictionary mapping label IDs to label names + :return: a list of pre-entities + """ + offset_mapping = batch_encoding.offset_mapping[0] + # For each sub-token returned by the tokenizer, the offset mapping gives us a + # tuple indicating the sub-token’s start position and end position relative to + # the original token it was split from. + # That means that if the first position in the tuple is anything other than 0, + # it's a subword token. + is_not_subword = offset_mapping[:, 0] == 0 + is_not_special_token = ~batch_encoding.special_tokens_mask[0].astype(bool) + attention_mask = batch_encoding.attention_mask[0].astype(bool) + mask = attention_mask & is_not_subword & is_not_special_token + + maxes = np.max(logits[mask], axis=-1, keepdims=True) + shifted_exp = np.exp(logits[mask] - maxes) + scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True) + label_ids = logits[mask].argmax(axis=-1) + + if len(label_ids) != len(words): + raise ValueError( + f"Number of labels ({len(label_ids)}) does not match number of words ({len(words)})" + ) + + pre_entities = [] + for word_idx in range(len(words)): + label_id = label_ids[word_idx] + word = words[word_idx] + label = id2label[label_id] + entity = label.split("-", maxsplit=1)[-1] + pre_entities.append( + { + "word": word, + "entity": entity, + "score": float(scores[word_idx, label_id]), + "index": word_idx, + "char_start": char_offsets[word_idx][0], + "char_end": char_offsets[word_idx][1], + } + ) + return pre_entities + + +def aggregate_entities(pre_entities: list[JSONType]) -> list[JSONType]: + """Aggregate the entities extracted by the model. + + This function takes as input the list of pre-entities (see the + `gather_pre_entities` function) and aggregate them into entities with the + following fields: + + - `entity`: the entity type (string, ex: "ENERGY_KCAL_100G") + - `words`: the words forming the entity (list of strings) + - `score`: the score of the entity (float), we use the score of the first token + - `start`: the token start index of the entity + - `end`: the token end index of the entity + - `char_start`: the character start index of the entity + - `char_end`: the character end index of the entity + + The entities are aggregated by grouping consecutive tokens with the same entity + type. + """ + entities = [] + + current_entity = None + for pre_entity in pre_entities: + if pre_entity["entity"] == "O": + if current_entity is not None: + entities.append(current_entity) + current_entity = None + continue + + if current_entity is None: + current_entity = { + "entity": pre_entity["entity"], + "words": [pre_entity["word"]], + # We use the score of the first word as the score of the entity + "score": pre_entity["score"], + "start": pre_entity["index"], + "end": pre_entity["index"] + 1, + "char_start": pre_entity["char_start"], + "char_end": pre_entity["char_end"], + } + continue + + if current_entity["entity"] == pre_entity["entity"]: + current_entity["words"].append(pre_entity["word"]) + current_entity["end"] = pre_entity["index"] + 1 + current_entity["char_end"] = pre_entity["char_end"] + continue + + # If we reach this point, the entity has changed + entities.append(current_entity) + current_entity = { + "entity": pre_entity["entity"], + "words": [pre_entity["word"]], + "score": pre_entity["score"], + "start": pre_entity["index"], + "end": pre_entity["index"] + 1, + "char_start": pre_entity["char_start"], + "char_end": pre_entity["char_end"], + } + + if current_entity is not None: + entities.append(current_entity) + + return entities + + +NUTRIENT_VALUE_REGEX = re.compile(r"([0-9]+[,.]?[0-9]*) ?(g|mg|µg|mcg|kj|kcal)?", re.I) + + +def postprocess_aggregated_entities( + aggregated_entities: list[JSONType], +) -> list[JSONType]: + """Postprocess the aggregated entities to extract the nutrient values. + + This function takes as input the list of aggregated entities (see the + `aggregate_entities` function) and add the following fields to each entity: + + - `value`: the nutrient value (string, ex: "12.5") + - `unit`: the nutrient unit (string, ex: "g") + - `valid`: a boolean indicating whether the entity is valid or not + - `invalid_reason`: a string indicating the reason why the entity is invalid + - `text`: the text of the entity + + The field `words` is removed from the aggregated entities. + + Some additional postprocessing steps are also done to handle specific cases: + + - The OCR engine can split incorrectly tokens for energy nutrients + - The OCR engine can fail to detect the word corresponding to the unit after the + value + """ + postprocessed_entities = [] + + for entity in aggregated_entities: + words = [word.strip().strip("()/") for word in entity["words"]] + if entity["entity"].startswith("ENERGY_"): + # Due to incorrect token split by the OCR, the unit (kcal or kj) can be + # attached to the next token. + # Ex: "525 kJ/126 kcal" is often tokenized into ["525", "kJ/"126", "kcal"] + # We handle this case here. + if len(words[0]) > 3 and words[0][:3].lower() == "kj/": + words[0] = words[0][3:] + + words_str = " ".join(words) + value = None + unit = None + is_valid = True + + if entity["entity"] == "SERVING_SIZE": + value = words_str + elif words_str in ("trace", "traces"): + value = "traces" + else: + match = NUTRIENT_VALUE_REGEX.search(words_str) + if match: + value = match.group(1).replace(",", ".") + unit = match.group(2) + # Unit can be none if the OCR engine didn't detect the unit after the + # value as a word + if unit is None: + if entity["entity"].startswith("ENERGY_"): + # Due to incorrect splitting by OCR engine, we don't necessarily + # have a unit for energy, but as the entity can only have a + # single unit (kcal or kJ), we infer the unit from the entity + # name + unit = entity["entity"].split("_")[1].lower() + else: + unit = unit.lower() + if unit == "mcg": + unit = "µg" + else: + logger.warning("Could not extract nutrient value from %s", words_str) + is_valid = False + + postprocessed_entity = { + "entity": entity["entity"].lower(), + "text": words_str, + "value": value, + "unit": unit, + "score": entity["score"], + "start": entity["start"], + "end": entity["end"], + "char_start": entity["char_start"], + "char_end": entity["char_end"], + "valid": is_valid, + } + postprocessed_entities.append(postprocessed_entity) + + entity_type_multiple = set( + entity + for entity, count in Counter( + entity["entity"] for entity in postprocessed_entities + ).items() + if count > 1 + ) + for postprocessed_entity in postprocessed_entities: + if postprocessed_entity["entity"] in entity_type_multiple: + postprocessed_entity["valid"] = False + postprocessed_entity["invalid_reason"] = "multiple_entities" + + return postprocessed_entities + + +@functools.cache +def get_processor(model_dir: Path) -> PreTrainedTokenizerBase: + """Return the processor located in `model_dir`. + + The processor is only loaded once and then cached in memory. + + :param model_dir: the model directory + :return: the processor + """ + return AutoProcessor.from_pretrained(model_dir) + + +@functools.cache +def get_id2label(model_dir: Path) -> dict[int, str]: + """Return a dictionary mapping label IDs to labels for a model located in + `model_dir`.""" + config_path = model_dir / "config.json" + + if not config_path.exists(): + raise ValueError(f"Model config not found in {model_dir}") + + id2label = typing.cast(dict, load_json(config_path))["id2label"] + return {int(k): v for k, v in id2label.items()} + + +def send_infer_request( + input_ids: np.ndarray, + attention_mask: np.ndarray, + bbox: np.ndarray, + pixel_values: np.ndarray, + model_name: str, + triton_stub: GRPCInferenceServiceStub, + model_version: str = "1", +) -> np.ndarray: + """Send a NER infer request to the Triton inference server. + + The first dimension of `input_ids` and `attention_mask` must be the batch + dimension. This function returns the predicted logits. + + :param input_ids: input IDs, generated using the transformers tokenizer. + :param attention_mask: attention mask, generated using the transformers + tokenizer. + :param bbox: bounding boxes of the tokens, generated using the transformers + tokenizer. + :param pixel_values: pixel values of the image, generated using the + transformers tokenizer. + :param model_name: the name of the model to use + :param model_version: version of the model model to use, defaults to "1" + :return: the predicted logits + """ + request = build_triton_request( + input_ids=input_ids, + attention_mask=attention_mask, + bbox=bbox, + pixel_values=pixel_values, + model_name=model_name, + model_version=model_version, + ) + response = triton_stub.ModelInfer(request) + num_tokens = response.outputs[0].shape[1] + num_labels = response.outputs[0].shape[2] + return np.frombuffer( + response.raw_output_contents[0], + dtype=np.float32, + ).reshape((len(input_ids), num_tokens, num_labels)) + + +def build_triton_request( + input_ids: np.ndarray, + attention_mask: np.ndarray, + bbox: np.ndarray, + pixel_values: np.ndarray, + model_name: str, + model_version: str = "1", +): + """Build a Triton ModelInferRequest gRPC request for LayoutLMv3 models. + + :param input_ids: input IDs, generated using the transformers tokenizer. + :param attention_mask: attention mask, generated using the transformers + tokenizer. + :param bbox: bounding boxes of the tokens, generated using the transformers + tokenizer. + :param pixel_values: pixel values of the image, generated using the + transformers tokenizer. + :param model_name: the name of the model to use. + :param model_version: version of the model model to use, defaults to "1". + :return: the gRPC ModelInferRequest + """ + request = service_pb2.ModelInferRequest() + request.model_name = model_name + request.model_version = model_version + + add_triton_infer_input_tensor(request, "input_ids", input_ids, "INT64") + add_triton_infer_input_tensor(request, "attention_mask", attention_mask, "INT64") + add_triton_infer_input_tensor(request, "bbox", bbox, "INT64") + add_triton_infer_input_tensor(request, "pixel_values", pixel_values, "FP32") + + return request diff --git a/robotoff/products.py b/robotoff/products.py index b7c1dae89c..aaeb16cae5 100644 --- a/robotoff/products.py +++ b/robotoff/products.py @@ -418,6 +418,7 @@ class Product: "packagings", "lang", "ingredients_text", + "nutriments", ) def __init__(self, product: JSONType): @@ -441,6 +442,7 @@ def __init__(self, product: JSONType): ) self.lang: Optional[str] = product.get("lang") self.ingredients_text: Optional[str] = product.get("ingredients_text") + self.nutriments: JSONType = product.get("nutriments") or {} @staticmethod def get_fields(): @@ -458,6 +460,7 @@ def get_fields(): "images", "lang", "ingredients_text", + "nutriments", } diff --git a/robotoff/triton.py b/robotoff/triton.py index 63182bce75..84a06d763b 100644 --- a/robotoff/triton.py +++ b/robotoff/triton.py @@ -159,3 +159,19 @@ def serialize_byte_tensor(input_tensor): flattened = b"".join(flattened_ls) return flattened return None + + +def add_triton_infer_input_tensor(request, name: str, data: np.ndarray, datatype: str): + """Create and add an input tensor to a Triton gRPC Inference request. + + :param request: the Triton Inference request + :param name: the name of the input tensor + :param data: the input tensor data + :param datatype: the datatype of the input tensor (e.g. "FP32") + """ + input_tensor = service_pb2.ModelInferRequest().InferInputTensor() + input_tensor.name = name + input_tensor.datatype = datatype + input_tensor.shape.extend(data.shape) + request.inputs.extend([input_tensor]) + request.raw_input_contents.extend([data.tobytes()]) diff --git a/robotoff/types.py b/robotoff/types.py index d0ad8443bd..37b9b2b9f0 100644 --- a/robotoff/types.py +++ b/robotoff/types.py @@ -49,8 +49,8 @@ class PredictionType(str, enum.Enum): nutrient_mention = "nutrient_mention" image_lang = "image_lang" nutrition_image = "nutrition_image" - nutrition_table_structure = "nutrition_table_structure" is_upc_image = "is_upc_image" + nutrient_extraction = "nutrient_extraction" @enum.unique @@ -150,15 +150,13 @@ class InsightType(str, enum.Enum): # product main language. nutrition_image = "nutrition_image" - # The 'nutritional_table_structure' insight detects the nutritional table - # structure from the image. NOTE: this insight has not been generated since - # 2020. - nutrition_table_structure = "nutrition_table_structure" - # The 'is_upc_image' insight predicts whether or not the image is largely # dominated by a UPC (barcode) is_upc_image = "is_upc_image" + # Nutrient values extracted from images + nutrient_extraction = "nutrient_extraction" + class ServerType(str, enum.Enum): """ServerType is used to refer to a specific Open*Facts project: diff --git a/robotoff/workers/tasks/import_image.py b/robotoff/workers/tasks/import_image.py index 1a7ef458d8..b5b140a047 100644 --- a/robotoff/workers/tasks/import_image.py +++ b/robotoff/workers/tasks/import_image.py @@ -32,13 +32,12 @@ ImagePrediction, LogoAnnotation, LogoEmbedding, - Prediction, db, with_db, ) from robotoff.notifier import NotifierFactory from robotoff.off import generate_image_url, get_source_from_url, parse_ingredients -from robotoff.prediction import ingredient_list +from robotoff.prediction import ingredient_list, nutrition_extraction from robotoff.prediction.upc_image import UPCImageType, find_image_is_upc from robotoff.products import get_product_store from robotoff.taxonomy import get_taxonomy @@ -50,6 +49,7 @@ from robotoff.types import ( JSONType, ObjectDetectionModel, + Prediction, PredictionType, ProductIdentifier, ServerType, @@ -137,6 +137,14 @@ def run_import_image_job(product_id: ProductIdentifier, image_url: str, ocr_url: product_id=product_id, ocr_url=ocr_url, ) + enqueue_job( + extract_nutrition_job, + get_high_queue(product_id), + job_kwargs={"result_ttl": 0, "timeout": "2m"}, + product_id=product_id, + image_url=image_url, + ocr_url=ocr_url, + ) # We make sure there are no concurrent insight processing by sending # the job to the same queue. The queue is selected based on the product # barcode. See `get_high_queue` documentation for more details. @@ -779,3 +787,101 @@ def add_ingredient_in_taxonomy_field( known_ingredients_n += known_sub_ingredients_n return ingredients_n, known_ingredients_n + + +@with_db +def extract_nutrition_job( + product_id: ProductIdentifier, + image_url: str, + ocr_url: str, + triton_uri: str | None = None, +) -> None: + """Extract nutrition information from an image OCR, and save the prediction + in the DB. + + :param product_id: The identifier of the product to extract nutrition + information for. + :param image_url: The URL of the image to extract nutrition information + from. + :param ocr_url: The URL of the OCR JSON file + :param triton_uri: URI of the Triton Inference Server, defaults to None. If + not provided, the default value from settings is used. + """ + logger.info("Running nutrition extraction for %s, image %s", product_id, image_url) + source_image = get_source_from_url(image_url) + + with db: + image_model = ImageModel.get_or_none( + source_image=source_image, server_type=product_id.server_type.name + ) + + if not image_model: + logger.info("Missing image in DB for image %s", source_image) + return + + # Stop the job here if the image has already been processed + if ( + ImagePrediction.get_or_none( + image=image_model, model_name=nutrition_extraction.MODEL_NAME + ) + ) is not None: + return + + image = get_image_from_url( + image_url, error_raise=False, session=http_session, use_cache=True + ) + + if image is None: + logger.info("Error while downloading image %s", image_url) + return + + ocr_result = OCRResult.from_url(ocr_url, http_session, error_raise=False) + + if ocr_result is None: + logger.info("Error while downloading OCR JSON %s", ocr_url) + return + + output = nutrition_extraction.predict(image, ocr_result, triton_uri=triton_uri) + + if output is None: + data: JSONType = {"error": "missing_text"} + max_confidence = None + else: + max_confidence = max( + entity["score"] for entity in output.entities.aggregated + ) + data = { + "nutrients": { + entity: dataclasses.asdict(nutrient) + for entity, nutrient in output.nutrients.items() + }, + "entities": dataclasses.asdict(output.entities), + } + logger.info("create image prediction (nutrition extraction) from %s", ocr_url) + ImagePrediction.create( + image=image_model, + type="nutrition_extraction", + model_name=nutrition_extraction.MODEL_NAME, + model_version=nutrition_extraction.MODEL_VERSION, + data=data, + timestamp=datetime.datetime.now(datetime.timezone.utc), + max_confidence=max_confidence, + ) + + if max_confidence is not None: + prediction = Prediction( + barcode=product_id.barcode, + type=PredictionType.nutrient_extraction, + # value and value_tag are None, all data is in data field + value_tag=None, + value=None, + automatic_processing=False, + predictor=nutrition_extraction.MODEL_NAME, + predictor_version=nutrition_extraction.MODEL_VERSION, + data=data, + confidence=None, + server_type=product_id.server_type, + source_image=source_image, + ) + imported = import_insights([prediction], server_type=product_id.server_type) + logger.info(imported) diff --git a/tests/unit/prediction/test_nutrition_extraction.py b/tests/unit/prediction/test_nutrition_extraction.py new file mode 100644 index 0000000000..6dea1e2480 --- /dev/null +++ b/tests/unit/prediction/test_nutrition_extraction.py @@ -0,0 +1,370 @@ +from robotoff.prediction.nutrition_extraction import ( + aggregate_entities, + postprocess_aggregated_entities, +) + + +class TestProcessAggregatedEntities: + def test_postprocess_aggregated_entities_single_entity(self): + aggregated_entities = [ + { + "entity": "ENERGY_KCAL_100G", + "words": ["525", "kcal"], + "score": 0.99, + "start": 0, + "end": 2, + "char_start": 0, + "char_end": 7, + } + ] + expected_output = [ + { + "entity": "energy_kcal_100g", + "text": "525 kcal", + "value": "525", + "unit": "kcal", + "score": 0.99, + "start": 0, + "end": 2, + "char_start": 0, + "char_end": 7, + "valid": True, + } + ] + assert postprocess_aggregated_entities(aggregated_entities) == expected_output + + def test_postprocess_aggregated_entities_multiple_entities(self): + aggregated_entities = [ + { + "entity": "ENERGY_KCAL_100G", + "words": ["525", "kcal"], + "score": 0.99, + "start": 0, + "end": 2, + "char_start": 0, + "char_end": 7, + }, + { + "entity": "ENERGY_KCAL_100G", + "words": ["126", "kcal"], + "score": 0.95, + "start": 3, + "end": 5, + "char_start": 8, + "char_end": 15, + }, + ] + expected_output = [ + { + "entity": "energy_kcal_100g", + "text": "525 kcal", + "value": "525", + "unit": "kcal", + "score": 0.99, + "start": 0, + "end": 2, + "char_start": 0, + "char_end": 7, + "valid": False, + "invalid_reason": "multiple_entities", + }, + { + "entity": "energy_kcal_100g", + "text": "126 kcal", + "value": "126", + "unit": "kcal", + "score": 0.95, + "start": 3, + "end": 5, + "char_start": 8, + "char_end": 15, + "valid": False, + "invalid_reason": "multiple_entities", + }, + ] + assert postprocess_aggregated_entities(aggregated_entities) == expected_output + + def test_postprocess_aggregated_entities_no_value(self): + aggregated_entities = [ + { + "entity": "FAT_SERVING", + "words": ["fat"], + "score": 0.85, + "start": 0, + "end": 1, + "char_start": 0, + "char_end": 3, + } + ] + expected_output = [ + { + "entity": "fat_serving", + "text": "fat", + "value": None, + "unit": None, + "score": 0.85, + "start": 0, + "end": 1, + "char_start": 0, + "char_end": 3, + "valid": False, + } + ] + assert postprocess_aggregated_entities(aggregated_entities) == expected_output + + def test_postprocess_aggregated_entities_trace_value(self): + aggregated_entities = [ + { + "entity": "SALT_SERVING", + "words": ["trace"], + "score": 0.90, + "start": 0, + "end": 1, + "char_start": 0, + "char_end": 5, + } + ] + expected_output = [ + { + "entity": "salt_serving", + "text": "trace", + "value": "traces", + "unit": None, + "score": 0.90, + "start": 0, + "end": 1, + "char_start": 0, + "char_end": 5, + "valid": True, + } + ] + assert postprocess_aggregated_entities(aggregated_entities) == expected_output + + def test_postprocess_aggregated_entities_serving_size(self): + aggregated_entities = [ + { + "entity": "SERVING_SIZE", + "words": ["25", "g"], + "score": 0.95, + "start": 0, + "end": 2, + "char_start": 0, + "char_end": 5, + } + ] + expected_output = [ + { + "entity": "serving_size", + "text": "25 g", + "value": "25 g", + "unit": None, + "score": 0.95, + "start": 0, + "end": 2, + "char_start": 0, + "char_end": 5, + "valid": True, + } + ] + assert postprocess_aggregated_entities(aggregated_entities) == expected_output + + def test_postprocess_aggregated_entities_mcg(self): + aggregated_entities = [ + { + "entity": "SALT_100G", + "words": ["1.2", "mcg"], + "score": 0.95, + "start": 0, + "end": 2, + "char_start": 0, + "char_end": 7, + } + ] + expected_output = [ + { + "entity": "salt_100g", + "text": "1.2 mcg", + "value": "1.2", + "unit": "µg", + "score": 0.95, + "start": 0, + "end": 2, + "char_start": 0, + "char_end": 7, + "valid": True, + } + ] + assert postprocess_aggregated_entities(aggregated_entities) == expected_output + + def test_postprocess_aggregated_entities_merged_kcal_kj(self): + aggregated_entities = [ + { + "entity": "ENERGY_KJ_100G", + "words": ["525"], + "score": 0.99, + "start": 0, + "end": 1, + "char_start": 0, + "char_end": 3, + }, + { + "entity": "ENERGY_KCAL_100G", + "words": ["kj/126", "kcal"], + "score": 0.99, + "start": 1, + "end": 3, + "char_start": 4, + "char_end": 15, + }, + ] + expected_output = [ + { + "entity": "energy_kj_100g", + "text": "525", + "value": "525", + "unit": "kj", + "score": 0.99, + "start": 0, + "end": 1, + "char_start": 0, + "char_end": 3, + "valid": True, + }, + { + "entity": "energy_kcal_100g", + "text": "126 kcal", + "value": "126", + "unit": "kcal", + "score": 0.99, + "start": 1, + "end": 3, + "char_start": 4, + "char_end": 15, + "valid": True, + }, + ] + assert postprocess_aggregated_entities(aggregated_entities) == expected_output + + +class TestAggregateEntities: + def test_aggregate_entities_single_entity(self): + pre_entities = [ + { + "entity": "ENERGY_KCAL_100G", + "word": "525", + "score": 0.99, + "index": 0, + "char_start": 0, + "char_end": 3, + }, + { + "entity": "ENERGY_KCAL_100G", + "word": "KJ", + "score": 0.99, + "index": 1, + "char_start": 4, + "char_end": 6, + }, + { + "entity": "O", + "word": "matières", + "score": 0.99, + "index": 2, + "char_start": 7, + "char_end": 15, + }, + ] + expected_output = [ + { + "entity": "ENERGY_KCAL_100G", + "words": ["525", "KJ"], + "score": 0.99, + "start": 0, + "end": 2, + "char_start": 0, + "char_end": 6, + } + ] + assert aggregate_entities(pre_entities) == expected_output + + def test_aggregate_entities_multiple_entities(self): + pre_entities = [ + { + "entity": "SALT_SERVING", + "word": "0.1", + "score": 0.99, + "index": 0, + "char_start": 0, + "char_end": 3, + }, + { + "entity": "SALT_SERVING", + "word": "g", + "score": 0.99, + "index": 1, + "char_start": 4, + "char_end": 5, + }, + { + "entity": "PROTEINS_SERVING", + "word": "101", + "score": 0.93, + "index": 2, + "char_start": 6, + "char_end": 9, + }, + { + "entity": "O", + "word": "portion", + "score": 0.99, + "index": 3, + "char_start": 10, + "char_end": 17, + }, + { + "entity": "CARBOHYDRATES_SERVING", + "word": "126", + "score": 0.91, + "index": 4, + "char_start": 18, + "char_end": 21, + }, + { + "entity": "CARBOHYDRATES_SERVING", + "word": "g", + "score": 0.95, + "index": 5, + "char_start": 22, + "char_end": 23, + }, + ] + expected_output = [ + { + "entity": "SALT_SERVING", + "words": ["0.1", "g"], + "score": 0.99, + "start": 0, + "end": 2, + "char_start": 0, + "char_end": 5, + }, + { + "entity": "PROTEINS_SERVING", + "words": ["101"], + "score": 0.93, + "start": 2, + "end": 3, + "char_start": 6, + "char_end": 9, + }, + { + "entity": "CARBOHYDRATES_SERVING", + "words": ["126", "g"], + "score": 0.91, + "start": 4, + "end": 6, + "char_start": 18, + "char_end": 23, + }, + ] + assert aggregate_entities(pre_entities) == expected_output