Skip to content

Commit

Permalink
fix: fix issue with model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Jul 11, 2024
1 parent 5b069af commit 1e15d72
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 14 deletions.
2 changes: 1 addition & 1 deletion robotoff/insights/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def run_object_detection_model(
"""
if (
existing_image_prediction := ImagePrediction.get_or_none(
image=image_model, model_name=model_name.value
image=image_model, model_name=model_name.get_type()
)
) is not None:
if return_null_if_exist:
Expand Down
27 changes: 16 additions & 11 deletions robotoff/insights/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,11 @@ def sort_candidates(candidates: Iterable[ProductInsight]) -> list[ProductInsight
candidates,
key=lambda candidate: (
candidate.data.get("priority", 1),
-int(get_image_id(candidate.source_image) or 0)
if candidate.source_image
else 0,
(
-int(get_image_id(candidate.source_image) or 0)
if candidate.source_image
else 0
),
# automatically processable insights come first
-int(candidate.automatic_processing),
# hack to set a higher priority to prediction with a predictor
Expand Down Expand Up @@ -462,9 +464,11 @@ def sort_predictions(cls, predictions: Iterable[Prediction]) -> list[Prediction]
predictions,
key=lambda prediction: (
prediction.data.get("priority", 1),
-int(get_image_id(prediction.source_image) or 0)
if prediction.source_image
else 0,
(
-int(get_image_id(prediction.source_image) or 0)
if prediction.source_image
else 0
),
# hack to set a higher priority to prediction with a predictor
# value
prediction.predictor or "z",
Expand Down Expand Up @@ -1218,8 +1222,7 @@ def get_nutrition_table_predictions(
.where(
ImageModel.barcode == product_id.barcode,
ImageModel.server_type == product_id.server_type.name,
ImagePrediction.model_name
== ObjectDetectionModel.nutrition_table_yolo.value,
ImagePrediction.model_name == ObjectDetectionModel.get_type(),
ImagePrediction.max_confidence >= min_score,
)
.tuples()
Expand Down Expand Up @@ -1281,9 +1284,11 @@ def generate_candidates(
# `nutrient` prediction is optional, so the dict value associated
# with `nutrient` PredictionType may be null
image_prediction_by_type = {
type_: [p for p in image_predictions if p.type == type_][0]
if any(p for p in image_predictions if p.type == type_)
else None
type_: (
[p for p in image_predictions if p.type == type_][0]
if any(p for p in image_predictions if p.type == type_)
else None
)
for type_ in (
PredictionType.nutrient_mention,
PredictionType.nutrient,
Expand Down
22 changes: 20 additions & 2 deletions robotoff/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,28 @@

class ObjectDetectionModel(enum.Enum):
nutriscore = "nutriscore"
nutriscore_yolo = "nutriscore"
nutriscore_yolo = "nutriscore-yolo"
universal_logo_detector = "universal-logo-detector"
nutrition_table = "nutrition-table"
nutrition_table_yolo = "nutrition-table"
nutrition_table_yolo = "nutrition-table-yolo"

def get_type(self) -> str:
"""This helper function is useful as long as we have two model (yolo
and tf) for each type of detection.
Once we've migrated all models to Yolo, we can remove this function.
"""
if self in (
ObjectDetectionModel.nutriscore,
ObjectDetectionModel.nutriscore_yolo,
):
return "nutriscore"
if self in (
ObjectDetectionModel.nutrition_table,
ObjectDetectionModel.nutrition_table_yolo,
):
return "nutrition-table"

return "universal-logo-detector"


@enum.unique
Expand Down

0 comments on commit 1e15d72

Please sign in to comment.