From 0624efeb666c63a2cd2975bbee536041f6a901d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Fri, 27 Oct 2023 16:18:12 +0200 Subject: [PATCH] fix: fix previously introduced issues --- robotoff/prediction/ocr/location.py | 2 +- robotoff/prediction/ocr/packager_code.py | 4 ++-- robotoff/prediction/ocr/trace.py | 7 +++++-- robotoff/utils/image.py | 3 +++ tests/unit/prediction/ocr/test_location.py | 9 +++++---- 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/robotoff/prediction/ocr/location.py b/robotoff/prediction/ocr/location.py index 8b4b456f54..79c6b8bfc4 100644 --- a/robotoff/prediction/ocr/location.py +++ b/robotoff/prediction/ocr/location.py @@ -31,7 +31,7 @@ class City: coordinates: Optional[tuple[float, float]] -@cache() +@cache def load_cities_fr(source: Union[Path, BinaryIO, None] = None) -> set[City]: """Load French cities dataset. diff --git a/robotoff/prediction/ocr/packager_code.py b/robotoff/prediction/ocr/packager_code.py index 3131a95ee8..d0e336a5cb 100644 --- a/robotoff/prediction/ocr/packager_code.py +++ b/robotoff/prediction/ocr/packager_code.py @@ -57,7 +57,7 @@ def process_USDA_match_to_flashtext(match) -> Optional[str]: return USDA_code -@cache() +@cache def generate_USDA_code_keyword_processor() -> KeywordProcessor: """Builds the KeyWordProcessor for USDA codes.""" @@ -184,7 +184,7 @@ def find_packager_codes_regex(content: Union[OCRResult, str]) -> list[Prediction return results -@cache() +@cache def generate_fishing_code_keyword_processor() -> KeywordProcessor: codes = text_file_iter(settings.OCR_FISHING_FLASHTEXT_DATA_PATH) return generate_keyword_processor(("{}||{}".format(c.upper(), c) for c in codes)) diff --git a/robotoff/prediction/ocr/trace.py b/robotoff/prediction/ocr/trace.py index 63462c8a1f..8b3494c0e2 100644 --- a/robotoff/prediction/ocr/trace.py +++ b/robotoff/prediction/ocr/trace.py @@ -13,6 +13,7 @@ from robotoff import settings from robotoff.types import Prediction, PredictionType from robotoff.utils import text_file_iter +from robotoff.utils.text.flashtext import KeywordProcessor from .utils import generate_keyword_processor @@ -21,8 +22,10 @@ PREDICTOR_VERSION = "1" -@cache() -def generate_trace_keyword_processor(labels: Optional[list[str]] = None): +@cache +def generate_trace_keyword_processor( + labels: Optional[list[str]] = None, +) -> KeywordProcessor: if labels is None: labels = list(text_file_iter(settings.OCR_TRACE_ALLERGEN_DATA_PATH)) diff --git a/robotoff/utils/image.py b/robotoff/utils/image.py index 800485551b..e909a310ab 100644 --- a/robotoff/utils/image.py +++ b/robotoff/utils/image.py @@ -72,6 +72,9 @@ def get_image_from_url( error_raise=error_raise, session=session, ) + + if content_bytes is None: + return None else: r = _get_image_from_url(image_url, error_raise, session) if r is None: diff --git a/tests/unit/prediction/ocr/test_location.py b/tests/unit/prediction/ocr/test_location.py index dec32a6504..c9681eada0 100644 --- a/tests/unit/prediction/ocr/test_location.py +++ b/tests/unit/prediction/ocr/test_location.py @@ -28,7 +28,8 @@ def test_load_cities_fr(mocker): ], ) - res = load_cities_fr() + # Bypass the cache + res = load_cities_fr.__wrapped__() m_gzip_open.assert_called_once_with(settings.OCR_CITIES_FR_PATH, "rb") m_json_load.assert_called_once_with(m_gzip_open.return_value.__enter__.return_value) @@ -46,7 +47,7 @@ def test_load_cities_fr(mocker): with pytest.raises( ValueError, match="'123', invalid FR postal code for city 'yolo'" ): - load_cities_fr() + load_cities_fr.__wrapped__() m_gzip_open.assert_called_once_with(settings.OCR_CITIES_FR_PATH, "rb") m_json_load.assert_called_once_with(m_gzip_open.return_value.__enter__.return_value) @@ -60,14 +61,14 @@ def test_load_cities_fr(mocker): with pytest.raises( ValueError, match="'12A42', invalid FR postal code for city 'yolo'" ): - load_cities_fr() + load_cities_fr.__wrapped__() m_gzip_open.assert_called_once_with(settings.OCR_CITIES_FR_PATH, "rb") m_json_load.assert_called_once_with(m_gzip_open.return_value.__enter__.return_value) def test_cities_fr_dataset(): - cities_fr = load_cities_fr() + cities_fr = load_cities_fr.__wrapped__() assert all(isinstance(item, City) for item in cities_fr) assert len(set(cities_fr)) == len(cities_fr)