Skip to content

Commit

Permalink
using open3d
Browse files Browse the repository at this point in the history
  • Loading branch information
simone-ferrari committed Nov 15, 2024
1 parent b3353da commit 217691d
Showing 1 changed file with 49 additions and 28 deletions.
77 changes: 49 additions & 28 deletions mad_icp/apps/utils/tools/mad_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@

import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import open3d as o3d
from scipy.spatial.transform import Rotation as R
import typer
from typing_extensions import Annotated
import time

# binded vectors and madtree
from mad_icp.src.pybind.pyvector import VectorEigen3d
Expand All @@ -41,12 +41,11 @@

from mad_icp.apps.utils.tools.tools_utils import generate_four_walls_pointcloud


MAX_ITERATIONS = 15
app = typer.Typer()
T_guess = np.eye(4)

def main(viz: Annotated[bool, typer.Option(help="if true visualizer on (very slow)", show_default=True)] = False) -> None:
def main(viz: Annotated[bool, typer.Option(help="if true visualizer on", show_default=True)] = False) -> None:

# initialize reference and query clouds
np.random.seed(42)
Expand All @@ -69,22 +68,27 @@ def main(viz: Annotated[bool, typer.Option(help="if true visualizer on (very slo
print("estimate \n", T_est)
exit(0)

# create figure and 3d axis for visualization
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ref_scatter = ax.scatter(ref_cloud[:, 0], ref_cloud[:, 1], ref_cloud[:, 2], color='blue', s=1, label='reference cloud')
# create open3d point clouds for visualization
ref_pcd = o3d.geometry.PointCloud()
ref_pcd.points = o3d.utility.Vector3dVector(ref_cloud)
ref_pcd.paint_uniform_color([0, 0, 1]) # blue

query_pcd = o3d.geometry.PointCloud()
query_cloud_transformed = np.array([T_guess[:3, :3] @ point + T_guess[:3, 3] for point in query_cloud])
query_scatter = ax.scatter(query_cloud_transformed[:, 0], query_cloud_transformed[:, 1], query_cloud_transformed[:, 2], color='red', s=1, label='query cloud')
query_pcd.points = o3d.utility.Vector3dVector(query_cloud_transformed)
query_pcd.paint_uniform_color([1, 0, 0]) # red

# set axis limits
ax.legend()
vis = o3d.visualization.Visualizer()
vis.create_window()
vis.add_geometry(ref_pcd)
vis.add_geometry(query_pcd)

# this is just to visualize convergence of ICP
# and matches during iterations
tree = MADtree()
tree.build(VectorEigen3d(ref_cloud))

# function to update the animation for each iteration
# function to update the visualization for each iteration
def update(iteration):
global T_guess
# perform one iteration of icp
Expand All @@ -94,26 +98,43 @@ def update(iteration):
ref_cloud_matched = tree.searchCloud(VectorEigen3d(query_cloud_transformed))
ref_cloud_points = np.array([ref_point_matched for ref_point_matched, _ in ref_cloud_matched])

# remove existing lines from plot
[line.remove() for line in ax.lines]
# update the query cloud point cloud data
query_pcd.points = o3d.utility.Vector3dVector(query_cloud_transformed)
vis.update_geometry(query_pcd)
vis.poll_events()
vis.update_renderer()
vis.get_render_option().point_size = 5
vis.get_render_option().background_color = np.asarray([1, 1, 1])
vis.get_render_option().show_coordinate_frame = False
vis.get_render_option().line_width = 2.0

# draw lines between each matched point pair
[ax.plot([qp[0], rp[0]], [qp[1], rp[1]], [qp[2], rp[2]], color='green', linestyle='-', linewidth=0.2)
for qp, rp in zip(query_cloud_transformed, ref_cloud_points)]

# update the query cloud scatter plot data
query_scatter._offsets3d = (query_cloud_transformed[:, 0], query_cloud_transformed[:, 1], query_cloud_transformed[:, 2])
ax.set_title(f"MAD registration: iteration {iteration+1}")

# create animation object, updating for each iteration
ani = FuncAnimation(fig, update, frames=MAX_ITERATIONS, interval=0, repeat=False)

# display the animation
plt.show()
lines = [[i, i + len(query_cloud_transformed)] for i in range(len(query_cloud_transformed))]
colors = [[0, 1, 0] for _ in range(len(query_cloud_transformed))] # green
line_set = o3d.geometry.LineSet(
points=o3d.utility.Vector3dVector(np.vstack((query_cloud_transformed, ref_cloud_points))),
lines=o3d.utility.Vector2iVector(lines),
)
line_set.colors = o3d.utility.Vector3dVector(colors)
vis.add_geometry(line_set)
vis.poll_events()
vis.update_renderer()
return line_set

# update the visualization for each iteration
line_set = None
for iteration in range(MAX_ITERATIONS):
if line_set is not None:
vis.remove_geometry(line_set)
line_set = update(iteration)

time.sleep(0.1) # sleep for 100 ms

vis.run()
vis.destroy_window()

def run():
typer.run(main)

if __name__ == '__main__':
run()

run()

0 comments on commit 217691d

Please sign in to comment.