Skip to content

Commit

Permalink
feat: first work version
Browse files Browse the repository at this point in the history
  • Loading branch information
weishan committed Jul 1, 2023
1 parent 4bf4a76 commit 35cd55f
Show file tree
Hide file tree
Showing 3 changed files with 371 additions and 0 deletions.
32 changes: 32 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "dinov2-retrieval"
version = "0.0.1"
authors = [
{ name="Yunfeng Wang", email="wyf.brz@gmail.com" },
]
description = "A cli program of image retrieval using dinov2"
readme = "README.md"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
"appdirs",
"torch",
"pillow",
"numpy",
"loguru",
]

[project.scripts]
dinov2_retrieval="dinov2_retrieval.__init__:main"

[project.urls]
"Homepage" = "https://github.com/vra/dinov2-retrieval"
"Bug Tracker" = "https://github.com/vra/dinov2-retrieval/issues"
77 changes: 77 additions & 0 deletions src/dinov2_retrieval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import argparse

from .image_retriever import ImageRetriver


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-s",
"--model-size",
type=str,
default="small",
choices=["small", "base", "large", "largest"],
help="DinoV2 model type",
)
parser.add_argument(
"-p",
"--model-path",
type=str,
default=None,
help="path to dinov2 model, useful when github is unavailable",
)
parser.add_argument(
"-o",
"--output-root",
type=str,
default="output",
help="root folder to save output results",
)
parser.add_argument(
"-q",
"--query",
type=str,
help="path to a query image file or image folder",
)
parser.add_argument(
"-d",
"--database",
type=str,
help="path to the database image file or image folder",
)
parser.add_argument(
"-n",
"--num",
type=int,
default=1,
help="How many images to show in retrieval results",
)
parser.add_argument(
"--size",
type=int,
default=224,
help="image output size",
)
parser.add_argument(
"-m",
"--margin",
type=int,
default=10,
choices=range(0, 100),
help="margin size (in pixel) between concatenated images",
)
parser.add_argument(
"--disable-cache",
action="store_true",
help="don't cache database features, will extract features each time, quite time-consuming for large database",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="show detailed logs",
)
args = parser.parse_args()

image_retriever = ImageRetriver(args)
image_retriever.run(args)
262 changes: 262 additions & 0 deletions src/dinov2_retrieval/image_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
import http
from pathlib import Path
import pickle
import sys

import appdirs
from loguru import logger
import numpy as np
import PIL
from PIL import Image, ImageOps
import torch
from torchvision import models, transforms
from tqdm import tqdm


class ImageRetriver:
"""
image retrieval object using dinov2.
"""

def __init__(self, args):
# setup logging
logger.remove()

if args.verbose:
logger.add(sys.stdout, level="DEBUG")
else:
logger.add(sys.stdout, level="INFO")

# Load dinov2 model
model_name_dict = {
"small": "dinov2_vits14",
"base": "dinov2_vitb14",
"large": "dinov2_vitl14",
"largest": "dinov2_vitg14",
}

model_name = model_name_dict[args.model_size]

model_folder = (
"facebookresearch/dinov2" if args.model_path is None else args.model_path
)
model_source = "github" if args.model_path is None else "local"

try:
logger.info(f"loading {model_name=} from {model_folder=}")
self.model = torch.hub.load(
model_folder,
model_name,
source=model_source,
)
except FileNotFoundError:
logger.error(f"load model failed. please check if {model_folder=} exists")
sys.exit(1)
except http.client.RemoteDisconnected:
logger.error(
f"connect to github is reset. maybe set --model-path to $HOME/.cache/torch/hub/facebookresearch_dinov2_main ?"
)
sys.exit(1)

self.model.eval()

# Construct image tranforms
self.transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)

self.top_k = args.num
self.model_name = model_name

if not args.disable_cache:
cache_root_folder = Path(
appdirs.user_cache_dir(appname="dinov2_retrieval", appauthor="vra")
)
cache_root_folder.mkdir(parents=True, exist_ok=True)
self.database_features_cache_path = cache_root_folder / (
Path(args.database).name + "_" + model_name + ".pkl"
)
logger.debug(f"{cache_root_folder=}, {self.database_features_cache_path=}")

def glob_images(self, path):
"""Find all image files in path"""
return (
list(path.rglob("*.jpg"))
+ list(path.rglob("*.JPG"))
+ list(path.rglob("*.jpeg"))
+ list(path.rglob("*.png"))
+ list(path.rglob("*.bmp"))
)

