Skip to content

Commit

Permalink
Changes for python 3.11 and new docstring style
Browse files Browse the repository at this point in the history
  • Loading branch information
CorentinBrtx committed Aug 6, 2023
1 parent 7ee7eda commit daf89da
Show file tree
Hide file tree
Showing 14 changed files with 227 additions and 355 deletions.
1 change: 0 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ max-line-length=100
[MESSAGES CONTROL]
disable=invalid-name


[BASIC]
good-names=i,j,x1,x2,y1,y2,df,id,n,fr,en,db,t,ws
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The implementation is strongly based on the 2007 paper **Automatic Panoramic Ima

## Install

**Python 3.11** or higher is required.
Clone the repository, and run the following command:

pip install -r requirements.txt
Expand Down
74 changes: 48 additions & 26 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@

import argparse
import logging
import os
import time
from typing import List
from pathlib import Path

import cv2
import numpy as np

from src.images import Image
from src.matching import MultiImageMatches, PairMatch, build_homographies, find_connected_components
from src.matching import (
MultiImageMatches,
PairMatch,
build_homographies,
find_connected_components,
)
from src.rendering import multi_band_blending, set_gain_compensations, simple_blending

parser = argparse.ArgumentParser(
Expand All @@ -22,23 +26,33 @@
Multiple panoramas can be created at once"
)

parser.add_argument(dest="data_dir", help="directory containing the images")
parser.add_argument(dest="data_dir", type=Path, help="directory containing the images")
parser.add_argument(
"-mbb",
"--multi-band-blending",
action="store_true",
help="use multi-band blending instead of simple blending",
)
parser.add_argument("--size", type=int, help="maximum dimension to resize the images to")
parser.add_argument(
"--size", type=int, help="maximum dimension to resize the images to"
)
parser.add_argument(
"--num-bands", type=int, default=5, help="number of bands for multi-band blending"
)
parser.add_argument("--mbb-sigma", type=float, default=1, help="sigma for multi-band blending")
parser.add_argument(
"--mbb-sigma", type=float, default=1, help="sigma for multi-band blending"
)

parser.add_argument("--gain-sigma-n", type=float, default=10, help="sigma_n for gain compensation")
parser.add_argument("--gain-sigma-g", type=float, default=0.1, help="sigma_g for gain compensation")
parser.add_argument(
"--gain-sigma-n", type=float, default=10, help="sigma_n for gain compensation"
)
parser.add_argument(
"--gain-sigma-g", type=float, default=0.1, help="sigma_g for gain compensation"
)

parser.add_argument("-v", "--verbose", action="store_true", help="increase output verbosity")
parser.add_argument(
"-v", "--verbose", action="store_true", help="increase output verbosity"
)

args = vars(parser.parse_args())

Expand All @@ -49,10 +63,11 @@

valid_images_extensions = {".jpg", ".png", ".bmp", ".jpeg"}

image_paths = []
for file in os.listdir(args["data_dir"]):
if os.path.splitext(file)[1].lower() in valid_images_extensions:
image_paths.append(os.path.join(args["data_dir"], file))
image_paths = [
str(filepath)
for filepath in args["data_dir"].iterdir()
if filepath.suffix.lower() in valid_images_extensions
]

images = [Image(path, args.get("size")) for path in image_paths]

Expand All @@ -65,7 +80,7 @@
logging.info("Matching images with features...")

matcher = MultiImageMatches(images)
pair_matches: List[PairMatch] = matcher.get_pair_matches()
pair_matches: list[PairMatch] = matcher.get_pair_matches()
pair_matches.sort(key=lambda pair_match: len(pair_match.matches), reverse=True)

logging.info("Finding connected components...")
Expand All @@ -83,11 +98,13 @@

for connected_component in connected_components:
component_matches = [
pair_match for pair_match in pair_matches if pair_match.image_a in connected_components[0]
pair_match
for pair_match in pair_matches
if pair_match.image_a in connected_component
]

