# Project EmbodiedGen # # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import argparse import logging import math import os import cv2 import numpy as np import nvdiffrast.torch as dr import spaces import torch import torch.nn.functional as F import trimesh import xatlas from PIL import Image from embodied_gen.data.mesh_operator import MeshFixer from embodied_gen.data.utils import ( CameraSetting, DiffrastRender, get_images_from_grid, init_kal_camera, normalize_vertices_array, post_process_texture, save_mesh_with_mtl, ) from embodied_gen.models.delight_model import DelightingModel from embodied_gen.models.sr_model import ImageRealESRGAN logging.basicConfig( format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO ) logger = logging.getLogger(__name__) __all__ = [ "TextureBacker", ] def _transform_vertices( mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False ) -> torch.Tensor: """Transform 3D vertices using a projection matrix.""" t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype) if pos.size(-1) == 3: pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1) result = pos @ t_mtx.T return result if keepdim else result.unsqueeze(0) def _bilinear_interpolation_scattering( image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor ) -> torch.Tensor: """Bilinear interpolation scattering for grid-based value accumulation.""" device = values.device dtype = values.dtype C = values.shape[-1] indices = coords * torch.tensor( [image_h - 1, image_w - 1], dtype=dtype, device=device ) i, j = indices.unbind(-1) i0, j0 = ( indices.floor() .long() .clamp(0, image_h - 2) .clamp(0, image_w - 2) .unbind(-1) ) i1, j1 = i0 + 1, j0 + 1 w_i = i - i0.float() w_j = j - j0.float() weights = torch.stack( [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j], dim=1, ) indices_comb = torch.stack( [ torch.stack([i0, j0], dim=1), torch.stack([i0, j1], dim=1), torch.stack([i1, j0], dim=1), torch.stack([i1, j1], dim=1), ], dim=1, ) grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype) cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype) for k in range(4): idx = indices_comb[:, k] w = weights[:, k].unsqueeze(-1) stride = torch.tensor([image_w, 1], device=device, dtype=torch.long) flat_idx = (idx * stride).sum(-1) grid.view(-1, C).scatter_add_( 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w ) cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w) mask = cnt.squeeze(-1) > 0 grid[mask] = grid[mask] / cnt[mask].repeat(1, C) return grid def _texture_inpaint_smooth( texture: np.ndarray, mask: np.ndarray, vertices: np.ndarray, faces: np.ndarray, uv_map: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """Perform texture inpainting using vertex-based color propagation.""" image_h, image_w, C = texture.shape N = vertices.shape[0] # Initialize vertex data structures vtx_mask = np.zeros(N, dtype=np.float32) vtx_colors = np.zeros((N, C), dtype=np.float32) unprocessed = [] adjacency = [[] for _ in range(N)] # Build adjacency graph and initial color assignment for face_idx in range(faces.shape[0]): for k in range(3): uv_idx_k = faces[face_idx, k] v_idx = faces[face_idx, k] # Convert UV to pixel coordinates with boundary clamping u = np.clip( int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1 ) v = np.clip( int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))), 0, image_h - 1, ) if mask[v, u]: vtx_mask[v_idx] = 1.0 vtx_colors[v_idx] = texture[v, u] elif v_idx not in unprocessed: unprocessed.append(v_idx) # Build undirected adjacency graph neighbor = faces[face_idx, (k + 1) % 3] if neighbor not in adjacency[v_idx]: adjacency[v_idx].append(neighbor) if v_idx not in adjacency[neighbor]: adjacency[neighbor].append(v_idx) # Color propagation with dynamic stopping remaining_iters, prev_count = 2, 0 while remaining_iters > 0: current_unprocessed = [] for v_idx in unprocessed: valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0] if not valid_neighbors: current_unprocessed.append(v_idx) continue # Calculate inverse square distance weights neighbors_pos = vertices[valid_neighbors] dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1) weights = 1 / np.maximum(dist_sq, 1e-8) vtx_colors[v_idx] = np.average( vtx_colors[valid_neighbors], weights=weights, axis=0 ) vtx_mask[v_idx] = 1.0 # Update iteration control if len(current_unprocessed) == prev_count: remaining_iters -= 1 else: remaining_iters = min(remaining_iters + 1, 2) prev_count = len(current_unprocessed) unprocessed = current_unprocessed # Generate output texture inpainted_texture, updated_mask = texture.copy(), mask.copy() for face_idx in range(faces.shape[0]): for k in range(3): v_idx = faces[face_idx, k] if not vtx_mask[v_idx]: continue # UV coordinate conversion uv_idx_k = faces[face_idx, k] u = np.clip( int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1 ) v = np.clip( int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))), 0, image_h - 1, ) inpainted_texture[v, u] = vtx_colors[v_idx] updated_mask[v, u] = 255 return inpainted_texture, updated_mask class TextureBacker: """Texture baking pipeline for multi-view projection and fusion. This class performs UV-based texture generation for a 3D mesh using multi-view color images, depth, and normal information. The pipeline includes mesh normalization and UV unwrapping, visibility-aware back-projection, confidence-weighted texture fusion, and inpainting of missing texture regions. Args: camera_params (CameraSetting): Camera intrinsics and extrinsics used for rendering each view. view_weights (list[float]): A list of weights for each view, used to blend confidence maps during texture fusion. render_wh (tuple[int, int], optional): Resolution (width, height) for intermediate rendering passes. Defaults to (2048, 2048). texture_wh (tuple[int, int], optional): Output texture resolution (width, height). Defaults to (2048, 2048). bake_angle_thresh (int, optional): Maximum angle (in degrees) between view direction and surface normal for projection to be considered valid. Defaults to 75. mask_thresh (float, optional): Threshold applied to visibility masks during rendering. Defaults to 0.5. smooth_texture (bool, optional): If True, apply post-processing (e.g., blurring) to the final texture. Defaults to True. """ def __init__( self, camera_params: CameraSetting, view_weights: list[float], render_wh: tuple[int, int] = (2048, 2048), texture_wh: tuple[int, int] = (2048, 2048), bake_angle_thresh: int = 75, mask_thresh: float = 0.5, smooth_texture: bool = True, ) -> None: self.camera_params = camera_params self.renderer = None self.view_weights = view_weights self.device = camera_params.device self.render_wh = render_wh self.texture_wh = texture_wh self.mask_thresh = mask_thresh self.smooth_texture = smooth_texture self.bake_angle_thresh = bake_angle_thresh self.bake_unreliable_kernel_size = int( (2 / 512) * max(self.render_wh[0], self.render_wh[1]) ) def _lazy_init_render(self, camera_params, mask_thresh): if self.renderer is None: camera = init_kal_camera(camera_params) mv = camera.view_matrix() # (n 4 4) world2cam p = camera.intrinsics.projection_matrix() # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa p[:, 1, 1] = -p[:, 1, 1] self.renderer = DiffrastRender( p_matrix=p, mv_matrix=mv, resolution_hw=camera_params.resolution_hw, context=dr.RasterizeCudaContext(), mask_thresh=mask_thresh, grad_db=False, device=self.device, antialias_mask=True, ) def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh: mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) self.scale, self.center = scale, center vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces) uvs[:, 1] = 1 - uvs[:, 1] mesh.vertices = mesh.vertices[vmapping] mesh.faces = indices mesh.visual.uv = uvs return mesh def get_mesh_np_attrs( self, mesh: trimesh.Trimesh, scale: float = None, center: np.ndarray = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: vertices = mesh.vertices.copy() faces = mesh.faces.copy() uv_map = mesh.visual.uv.copy() uv_map[:, 1] = 1.0 - uv_map[:, 1] if scale is not None: vertices = vertices / scale if center is not None: vertices = vertices + center return vertices, faces, uv_map def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor: depth_image_np = depth_image.cpu().numpy() depth_image_np = (depth_image_np * 255).astype(np.uint8) depth_edges = cv2.Canny(depth_image_np, 30, 80) sketch_image = ( torch.from_numpy(depth_edges).to(depth_image.device).float() / 255 ) sketch_image = sketch_image.unsqueeze(-1) return sketch_image def compute_enhanced_viewnormal( self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor ) -> torch.Tensor: rast, _ = self.renderer.compute_dr_raster(vertices, faces) rendered_view_normals = [] for idx in range(len(mv_mtx)): pos_cam = _transform_vertices(mv_mtx[idx], vertices, keepdim=True) pos_cam = pos_cam[:, :3] / pos_cam[:, 3:] v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3)) face_norm = F.normalize( torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1 ) vertex_norm = ( torch.from_numpy( trimesh.geometry.mean_vertex_normals( len(pos_cam), faces.cpu(), face_norm.cpu() ) ) .to(vertices.device) .contiguous() ) im_base_normals, _ = dr.interpolate( vertex_norm[None, ...].float(), rast[idx : idx + 1], faces.to(torch.int32), ) rendered_view_normals.append(im_base_normals) rendered_view_normals = torch.cat(rendered_view_normals, dim=0) return rendered_view_normals def back_project( self, image, vis_mask, depth, normal, uv ) -> tuple[torch.Tensor, torch.Tensor]: image = np.array(image) image = torch.as_tensor(image, device=self.device, dtype=torch.float32) if image.ndim == 2: image = image.unsqueeze(-1) image = image / 255 depth_inv = (1.0 - depth) * vis_mask sketch_image = self._render_depth_edges(depth_inv) cos = F.cosine_similarity( torch.tensor([[0, 0, 1]], device=self.device), normal.view(-1, 3), ).view_as(normal[..., :1]) cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0 k = self.bake_unreliable_kernel_size * 2 + 1 kernel = torch.ones((1, 1, k, k), device=self.device) vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float() vis_mask = F.conv2d( 1.0 - vis_mask, kernel, padding=k // 2, ) vis_mask = 1.0 - (vis_mask > 0).float() vis_mask = vis_mask.squeeze(0).permute(1, 2, 0) sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0) sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2) sketch_image = (sketch_image > 0).float() sketch_image = sketch_image.squeeze(0).permute(1, 2, 0) vis_mask = vis_mask * (sketch_image < 0.5) cos[vis_mask == 0] = 0 valid_pixels = (vis_mask != 0).view(-1) return ( self._scatter_texture(uv, image, valid_pixels), self._scatter_texture(uv, cos, valid_pixels), ) def _scatter_texture(self, uv, data, mask): def __filter_data(data, mask): return data.view(-1, data.shape[-1])[mask] return _bilinear_interpolation_scattering( self.texture_wh[1], self.texture_wh[0], __filter_data(uv, mask)[..., [1, 0]], __filter_data(data, mask), ) @torch.no_grad() def fast_bake_texture( self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor] ) -> tuple[torch.Tensor, torch.Tensor]: channel = textures[0].shape[-1] texture_merge = torch.zeros(self.texture_wh + [channel]).to( self.device ) trust_map_merge = torch.zeros(self.texture_wh + [1]).to(self.device) for texture, cos_map in zip(textures, confidence_maps): view_sum = (cos_map > 0).sum() painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum() if painted_sum / view_sum > 0.99: continue texture_merge += texture * cos_map trust_map_merge += cos_map texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8) return texture_merge, trust_map_merge > 1e-8 def uv_inpaint( self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray ) -> np.ndarray: vertices, faces, uv_map = self.get_mesh_np_attrs(mesh) texture, mask = _texture_inpaint_smooth( texture, mask, vertices, faces, uv_map ) texture = texture.clip(0, 1) texture = cv2.inpaint( (texture * 255).astype(np.uint8), 255 - mask, 3, cv2.INPAINT_NS, ) return texture @spaces.GPU def compute_texture( self, colors: list[Image.Image], mesh: trimesh.Trimesh, ) -> trimesh.Trimesh: self._lazy_init_render(self.camera_params, self.mask_thresh) vertices = torch.from_numpy(mesh.vertices).to(self.device).float() faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int) uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float() rendered_depth, masks = self.renderer.render_depth(vertices, faces) norm_deps = self.renderer.normalize_map_by_mask(rendered_depth, masks) render_uvs, _ = self.renderer.render_uv(vertices, faces, uv_map) view_normals = self.compute_enhanced_viewnormal( self.renderer.mv_mtx, vertices, faces ) textures, weighted_cos_maps = [], [] for color, mask, dep, normal, uv, weight in zip( colors, masks, norm_deps, view_normals, render_uvs, self.view_weights, ): texture, cos_map = self.back_project(color, mask, dep, normal, uv) textures.append(texture) weighted_cos_maps.append(weight * (cos_map**4)) texture, mask = self.fast_bake_texture(textures, weighted_cos_maps) texture_np = texture.cpu().numpy() mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8) return texture_np, mask_np def __call__( self, colors: list[Image.Image], mesh: trimesh.Trimesh, output_path: str, ) -> trimesh.Trimesh: """Runs the texture baking and exports the textured mesh. Args: colors (list[Image.Image]): List of input view images. mesh (trimesh.Trimesh): Input mesh to be textured. output_path (str): Path to save the output textured mesh (.obj or .glb). Returns: trimesh.Trimesh: The textured mesh with UV and texture image. """ mesh = self.load_mesh(mesh) texture_np, mask_np = self.compute_texture(colors, mesh) texture_np = self.uv_inpaint(mesh, texture_np, mask_np) if self.smooth_texture: texture_np = post_process_texture(texture_np) vertices, faces, uv_map = self.get_mesh_np_attrs( mesh, self.scale, self.center ) textured_mesh = save_mesh_with_mtl( vertices, faces, uv_map, texture_np, output_path ) return textured_mesh def parse_args(): parser = argparse.ArgumentParser(description="Backproject texture") parser.add_argument( "--color_path", type=str, help="Multiview color image in 6x512x512 file path", ) parser.add_argument( "--mesh_path", type=str, help="Mesh path, .obj, .glb or .ply", ) parser.add_argument( "--output_path", type=str, help="Output mesh path with suffix", ) parser.add_argument( "--num_images", type=int, default=6, help="Number of images to render." ) parser.add_argument( "--elevation", nargs=2, type=float, default=[20.0, -10.0], help="Elevation angles for the camera (default: [20.0, -10.0])", ) parser.add_argument( "--distance", type=float, default=5, help="Camera distance (default: 5)", ) parser.add_argument( "--resolution_hw", type=int, nargs=2, default=(2048, 2048), help="Resolution of the output images (default: (2048, 2048))", ) parser.add_argument( "--fov", type=float, default=30, help="Field of view in degrees (default: 30)", ) parser.add_argument( "--device", type=str, choices=["cpu", "cuda"], default="cuda", help="Device to run on (default: `cuda`)", ) parser.add_argument( "--skip_fix_mesh", action="store_true", help="Fix mesh geometry." ) parser.add_argument( "--texture_wh", nargs=2, type=int, default=[2048, 2048], help="Texture resolution width and height", ) parser.add_argument( "--mesh_sipmlify_ratio", type=float, default=0.9, help="Mesh simplification ratio (default: 0.9)", ) parser.add_argument( "--delight", action="store_true", help="Use delighting model." ) parser.add_argument( "--no_smooth_texture", action="store_true", help="Do not smooth the texture.", ) parser.add_argument( "--save_glb_path", type=str, default=None, help="Save glb path." ) parser.add_argument( "--no_save_delight_img", action="store_true", help="Disable saving delight image", ) args, unknown = parser.parse_known_args() return args def entrypoint( delight_model: DelightingModel = None, imagesr_model: ImageRealESRGAN = None, **kwargs, ) -> trimesh.Trimesh: args = parse_args() for k, v in kwargs.items(): if hasattr(args, k) and v is not None: setattr(args, k, v) # Setup camera parameters. camera_params = CameraSetting( num_images=args.num_images, elevation=args.elevation, distance=args.distance, resolution_hw=args.resolution_hw, fov=math.radians(args.fov), device=args.device, ) view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02] color_grid = Image.open(args.color_path) if args.delight: if delight_model is None: delight_model = DelightingModel() save_dir = os.path.dirname(args.output_path) os.makedirs(save_dir, exist_ok=True) color_grid = delight_model(color_grid) if not args.no_save_delight_img: color_grid.save(f"{save_dir}/color_grid_delight.png") multiviews = get_images_from_grid(color_grid, img_size=512) # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution. if imagesr_model is None: imagesr_model = ImageRealESRGAN(outscale=4) multiviews = [imagesr_model(img) for img in multiviews] multiviews = [img.convert("RGB") for img in multiviews] mesh = trimesh.load(args.mesh_path) if isinstance(mesh, trimesh.Scene): mesh = mesh.dump(concatenate=True) if not args.skip_fix_mesh: mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) mesh_fixer = MeshFixer(mesh.vertices, mesh.faces, args.device) mesh.vertices, mesh.faces = mesh_fixer( filter_ratio=args.mesh_sipmlify_ratio, max_hole_size=0.04, resolution=1024, num_views=1000, norm_mesh_ratio=0.5, ) # Restore scale. mesh.vertices = mesh.vertices / scale mesh.vertices = mesh.vertices + center # Baking texture to mesh. texture_backer = TextureBacker( camera_params=camera_params, view_weights=view_weights, render_wh=camera_params.resolution_hw, texture_wh=args.texture_wh, smooth_texture=not args.no_smooth_texture, ) textured_mesh = texture_backer(multiviews, mesh, args.output_path) if args.save_glb_path is not None: os.makedirs(os.path.dirname(args.save_glb_path), exist_ok=True) textured_mesh.export(args.save_glb_path) return textured_mesh if __name__ == "__main__": entrypoint()