forked from cvg/LightGlue
-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #82 from fabio-sim/feat/torch-export
feat: Dynamic batch
- Loading branch information
Showing
23 changed files
with
2,105 additions
and
197 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,78 @@ | ||
*.egg-info | ||
*.pyc | ||
/.idea/ | ||
/data/ | ||
/outputs/ | ||
__pycache__ | ||
/lightglue/weights/ | ||
lightglue/_flash/ | ||
*-checkpoint.ipynb | ||
# Models / data | ||
*.pth | ||
*.pt2 | ||
*.onnx | ||
*.engine | ||
*.profile | ||
.vscode | ||
*.timing | ||
data | ||
megadepth_test_1500 | ||
|
||
# VSCode | ||
.vscode/ | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# poetry | ||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
# This is especially recommended for binary packages to ensure reproducibility, and is more | ||
# commonly ignored for libraries. | ||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
#poetry.lock | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
from pathlib import Path | ||
from typing import Annotated, Optional | ||
|
||
import cv2 | ||
import typer | ||
|
||
from lightglue_dynamo.cli_utils import multiple_of | ||
from lightglue_dynamo.config import Extractor, InferenceDevice | ||
|
||
app = typer.Typer() | ||
|
||
|
||
@app.callback() | ||
def callback(): | ||
"""LightGlue Dynamo CLI""" | ||
|
||
|
||
@app.command() | ||
def export( | ||
extractor_type: Annotated[Extractor, typer.Argument()] = Extractor.superpoint, | ||
output: Annotated[ | ||
Optional[Path], # typer does not support Path | None # noqa: UP007 | ||
typer.Option("-o", "--output", dir_okay=False, writable=True, help="Path to save exported model."), | ||
] = None, | ||
batch_size: Annotated[ | ||
int, | ||
typer.Option( | ||
"-b", | ||
"--batch-size", | ||
min=0, | ||
help="Batch size of exported ONNX model. Set to 0 to mark as dynamic.", | ||
callback=multiple_of(2), | ||
), | ||
] = 0, | ||
height: Annotated[ | ||
int, | ||
typer.Option( | ||
"-h", "--height", min=0, help="Height of input image. Set to 0 to mark as dynamic.", callback=multiple_of(8) | ||
), | ||
] = 0, | ||
width: Annotated[ | ||
int, | ||
typer.Option( | ||
"-w", "--width", min=0, help="Width of input image. Set to 0 to mark as dynamic.", callback=multiple_of(8) | ||
), | ||
] = 0, | ||
num_keypoints: Annotated[ | ||
int, typer.Option(min=128, help="Number of keypoints outputted by feature extractor.") | ||
] = 1024, | ||
fuse_multi_head_attention: Annotated[ | ||
bool, | ||
typer.Option( | ||
"--fuse-multi-head-attention", | ||
help="Fuse multi-head attention subgraph into one optimized operation. (ONNX Runtime-only).", | ||
), | ||
] = False, | ||
opset: Annotated[int, typer.Option(min=16, max=20, help="ONNX opset version of exported model.")] = 17, | ||
fp16: Annotated[bool, typer.Option("--fp16", help="Whether to also convert to FP16.")] = False, | ||
): | ||
"""Export LightGlue to ONNX.""" | ||
import onnx | ||
import torch | ||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference | ||
from onnxruntime.transformers.float16 import convert_float_to_float16 | ||
|
||
from lightglue_dynamo.models import LightGlue, Pipeline, SuperPoint | ||
from lightglue_dynamo.ops import use_fused_multi_head_attention | ||
|
||
match extractor_type: | ||
case Extractor.superpoint: | ||
extractor = SuperPoint(num_keypoints=num_keypoints).eval() | ||
matcher = LightGlue(extractor_type).eval() | ||
pipeline = Pipeline(extractor, matcher) | ||
|
||
if output is None: | ||
output = Path(f"weights/{extractor_type}_lightglue_pipeline.onnx") | ||
|
||
if height > 0 and width > 0 and num_keypoints > height * width: | ||
raise typer.BadParameter("num_keypoints cannot be greater than height * width.") | ||
|
||
if fuse_multi_head_attention: | ||
typer.echo( | ||
"Warning: Multi-head attention nodes will be fused. Exported model will only work with ONNX Runtime CPU & CUDA execution providers." | ||
) | ||
if torch.__version__ < "2.4": | ||
raise typer.Abort("Fused multi-head attention requires PyTorch 2.4 or later.") | ||
use_fused_multi_head_attention() | ||
|
||
dynamic_axes = {"images": {}, "keypoints": {}} | ||
if batch_size == 0: | ||
dynamic_axes["images"][0] = "batch_size" | ||
dynamic_axes["keypoints"][0] = "batch_size" | ||
if height == 0: | ||
dynamic_axes["images"][2] = "height" | ||
if width == 0: | ||
dynamic_axes["images"][3] = "width" | ||
dynamic_axes |= {"matches": {0: "num_matches"}, "mscores": {0: "num_matches"}} | ||
torch.onnx.export( | ||
pipeline, | ||
torch.zeros(batch_size or 2, 1, height or 256, width or 256), | ||
str(output), | ||
input_names=["images"], | ||
output_names=["keypoints", "matches", "mscores"], | ||
opset_version=opset, | ||
dynamic_axes=dynamic_axes, | ||
) | ||
onnx.checker.check_model(output) | ||
onnx.save_model(SymbolicShapeInference.infer_shapes(onnx.load_model(output), auto_merge=True), output) # type: ignore | ||
if fp16: | ||
typer.echo( | ||
"Converting to FP16. Warning: This FP16 model should NOT be used for TensorRT. TRT provides its own fp16 option." | ||
) | ||
onnx.save_model(convert_float_to_float16(onnx.load_model(output)), output.with_suffix(".fp16.onnx")) | ||
|
||
|
||
@app.command() | ||
def infer( | ||
model_path: Annotated[Path, typer.Argument(exists=True, dir_okay=False, readable=True, help="Path to ONNX model.")], | ||
left_image_path: Annotated[ | ||
Path, typer.Argument(exists=True, dir_okay=False, readable=True, help="Path to first image.") | ||
], | ||
right_image_path: Annotated[ | ||
Path, typer.Argument(exists=True, dir_okay=False, readable=True, help="Path to second image.") | ||
], | ||
extractor_type: Annotated[Extractor, typer.Argument()] = Extractor.superpoint, | ||
output_path: Annotated[ | ||
Optional[Path], # noqa: UP007 | ||
typer.Option( | ||
"-o", | ||
"--output", | ||
dir_okay=False, | ||
writable=True, | ||
help="Path to save output matches figure. If not given, show visualization.", | ||
), | ||
] = None, | ||
height: Annotated[ | ||
int, | ||
typer.Option("-h", "--height", min=1, help="Height of input image at which to perform inference."), | ||
] = 1024, | ||
width: Annotated[ | ||
int, | ||
typer.Option("-w", "--width", min=1, help="Width of input image at which to perform inference."), | ||
] = 1024, | ||
device: Annotated[ | ||
InferenceDevice, typer.Option("-d", "--device", help="Device to run inference on.") | ||
] = InferenceDevice.cpu, | ||
fp16: Annotated[bool, typer.Option("--fp16", help="Whether model uses FP16 precision.")] = False, | ||
profile: Annotated[bool, typer.Option("--profile", help="Whether to profile model execution.")] = False, | ||
): | ||
"""Run inference for LightGlue ONNX model.""" | ||
import numpy as np | ||
import onnxruntime as ort | ||
|
||
from lightglue_dynamo import viz | ||
from lightglue_dynamo.preprocessors import SuperPointPreprocessor | ||
|
||
raw_images = [left_image_path, right_image_path] | ||
raw_images = [cv2.resize(cv2.imread(str(i)), (width, height)) for i in raw_images] | ||
images = np.stack(raw_images) | ||
match extractor_type: | ||
case Extractor.superpoint: | ||
images = SuperPointPreprocessor.preprocess(images).astype( | ||
np.float16 if fp16 and device != InferenceDevice.tensorrt else np.float32 | ||
) | ||
|
||
session_options = ort.SessionOptions() | ||
session_options.enable_profiling = profile | ||
# session_options.optimized_model_filepath = "weights/ort_optimized.onnx" | ||
|
||
providers = [("CPUExecutionProvider", {})] | ||
if device == InferenceDevice.cuda: | ||
providers.insert(0, ("CUDAExecutionProvider", {})) | ||
elif device == InferenceDevice.tensorrt: | ||
providers.insert(0, ("CUDAExecutionProvider", {})) | ||
providers.insert( | ||
0, | ||
( | ||
"TensorrtExecutionProvider", | ||
{ | ||
"trt_engine_cache_enable": True, | ||
"trt_engine_cache_path": "weights/.trtcache_engines", | ||
"trt_timing_cache_enable": True, | ||
"trt_timing_cache_path": "weights/.trtcache_timings", | ||
"trt_fp16_enable": fp16, | ||
}, | ||
), | ||
) | ||
|
||
session = ort.InferenceSession(model_path, session_options, providers) | ||
|
||
for _ in range(100 if profile else 1): | ||
keypoints, matches, mscores = session.run(None, {"images": images}) | ||
|
||
viz.plot_images(raw_images) | ||
viz.plot_matches(keypoints[0][matches[..., 1]], keypoints[1][matches[..., 2]], color="lime", lw=0.2) | ||
if output_path is None: | ||
viz.plt.show() | ||
else: | ||
viz.save_plot(output_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .disk import DISK | ||
from .lightglue import LightGlue | ||
from .superpoint import SuperPoint | ||
from .disk import DISK | ||
from .utils import match_pair | ||
from .utils import match_pair |
Oops, something went wrong.