Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: logo improvements #1273

Merged
merged 3 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions robotoff/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,13 @@
sleep_time: float = typer.Option(
0.0, help="Time to sleep between each query (in s)"
),
existing_ids_path: Optional[Path] = typer.Argument(
None,
file_okay=True,
dir_okay=False,
help="Path of the plain text file containing logo IDs (one ID per line). If not provided, "
"existing IDs will be fetched from Elasticsearch.",
),
) -> None:
"""Index all missing logos in Elasticsearch ANN index."""
import logging
Expand All @@ -490,26 +497,28 @@
from robotoff.elasticsearch import get_es_client
from robotoff.logos import add_logos_to_ann, get_stored_logo_ids
from robotoff.models import LogoEmbedding, db
from robotoff.utils import get_logger
from robotoff.utils import get_logger, text_file_iter

Check warning on line 500 in robotoff/cli/main.py

View check run for this annotation

Codecov / codecov/patch

robotoff/cli/main.py#L500

Added line #L500 was not covered by tests

logger = get_logger()
logging.getLogger("elastic_transport.transport").setLevel(logging.WARNING)

es_client = get_es_client()
seen = get_stored_logo_ids(es_client)
if existing_ids_path is not None and existing_ids_path.is_file():
seen = set(int(x) for x in text_file_iter(existing_ids_path))

Check warning on line 507 in robotoff/cli/main.py

View check run for this annotation

Codecov / codecov/patch

robotoff/cli/main.py#L506-L507

Added lines #L506 - L507 were not covered by tests
else:
seen = get_stored_logo_ids(es_client)

Check warning on line 509 in robotoff/cli/main.py

View check run for this annotation

Codecov / codecov/patch

robotoff/cli/main.py#L509

Added line #L509 was not covered by tests

added = 0

with db.connection_context():
logger.info("Fetching logo embedding to index...")
query = LogoEmbedding.select().objects()
logo_embedding_iter = tqdm.tqdm(
(
logo_embedding
for logo_embedding in ServerSide(query)
if logo_embedding.logo_id not in seen
),
desc="logo",
logo_embedding_iter = (

Check warning on line 516 in robotoff/cli/main.py

View check run for this annotation

Codecov / codecov/patch

robotoff/cli/main.py#L516

Added line #L516 was not covered by tests
logo_embedding
for logo_embedding in tqdm.tqdm(ServerSide(query), desc="logo")
if logo_embedding.logo_id not in seen
)

for logo_embedding_batch in chunked(logo_embedding_iter, 500):
try:
add_logos_to_ann(es_client, logo_embedding_batch, server_type)
Expand Down
4 changes: 3 additions & 1 deletion robotoff/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ def event_api() -> str:


# ANN index parameters
K_NEAREST_NEIGHBORS = 100
# K_NEAREST_NEIGHBORS is the number of closest nearest neighbor we consider
# when predicting the value of a logo
K_NEAREST_NEIGHBORS = 10

# image moderation service
IMAGE_MODERATION_SERVICE_URL: Optional[str] = os.environ.get(
Expand Down
4 changes: 2 additions & 2 deletions scripts/category_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ def run(lang: Optional[str] = None):

category_taxonomy = get_taxonomy("category")
with open(WRITE_PATH / "categories.full.json", "w") as f:
f.write(json.dumps(category_taxonomy.to_dict()))
json.dump(category_taxonomy.to_dict(), f)

ingredient_taxonomy = get_taxonomy("ingredient")
with open(WRITE_PATH / "ingredients.full.json", "w") as f:
f.write(json.dumps(category_taxonomy.to_dict()))
json.dump(ingredient_taxonomy.to_dict(), f)

datasets = generate_train_test_val_datasets(
category_taxonomy, ingredient_taxonomy, training_stream, lang
Expand Down
Loading