def extract_database_features(self, database_img_paths):
"""Extract database dinov2 features"""
database_features = []
for img_path in tqdm(database_img_paths):
img = Image.open(str(img_path)).convert("RGB")
feature = self.extract_single_image_feature(img)
database_features.append(feature)
return database_features

def run(self, args):
"""Run image retrieval on query image(s) using dinvo v2"""

# get query images
query_path = Path(args.query)
if query_path.is_dir():
query_paths = self.glob_images(query_path)
else:
query_paths = [query_path]

logger.debug(f"query image paths: {list(query_paths)}")

if len(query_paths) < 1:
logger.warning(f"no query image, exit")
return

database_img_paths = self.glob_images(Path(args.database))

if len(database_img_paths) < 1:
logger.warning(f"database does not contain images, exit")
return

# set top_k to valid range
self.top_k = min(self.top_k, len(database_img_paths))

# Extract features for database images

if args.disable_cache or not self.database_features_cache_path.exists():
logger.info("preparing database features")
database_features = self.extract_database_features(database_img_paths)
if not args.disable_cache:
pickle.dump(
database_features,
open(str(self.database_features_cache_path), "wb"),
)
else:
logger.info(
f"Load cached database features from {self.database_features_cache_path}"
)
database_features = pickle.load(
open(str(self.database_features_cache_path), "rb")
)

for img_path in query_paths:
logger.info(f"processing {img_path}")
try:
img = Image.open(str(img_path)).convert("RGB")
except PIL.UnidentifiedImageError:
logger.debug(f"query path is not a image: {img_path}")
continue

logger.debug("Extracting features on query image")
feature = self.extract_single_image_feature(img)

logger.debug("Calculate similarity")
distances = self.calculate_distance(feature, database_features)
closest_indices = np.argsort(distances)[::-1][: self.top_k]
sorted_distances = np.sort(distances)[::-1][: self.top_k]

self.save_result(
args,
img,
img_path,
database_img_paths,
closest_indices,
sorted_distances,
)

def calculate_distance(self, query_feature, database_features):
cosine_distances = [
np.dot(query_feature, feature)
/ (np.linalg.norm(query_feature) * np.linalg.norm(feature))
for feature in database_features
]
return cosine_distances

def save_result(
self,
args,
query_image,
query_path,
database_img_paths,
closest_indices,
sorted_distances,
):
img_save_folder = (
Path(args.output_root)
/ Path(args.database).name
/ Path(self.model_name).name
)
img_save_path = img_save_folder / (
query_path.stem + "_output" + query_path.suffix
)
logger.info(f"Save results to {img_save_path}")

img_save_folder.mkdir(parents=True, exist_ok=True)

# pad and resize image, in order to combine query and retrieved image in a single image
query_image = self.process_image_for_visualization(args, query_image)

vis_img_list = [query_image]
for idx, img_idx in enumerate(closest_indices):
img_path = database_img_paths[img_idx]
similarity = sorted_distances[idx]
logger.debug(
f"{idx}th similar image is {img_path}, similarity is {similarity}"
)
cur_img = Image.open(img_path)
cur_img = self.process_image_for_visualization(args, cur_img)
vis_img_list.append(cur_img)

x_offset = 0
out_img = Image.new(
"RGB", (args.size * (self.top_k + 1) + args.margin * self.top_k, args.size)
)
for img in vis_img_list:
out_img.paste(img, (x_offset, 0))
x_offset += img.width + args.margin

out_img.save(str(img_save_path))

def process_image_for_visualization(self, args, img):
"""pad then resize image to target size"""
width, height = img.size
if width > height:
new_width = args.size
new_height = int((new_width / width) * height)
else:
new_height = args.size
new_width = int((new_height / height) * width)

img = img.resize((new_width, new_height))

width, height = img.size
target_size = args.size
width, height = img.size
delta_w = target_size - width
delta_h = target_size - height
padding = (
delta_w // 2,
delta_h // 2,
delta_w - (delta_w // 2),
delta_h - (delta_h // 2),
)

# fill with gray color
padded_img = ImageOps.expand(img, padding, fill=0)
return padded_img

def extract_single_image_feature(self, image):
"""extract backbone feature of dino v2 model on a single image"""
net_input = self.transform(image).unsqueeze(0)

with torch.no_grad():
feature = self.model(net_input).squeeze().numpy()
return feature

0 comments on commit 35cd55f

Please sign in to comment.