From 846fbabd3dd50d201987b156f14259921017dbce Mon Sep 17 00:00:00 2001 From: weishan Date: Sat, 8 Jul 2023 14:28:59 +0800 Subject: [PATCH] test: add simple test cases --- tests/test_image_retrieval.py | 53 +++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 tests/test_image_retrieval.py diff --git a/tests/test_image_retrieval.py b/tests/test_image_retrieval.py new file mode 100644 index 0000000..5ba9342 --- /dev/null +++ b/tests/test_image_retrieval.py @@ -0,0 +1,53 @@ +from collections import namedtuple +from pathlib import Path +import sys + +import pytest + +sys.path.insert(0, "..") +from src.dinov2_retrieval.image_retriever import ImageRetriver + + +def test_init(): + args = { + "verbose": True, + "model_size": "small", + "model_path": None, + "num": 5, + "disable_cache": False, + "database": "/path/to/database", + } + Args = namedtuple("Args", list(args.keys())) + args = Args(**args) + image_retriver = ImageRetriver(args) + assert image_retriver.top_k == 5 + assert image_retriver.model_name == "dinov2_vits14" + + +def test_glob_images(tmpdir): + # create some test image files + tmpdir.join("test1.jpg").write("test") + tmpdir.join("test2.png").write("test") + tmpdir.join("test3.bmp").write("test") + + args = { + "verbose": True, + "model_size": "small", + "model_path": None, + "num": 5, + "disable_cache": False, + "database": "/path/to/database", + } + Args = namedtuple("Args", list(args.keys())) + args = Args(**args) + image_retriver = ImageRetriver(args) + + tmpdir = Path(tmpdir) + images = image_retriver.glob_images(tmpdir) + assert len(images) == 3 + assert set(images) == set( + [tmpdir / "test1.jpg", tmpdir / "test2.png", tmpdir / "test3.bmp"] + ) + + +# vim: ts=4 sw=4 sts=4 expandtab