Skip to content

Commit

Permalink
add SFD2 (CVPR 2023) and IMP (CVPR 2023) (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
feixue94 authored Jul 5, 2024
1 parent 36f69c1 commit df195f4
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 17 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@
[submodule "third_party/RoMa"]
path = third_party/RoMa
url = https://github.com/Vincentqyw/RoMa.git
[submodule "third_party/pram"]
path = third_party/pram
url = https://github.com/feixue94/pram.git
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ The tool currently supports various popular image matching algorithms, namely:
- [ ] [DUSt3R](https://github.com/naver/dust3r), arXiv 2023
- [x] [LightGlue](https://github.com/cvg/LightGlue), ICCV 2023
- [x] [DarkFeat](https://github.com/THU-LYJ-Lab/DarkFeat), AAAI 2023
- [x] [SFD2](https://github.com/feixue94/sfd2), CVPR 2023
- [x] [IMP](https://github.com/feixue94/imp-release), CVPR 2023
- [ ] [ASTR](https://github.com/ASTR2023/ASTR), CVPR 2023
- [ ] [SEM](https://github.com/SEM2023/SEM), CVPR 2023
- [ ] [DeepLSD](https://github.com/cvg/DeepLSD), CVPR 2023
Expand Down
36 changes: 30 additions & 6 deletions common/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ matcher_zoo:
enable: true
matcher: duster
dense: true
info:
info:
name: DUSt3R #dispaly name
source: "CVPR 2024"
github: https://github.com/naver/dust3r
paper: https://arxiv.org/abs/2312.14132
project: https://dust3r.europe.naverlabs.com
display: true
display: true
GIM(dkm):
enable: true
matcher: gim(dkm)
dense: true
info:
info:
name: GIM(DKM) #dispaly name
source: "ICLR 2024"
github: https://github.com/xuelunshen/gim
Expand All @@ -53,7 +53,7 @@ matcher_zoo:
RoMa:
matcher: roma
dense: true
info:
info:
name: RoMa #dispaly name
source: "CVPR 2024"
github: https://github.com/Parskatt/RoMa
Expand All @@ -63,7 +63,7 @@ matcher_zoo:
dkm:
matcher: dkm
dense: true
info:
info:
name: DKM #dispaly name
source: "CVPR 2023"
github: https://github.com/Parskatt/DKM
Expand All @@ -73,7 +73,7 @@ matcher_zoo:
loftr:
matcher: loftr
dense: true
info:
info:
name: LoFTR #dispaly name
source: "CVPR 2021"
github: https://github.com/zju3dv/LoFTR
Expand Down Expand Up @@ -363,3 +363,27 @@ matcher_zoo:
paper: https://arxiv.org/abs/2104.03362
project: null
display: true

sfd2+imp:
matcher: imp
feature: sfd2
dense: false
info:
name: SFD2+IMP #dispaly name
source: "CVPR 2023"
github: https://github.com/feixue94/imp-release
paper: https://arxiv.org/pdf/2304.14837
project: https://feixue94.github.io/
display: true

sfd2+mnn:
matcher: NN-mutual
feature: sfd2
dense: false
info:
name: SFD2+MNN #dispaly name
source: "CVPR 2023"
github: https://github.com/feixue94/sfd2
paper: https://arxiv.org/abs/2304.14845
project: https://feixue94.github.io/
display: true
39 changes: 28 additions & 11 deletions hloc/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from .utils.parsers import parse_image_lists
from .utils.io import read_image, list_h5_names


"""
A set of standard configurations that can be directly selected from the command
line using their name. Each is a dictionary with the following entries:
Expand Down Expand Up @@ -290,6 +289,24 @@
"dfactor": 8,
},
},
"sfd2": {
"output": "feats-sfd2-n4096-r1600",
"model": {
"name": "sfd2",
"max_keypoints": 4096,
},
"preprocessing": {
"grayscale": False,
"force_resize": True,
"resize_max": 1600,
"width": 640,
"height": 480,
'conf_th': 0.001,
'multiscale': False,
'scales': [1.0],

},
},
# Global descriptors
"dir": {
"output": "global-feats-dir",
Expand All @@ -316,13 +333,13 @@

def resize_image(image, size, interp):
if interp.startswith("cv2_"):
interp = getattr(cv2, "INTER_" + interp[len("cv2_") :].upper())
interp = getattr(cv2, "INTER_" + interp[len("cv2_"):].upper())
h, w = image.shape[:2]
if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]):
interp = cv2.INTER_LINEAR
resized = cv2.resize(image, size, interpolation=interp)
elif interp.startswith("pil_"):
interp = getattr(PIL.Image, interp[len("pil_") :].upper())
interp = getattr(PIL.Image, interp[len("pil_"):].upper())
resized = PIL.Image.fromarray(image.astype(np.uint8))
resized = resized.resize(size, resample=interp)
resized = np.asarray(resized, dtype=image.dtype)
Expand Down Expand Up @@ -376,7 +393,7 @@ def __getitem__(self, idx):
size = image.shape[:2][::-1]

if self.conf.resize_max and (
self.conf.force_resize or max(size) > self.conf.resize_max
self.conf.force_resize or max(size) > self.conf.resize_max
):
scale = self.conf.resize_max / max(size)
size_new = tuple(int(round(x * scale)) for x in size)
Expand Down Expand Up @@ -467,13 +484,13 @@ def preprocess(image: np.ndarray, conf: SimpleNamespace):

@torch.no_grad()
def main(
conf: Dict,
image_dir: Path,
export_dir: Optional[Path] = None,
as_half: bool = True,
image_list: Optional[Union[Path, List[str]]] = None,
feature_path: Optional[Path] = None,
overwrite: bool = False,
conf: Dict,
image_dir: Path,
export_dir: Optional[Path] = None,
as_half: bool = True,
image_list: Optional[Union[Path, List[str]]] = None,
feature_path: Optional[Path] = None,
overwrite: bool = False,
) -> Path:
logger.info(
"Extracting local features with configuration:"
Expand Down
43 changes: 43 additions & 0 deletions hloc/extractors/sfd2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# -*- coding: UTF-8 -*-
import sys
from pathlib import Path
import torchvision.transforms as tvf

import torch

from ..utils.base_model import BaseModel
from .. import logger

pram_path = Path(__file__).parent / "../../third_party/pram"
sys.path.append(str(pram_path))

from nets.sfd2 import load_sfd2, extract_sfd2_return


class SFD2(BaseModel):
default_conf = {
"max_keypoints": 4096,
"model_name": 'sfd2_20230511_210205_resnet4x.79.pth',
"conf_th": 0.001,

}
required_inputs = ["image"]

def _init(self, conf):
self.conf = {**self.default_conf, **conf}
self.norm_rgb = tvf.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
model_fn = pram_path / 'weights' / self.conf['model_name']
self.net = load_sfd2(weight_path=model_fn).cuda().eval()

logger.info(f"Load SFD2 model done.")

def _forward(self, data):
pred = self.net.extract_local_global(data={'image': self.norm_rgb(data['image'])}, config=self.conf)
out = {
"keypoints": pred["keypoints"][0][None],
"scores": pred["scores"][0][None],
"descriptors": pred["descriptors"][0][None],
}
return out
7 changes: 7 additions & 0 deletions hloc/match_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@
"match_threshold": 0.2,
},
},
"imp": {
"output": "matches-imp",
"model": {
"name": "imp",
"match_threshold": 0.2,
},
},
}


Expand Down
45 changes: 45 additions & 0 deletions hloc/matchers/imp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# -*- coding: UTF-8 -*-
import sys
from pathlib import Path

import torch

from ..utils.base_model import BaseModel
from .. import logger

pram_path = Path(__file__).parent / "../../third_party/pram"
sys.path.append(str(pram_path))

from nets.gml import GML


class IMP(BaseModel):
default_conf = {
"match_threshold": 0.2,
"features": "sfd2",
"model_name": "imp_gml.920.pth",
'sinkhorn_iterations': 20,
}
required_inputs = [
"image0",
"keypoints0",
"scores0",
"descriptors0",
"image1",
"keypoints1",
"scores1",
"descriptors1",
]

def _init(self, conf):
self.conf = {**self.default_conf, **conf}
weight_path = pram_path / "weights" / self.conf["model_name"]
self.net = GML(self.conf).eval().cuda()
self.net.load_state_dict(torch.load(weight_path)['model'], strict=True)
logger.info(f"Load IMP model done.")

def _forward(self, data):
data['descriptors0'] = data['descriptors0'].transpose(2, 1).float()
data['descriptors1'] = data['descriptors1'].transpose(2, 1).float()

return self.net.produce_matches(data, p=0.2)
1 change: 1 addition & 0 deletions third_party/pram
Submodule pram added at 7135cb

0 comments on commit df195f4

Please sign in to comment.