diff --git a/robotoff/cli/main.py b/robotoff/cli/main.py index 5a97ff9f2f..4c4e7e1171 100644 --- a/robotoff/cli/main.py +++ b/robotoff/cli/main.py @@ -477,6 +477,13 @@ def add_logo_to_ann( sleep_time: float = typer.Option( 0.0, help="Time to sleep between each query (in s)" ), + existing_ids_path: Optional[Path] = typer.Argument( + None, + file_okay=True, + dir_okay=False, + help="Path of the plain text file containing logo IDs (one ID per line). If not provided, " + "existing IDs will be fetched from Elasticsearch.", + ), ) -> None: """Index all missing logos in Elasticsearch ANN index.""" import logging @@ -490,26 +497,28 @@ def add_logo_to_ann( from robotoff.elasticsearch import get_es_client from robotoff.logos import add_logos_to_ann, get_stored_logo_ids from robotoff.models import LogoEmbedding, db - from robotoff.utils import get_logger + from robotoff.utils import get_logger, text_file_iter logger = get_logger() logging.getLogger("elastic_transport.transport").setLevel(logging.WARNING) es_client = get_es_client() - seen = get_stored_logo_ids(es_client) + if existing_ids_path is not None and existing_ids_path.is_file(): + seen = set(int(x) for x in text_file_iter(existing_ids_path)) + else: + seen = get_stored_logo_ids(es_client) + added = 0 with db.connection_context(): logger.info("Fetching logo embedding to index...") query = LogoEmbedding.select().objects() - logo_embedding_iter = tqdm.tqdm( - ( - logo_embedding - for logo_embedding in ServerSide(query) - if logo_embedding.logo_id not in seen - ), - desc="logo", + logo_embedding_iter = ( + logo_embedding + for logo_embedding in tqdm.tqdm(ServerSide(query), desc="logo") + if logo_embedding.logo_id not in seen ) + for logo_embedding_batch in chunked(logo_embedding_iter, 500): try: add_logos_to_ann(es_client, logo_embedding_batch, server_type)