From fa4e9fa0837e5916450fff1834a816b00d14ec66 Mon Sep 17 00:00:00 2001 From: wangyida1 Date: Tue, 30 Jul 2024 01:33:06 +0000 Subject: [PATCH] standalone rasted depth --- configs/neus-blender.yaml | 2 +- configs/neus-colmap.yaml | 1 + datasets/blender.py | 57 ++++++++++++++++++++++++++++++--------- datasets/colmap.py | 47 ++++++++++---------------------- datasets/hmvs.py | 4 +-- utils/rast.py | 44 ++++++++++++++++++++++++++++++ 6 files changed, 106 insertions(+), 49 deletions(-) create mode 100644 utils/rast.py diff --git a/configs/neus-blender.yaml b/configs/neus-blender.yaml index c9a91bd..2b2c8f4 100644 --- a/configs/neus-blender.yaml +++ b/configs/neus-blender.yaml @@ -23,7 +23,7 @@ dataset: model: name: neus - radius: 1.5 + radius: 1.0 # 1.5 num_samples_per_ray: 2048 train_num_rays: 256 max_train_num_rays: 8192 diff --git a/configs/neus-colmap.yaml b/configs/neus-colmap.yaml index de3e85e..bc29744 100644 --- a/configs/neus-colmap.yaml +++ b/configs/neus-colmap.yaml @@ -16,6 +16,7 @@ dataset: apply_depth: false load_data_on_gpu: false max_imgs: 400 + preprocess_only: false model: name: neus diff --git a/datasets/blender.py b/datasets/blender.py index 9271666..7c89270 100644 --- a/datasets/blender.py +++ b/datasets/blender.py @@ -14,6 +14,7 @@ import datasets from models.ray_utils import get_ray_directions from utils.misc import get_rank +from utils.rast import rasterize class BlenderDatasetBase(): @@ -48,16 +49,18 @@ def setup(self, config, split): self.near, self.far = self.config.near_plane, self.config.far_plane try: - self.focal_x = meta['fl_x'] * self.factor - self.focal_y = meta['fl_y'] * self.factor + self.fx = meta['fl_x'] * self.factor + self.fy = meta['fl_y'] * self.factor self.cx = meta['cx'] * self.factor self.cy = meta['cy'] * self.factor except: - self.focal_x = 0.5 * w / math.tan(0.5 * meta['camera_angle_x']) * self.factor # scaled focal length - self.focal_y = self.focal_x + self.fx = 0.5 * w / math.tan(0.5 * meta['camera_angle_x']) * self.factor # scaled focal length + self.fy = self.fx self.cx = self.w//2 * self.factor self.cy = self.h//2 * self.factor + intrinsic = np.array([[self.fx, 0, self.cx],[0, self.fy, self.cy],[0, 0, 1]]) + try: self.k1 = meta['k1'] self.k2 = meta['k2'] @@ -74,7 +77,7 @@ def setup(self, config, split): self.k4 = 0.0 # ray directions for all pixels, same for all images (same H, W, focal) - self.directions = get_ray_directions(self.w, self.h, self.focal_x, self.focal_y, self.cx, self.cy, k1=self.k1, k2=self.k2, k3=self.k3, k4=self.k4).to(self.rank) + self.directions = get_ray_directions(self.w, self.h, self.fx, self.fy, self.cx, self.cy, k1=self.k1, k2=self.k2, k3=self.k3, k4=self.k4).to(self.rank) if not self.config.load_data_on_gpu: self.directions = self.directions.cpu() # (h, w, 3) @@ -86,7 +89,7 @@ def setup(self, config, split): for i, frame in enumerate(meta['frames']): c2w_npy = np.array(frame['transform_matrix']) # NOTE: only specific dataset, e.g. Baoru's medical images needs to convert - # c2w_npy[:3, 1:3] *= -1. # COLMAP => OpenGL + # c2w_npy[:3, 1:3] *= -1. # COLMAP => OpenGL c2w = torch.from_numpy(c2w_npy[:3, :4]) self.all_c2w.append(c2w) @@ -117,10 +120,9 @@ def setup(self, config, split): import cv2 if depth.max() != 0.0: # depth_o3d = o3d.geometry.Image(depth.numpy() * self.config.cam_downscale) - intrin_mtx = np.array([[self.focal_x, 0, self.cx], [0, self.focal_y, self.cy], [0, 0, 1]]) distort = np.array([self.k1, self.k2, self.p1, self.p2, self.k3, self.k4, 0, 0]) - depth_o3d = o3d.geometry.Image(cv2.undistort(depth.numpy(), intrin_mtx, distort)) - img_o3d = o3d.geometry.Image(cv2.undistort(img.numpy(), intrin_mtx, distort)) + depth_o3d = o3d.geometry.Image(cv2.undistort(depth.numpy(), intrinsic, distort)) + img_o3d = o3d.geometry.Image(cv2.undistort(img.numpy(), intrinsic, distort)) rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( img_o3d, depth_o3d, @@ -128,7 +130,7 @@ def setup(self, config, split): depth_scale=1, convert_rgb_to_intensity=False) intrin_o3d = o3d.camera.PinholeCameraIntrinsic( - w, h, self.focal_x, self.focal_y, self.cx, self.cy) + w, h, self.fx, self.fy, self.cx, self.cy) # Open3D uses world-to-camera extrinsics pts_frm = o3d.geometry.PointCloud.create_from_depth_image( depth_o3d, intrinsic=intrin_o3d, extrinsic=np.linalg.inv(c2w_npy), depth_scale=1) @@ -157,7 +159,8 @@ def setup(self, config, split): self.all_fg_masks.append(torch.ones_like(img[...,0], device=img.device)) # (h, w) self.all_images.append(img[...,:3]) - if self.apply_depth: + if self.apply_depth and 'depth_path' in frame: + # Using the provided depth images depth_path = os.path.join(self.config.root_dir, f"{frame['depth_path']}") if os.path.isfile(depth_path): self.all_depths.append(depth) @@ -176,6 +179,34 @@ def setup(self, config, split): else: depth_mask = (depth > 0.0).to(bool) self.all_depth_masks.append(depth_mask) + elif self.apply_depth and 'depth_path' not in frame: + # Rasterizing depth and norms from a mesh + pcd = o3d.geometry.PointCloud() + mesh_init_path = os.path.join(self.config.root_dir, 'points3D_mesh.ply') + if os.path.exists(mesh_init_path): + print(colored( + 'GT surface mesh is directly loaded', + 'blue')) + mesh_o3d = o3d.t.geometry.TriangleMesh.from_legacy(o3d.io.read_triangle_mesh(mesh_init_path)) + depth_rast_path = os.path.join( + self.config.root_dir, f'depths_{self.split}', + f"rasted_{self.config.img_downscale}") + norm_rast_path = os.path.join( + self.config.root_dir, f'normals_{self.split}', + f"rasted_{self.config.img_downscale}") + c2w_col = c2w + c2w_col[:3, 1:3] *= -1. # OpenGL => COLMAP + depth_rast, _ = rasterize(frame['file_path'], mesh_o3d, intrinsic, c2w_col, w, h, depth_rast_path, norm_rast_path) + depth_rast = TF.pil_to_tensor(depth_rast).permute( + 1, 2, 0) / self.config.cam_downscale + inf_mask = (depth_rast == float("Inf")) + depth_rast[inf_mask] = 0 + depth_rast = depth_rast.to( + self.rank + ) if self.config.load_data_on_gpu else depth_rast.cpu() + depth_rast_mask = (depth_rast > 0.0).to(bool) # trim points outside the contraction box off + self.all_depths.append(depth_rast) + self.all_depth_masks.append(depth_rast_mask) else: self.all_depths.append(torch.zeros_like(img[...,0], device=img.device)) self.all_depth_masks.append(torch.zeros_like(img[...,0], device=img.device)) @@ -188,7 +219,7 @@ def setup(self, config, split): vis_mask = torch.ones_like(img[...,0], device=img.device) self.all_vis_masks.append(vis_mask) - if self.apply_depth: + if self.apply_depth and "depth_path" in frame: pts_clt = pts_clt.voxel_down_sample(voxel_size=0.5) pts_clt.estimate_normals( search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=1.0, max_nn=20)) @@ -219,7 +250,7 @@ def setup(self, config, split): # translate - self.all_c2w[...,3] -= self.all_c2w[...,3].mean(0) + # self.all_c2w[...,3] -= self.all_c2w[...,3].mean(0) # rescale if 'cam_downscale' not in self.config: diff --git a/datasets/colmap.py b/datasets/colmap.py index a978a28..2667921 100644 --- a/datasets/colmap.py +++ b/datasets/colmap.py @@ -16,6 +16,7 @@ from models.ray_utils import get_ray_directions from utils.misc import get_rank from utils.pose_utils import get_center, normalize_poses, create_spheric_poses +from utils.rast import rasterize class ColmapDatasetBase(): # the data only has to be processed once @@ -138,8 +139,17 @@ def setup(self, config, split): pcd = o3d.geometry.PointCloud() mesh_init_path = os.path.join(self.config.root_dir, 'sparse/0/points3D_mesh.ply') if os.path.exists(mesh_init_path): + print(colored( + 'GT surface mesh is directly loaded', + 'blue')) mesh_o3d = o3d.t.geometry.TriangleMesh.from_legacy(o3d.io.read_triangle_mesh(mesh_init_path)) else: + print(colored( + 'GT surface mesh is not provided', + 'cyan')) + print(colored( + 'Processing with Poisson surface on top of given GT points', + 'blue')) mesh_dir = os.path.join(self.config.root_dir, 'meshes') os.makedirs(mesh_dir, exist_ok=True) if not os.path.exists(os.path.join(mesh_dir, 'layout_pcd_gt.ply')): @@ -174,45 +184,13 @@ def setup(self, config, split): 'blue')) mesh_o3d = o3d.t.geometry.TriangleMesh.from_legacy(o3d.io.read_triangle_mesh(mesh_poisson_path)) - # Create scene and add the mesh - scene = o3d.t.geometry.RaycastingScene() - scene.add_triangles(mesh_o3d) - - # Rays are 6D vectors with origin and ray direction. - # Here we use a helper function to create rays - rays_mesh = scene.create_rays_pinhole(intrinsic_matrix=intrinsic, extrinsic_matrix=np.linalg.inv(np.concatenate((c2w.numpy(), np.array([[0, 0, 0, 1.]])))), width_px=w, height_px=h) - - # Compute the ray intersections. - rays_rast = scene.cast_rays(rays_mesh) - - # visualize the hit distance (depth) - # save rasterized depth depth_rast_path = os.path.join( self.config.root_dir, 'depths', f"rasted_{self.config.img_downscale}") - os.makedirs(depth_rast_path, exist_ok=True) - np.save(os.path.join(depth_rast_path, d.name.split("/")[-1][:-3] + 'npy'), rays_rast['t_hit'].numpy()) - - # save rasterized norm norm_rast_path = os.path.join( self.config.root_dir, 'normals', f"rasted_{self.config.img_downscale}") - os.makedirs(norm_rast_path, exist_ok=True) - rays_rast['primitive_normals'].numpy()[:,:,1:3] *= -1 # OpenGL => COLMAP - np.save(os.path.join(norm_rast_path, d.name.split("/")[-1][:-3] + 'npy'), rays_rast['primitive_normals'].numpy()) - depth_rast = Image.fromarray(rays_rast['t_hit'].numpy()) - - save_norm_dep_vis = True - if save_norm_dep_vis: - # visualize the hit distance (depth) - norm_rast = Image.fromarray(((rays_rast['primitive_normals'].numpy() + 1) * 128).astype(np.uint8)) - depth_rast.save( - os.path.join(depth_rast_path, - d.name.split("/")[-1][:-3] + 'tiff')) - norm_rast.save( - os.path.join( - norm_rast_path, - d.name.split("/")[-1][:-3] + 'png')) + depth_rast, _ = rasterize(d.name, mesh_o3d, intrinsic, c2w, w, h, depth_rast_path, norm_rast_path) depth_rast = TF.pil_to_tensor(depth_rast).permute( 1, 2, 0) / self.config.cam_downscale inf_mask = (depth_rast == float("Inf")) @@ -227,6 +205,9 @@ def setup(self, config, split): all_depths.append(torch.zeros_like(img[...,0], device=img.device)) all_depth_masks.append(torch.zeros_like(img[...,0], device=img.device)) + if self.config.apply_depth and self.config.preprocess_only: + print(colored('Finish preprocessing.', 'green')) + exit() all_c2w, all_images, all_fg_masks, all_depths, all_depth_masks, all_vis_masks = \ torch.stack(all_c2w, dim=0).float(), \ diff --git a/datasets/hmvs.py b/datasets/hmvs.py index 33d3a58..0472e6f 100644 --- a/datasets/hmvs.py +++ b/datasets/hmvs.py @@ -226,10 +226,10 @@ def setup(self, config, split): os.makedirs(os.path.join(self.config.root_dir, valid_mask_dir), exist_ok=True) - vis_mask = Image.fromarray( + vis_mask_pil = Image.fromarray( vis_mask.cpu().detach().numpy().astype(np.uint8) * 255) - vis_mask.save(fns[i].replace(f"{img_folder}", + vis_mask_pil.save(fns[i].replace(f"{img_folder}", f"/{valid_mask_dir}")) depth_folder = 'lidar_depth' diff --git a/utils/rast.py b/utils/rast.py new file mode 100644 index 0000000..b7ad383 --- /dev/null +++ b/utils/rast.py @@ -0,0 +1,44 @@ +import os +import numpy as np +from PIL import Image +import open3d as o3d + +def rasterize(img_name, mesh_o3d, intrinsic, c2w, w, h, depth_rast_path, norm_rast_path): + # Create scene and add the mesh + scene = o3d.t.geometry.RaycastingScene() + scene.add_triangles(mesh_o3d) + + # Rays are 6D vectors with origin and ray direction. + # Here we use a helper function to create rays + rays_mesh = scene.create_rays_pinhole(intrinsic_matrix=intrinsic, extrinsic_matrix=np.linalg.inv(np.concatenate((c2w.numpy(), np.array([[0, 0, 0, 1.]])))), width_px=w, height_px=h) + + # Compute the ray intersections. + rays_rast = scene.cast_rays(rays_mesh) + + # visualize the hit distance (depth) + # save rasterized depth + os.makedirs(depth_rast_path, exist_ok=True) + if img_name.lower().endswith(('.png', '.jpg')): + img_name = img_name[:-4] + elif img_name.lower().endswith(('.jpeg')): + img_name = img_name[:-5] + np.save(os.path.join(depth_rast_path, img_name.split("/")[-1] + '.npy'), rays_rast['t_hit'].numpy()) + + # save rasterized norm + os.makedirs(norm_rast_path, exist_ok=True) + # rays_rast['primitive_normals'].numpy()[:,:,1:3] *= -1 # OpenGL => COLMAP + rays_rast['primitive_normals'].numpy()[:,:,:] *= -1 + np.save(os.path.join(norm_rast_path, img_name.split("/")[-1] + '.npy'), rays_rast['primitive_normals'].numpy()) + depth_rast = Image.fromarray(rays_rast['t_hit'].numpy()) + + # visualize the hit distance (depth) + norm_rast = Image.fromarray(((rays_rast['primitive_normals'].numpy() + 1) * 128).astype(np.uint8)) + depth_rast.save( + os.path.join(depth_rast_path, + img_name.split("/")[-1] + '.tiff')) + norm_rast.save( + os.path.join( + norm_rast_path, + img_name.split("/")[-1] + '.png')) + + return depth_rast, norm_rast