set_gain_compensations(
connected_components[0],
connected_component,
component_matches,
sigma_n=args["gain_sigma_n"],
sigma_g=args["gain_sigma_g"],
Expand All @@ -102,20 +119,25 @@

if args["multi_band_blending"]:
logging.info("Applying multi-band blending...")
for connected_component in connected_components:
results.append(
multi_band_blending(
connected_component, num_bands=args["num_bands"], sigma=args["mbb_sigma"]
)
results = [
multi_band_blending(
connected_component,
num_bands=args["num_bands"],
sigma=args["mbb_sigma"],
)
for connected_component in connected_components
]


else:
logging.info("Applying simple blending...")
for connected_component in connected_components:
results.append(simple_blending(connected_component))
results = [
simple_blending(connected_component)
for connected_component in connected_components
]

logging.info("Saving results to %s", os.path.join(args["data_dir"], "results"))
logging.info("Saving results to %s", args["data_dir"] / "results")

os.makedirs(os.path.join(args["data_dir"], "results"), exist_ok=True)
(args["data_dir"] / "results").mkdir(exist_ok=True, parents=True)
for i, result in enumerate(results):
cv2.imwrite(os.path.join(args["data_dir"], "results", f"pano_{i}.jpg"), result)
cv2.imwrite(str(args["data_dir"] / "results" / f"pano_{i}.jpg"), result)
17 changes: 6 additions & 11 deletions src/images/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@


class Image:
def __init__(self, path: str, size: int = None):
def __init__(self, path: str, size: int | None = None) -> None:
"""
Image constructor.
Parameters
----------
path : str
path to the image
size : int, optional
maximum dimension to resize the image to, by default None
Args:
path: path to the image
size: maximum dimension to resize the image to
"""
self.path = path
self.image: np.ndarray = cv2.imread(path)
Expand All @@ -30,10 +27,8 @@ def __init__(self, path: str, size: int = None):
self.component_id: int = 0
self.gain: np.ndarray = np.ones(3, dtype=np.float32)

def compute_features(self):
"""
Compute the features and the keypoints of the image using SIFT.
"""
def compute_features(self) -> None:
"""Compute the features and the keypoints of the image using SIFT."""
descriptor = cv2.SIFT_create()
keypoints, features = descriptor.detectAndCompute(self.image, None)
self.keypoints = keypoints
Expand Down
2 changes: 2 additions & 0 deletions src/matching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
from .connected_components import find_connected_components
from .multi_images_matches import MultiImageMatches
from .pair_match import PairMatch

__all__ = ["build_homographies", "find_connected_components", "MultiImageMatches", "PairMatch"]
15 changes: 5 additions & 10 deletions src/matching/build_homographies.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
from typing import List

import numpy as np

from src.images import Image
from src.matching.pair_match import PairMatch


def build_homographies(
connected_components: List[List[Image]], pair_matches: List[PairMatch]
connected_components: list[list[Image]], pair_matches: list[PairMatch]
) -> None:
"""
Build homographies for each image of each connected component, using the pair matches.
The homographies are saved in the images themselves.
Parameters
----------
connected_components : List[List[Image]]
The connected components of the panorama.
pair_matches : List[PairMatch]
The valid pair matches.
Args:
connected_components: The connected components of the panorama
pair_matches: The valid pair matches
"""
for connected_component in connected_components:
component_matches = [
Expand Down Expand Up @@ -65,7 +60,7 @@ def build_homographies(
images_added.add(pair_match.image_b)
break

elif pair_match.image_a not in images_added and pair_match.image_b in images_added:
if pair_match.image_a not in images_added and pair_match.image_b in images_added:
pair_match.compute_homography()
homography = np.linalg.inv(pair_match.H) @ current_homography
pair_match.image_a.H = pair_match.image_b.H @ homography
Expand Down
18 changes: 6 additions & 12 deletions src/matching/connected_components.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
from typing import List

from src.matching.pair_match import PairMatch


def find_connected_components(pair_matches: List[PairMatch]) -> List[List[PairMatch]]:
def find_connected_components(pair_matches: list[PairMatch]) -> list[list[PairMatch]]:
"""
Find the connected components of the given pair matches.
Parameters
----------
pair_matches : List[PairMatch]
The list of pair matches.
Args:
pair_matches: The list of pair matches.
Returns
-------
connected_components : List[List[PairMatch]]
List of connected components.
Returns:
connected_components: List of connected components.
"""
connected_components = []
pair_matches_to_check = pair_matches.copy()
component_id = 0
while len(pair_matches_to_check) > 0:
pair_match = pair_matches_to_check.pop(0)
connected_component = set([pair_match.image_a, pair_match.image_b])
connected_component = {pair_match.image_a, pair_match.image_b}
size = len(connected_component)
stable = False
while not stable:
Expand Down
71 changes: 25 additions & 46 deletions src/matching/multi_images_matches.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,48 @@
from typing import List

import cv2

from src.images import Image
from src.matching.pair_match import PairMatch


class MultiImageMatches:
def __init__(self, images: List[Image], ratio: float = 0.75):
def __init__(self, images: list[Image], ratio: float = 0.75) -> None:
"""
Create a new MultiImageMatches object.
Parameters
----------
images : List[Image]
images to compare
ratio : float, optional
ratio used for the Lowe's ratio test, by default 0.75
Args:
images: images to compare
ratio: ratio used for the Lowe's ratio test
"""

self.images = images
self.matches = {image.path: {} for image in images}
self.ratio = ratio

def get_matches(self, image_a: Image, image_b: Image) -> List:
def get_matches(self, image_a: Image, image_b: Image) -> list:
"""
Get matches for the given images.
Parameters
----------
image_a : Image
First image.
image_b : Image
Second image.
Returns
-------
matches : List
List of matches between the two images.
Args:
image_a: First image
image_b: Second image
Returns:
matches: List of matches between the two images
"""
if image_b.path not in self.matches[image_a.path]:
matches = self.compute_matches(image_a, image_b)
self.matches[image_a.path][image_b.path] = matches

return self.matches[image_a.path][image_b.path]

def get_pair_matches(self, max_images: int = 6) -> List[PairMatch]:
def get_pair_matches(self, max_images: int = 6) -> list[PairMatch]:
"""
Get the pair matches for the given images.
Parameters
----------
max_images : int, optional
Number of matches maximum for each image, by default 6
Args:
max_images: Number of matches maximum for each image
Returns
-------
pair_matches : List[PairMatch]
List of pair matches.
Returns:
pair_matches: List of pair matches
"""
pair_matches = []
for i, image_a in enumerate(self.images):
Expand All @@ -73,30 +58,24 @@ def get_pair_matches(self, max_images: int = 6) -> List[PairMatch]:
pair_matches.append(pair_match)
return pair_matches

def compute_matches(self, image_a: Image, image_b: Image) -> List:
def compute_matches(self, image_a: Image, image_b: Image) -> list:
"""
Compute matches between image_a and image_b.
Parameters
----------
image_a : Image
First image.
image_b : Image
Second image.
Returns
-------
matches : List
Matches between image_a and image_b.
"""
Args:
image_a: First image
image_b: Second image
Returns:
matches: Matches between image_a and image_b
"""
matcher = cv2.DescriptorMatcher_create("BruteForce")
matches = []

rawMatches = matcher.knnMatch(image_a.features, image_b.features, 2)
raw_matches = matcher.knnMatch(image_a.features, image_b.features, 2)
matches = []

for m, n in rawMatches:
for m, n in raw_matches:
# ensure the distance is within a certain ratio of each
# other (i.e. Lowe's ratio test)
if m.distance < n.distance * self.ratio:
Expand Down
Loading

0 comments on commit daf89da

Please sign in to comment.