xinjie.wang
update
43c5d2f
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import logging
import math
import os
import cv2
import numpy as np
import nvdiffrast.torch as dr
import spaces
import torch
import torch.nn.functional as F
import trimesh
import xatlas
from PIL import Image
from embodied_gen.data.mesh_operator import MeshFixer
from embodied_gen.data.utils import (
CameraSetting,
DiffrastRender,
get_images_from_grid,
init_kal_camera,
normalize_vertices_array,
post_process_texture,
save_mesh_with_mtl,
)
from embodied_gen.models.delight_model import DelightingModel
from embodied_gen.models.sr_model import ImageRealESRGAN
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
__all__ = [
"TextureBacker",
]
def _transform_vertices(
mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
) -> torch.Tensor:
"""Transform 3D vertices using a projection matrix."""
t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
if pos.size(-1) == 3:
pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
result = pos @ t_mtx.T
return result if keepdim else result.unsqueeze(0)
def _bilinear_interpolation_scattering(
image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
) -> torch.Tensor:
"""Bilinear interpolation scattering for grid-based value accumulation."""
device = values.device
dtype = values.dtype
C = values.shape[-1]
indices = coords * torch.tensor(
[image_h - 1, image_w - 1], dtype=dtype, device=device
)
i, j = indices.unbind(-1)
i0, j0 = (
indices.floor()
.long()
.clamp(0, image_h - 2)
.clamp(0, image_w - 2)
.unbind(-1)
)
i1, j1 = i0 + 1, j0 + 1
w_i = i - i0.float()
w_j = j - j0.float()
weights = torch.stack(
[(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
dim=1,
)
indices_comb = torch.stack(
[
torch.stack([i0, j0], dim=1),
torch.stack([i0, j1], dim=1),
torch.stack([i1, j0], dim=1),
torch.stack([i1, j1], dim=1),
],
dim=1,
)
grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
for k in range(4):
idx = indices_comb[:, k]
w = weights[:, k].unsqueeze(-1)
stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
flat_idx = (idx * stride).sum(-1)
grid.view(-1, C).scatter_add_(
0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
)
cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
mask = cnt.squeeze(-1) > 0
grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
return grid
def _texture_inpaint_smooth(
texture: np.ndarray,
mask: np.ndarray,
vertices: np.ndarray,
faces: np.ndarray,
uv_map: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Perform texture inpainting using vertex-based color propagation."""
image_h, image_w, C = texture.shape
N = vertices.shape[0]
# Initialize vertex data structures
vtx_mask = np.zeros(N, dtype=np.float32)
vtx_colors = np.zeros((N, C), dtype=np.float32)
unprocessed = []
adjacency = [[] for _ in range(N)]
# Build adjacency graph and initial color assignment
for face_idx in range(faces.shape[0]):
for k in range(3):
uv_idx_k = faces[face_idx, k]
v_idx = faces[face_idx, k]
# Convert UV to pixel coordinates with boundary clamping
u = np.clip(
int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
)
v = np.clip(
int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
0,
image_h - 1,
)
if mask[v, u]:
vtx_mask[v_idx] = 1.0
vtx_colors[v_idx] = texture[v, u]
elif v_idx not in unprocessed:
unprocessed.append(v_idx)
# Build undirected adjacency graph
neighbor = faces[face_idx, (k + 1) % 3]
if neighbor not in adjacency[v_idx]:
adjacency[v_idx].append(neighbor)
if v_idx not in adjacency[neighbor]:
adjacency[neighbor].append(v_idx)
# Color propagation with dynamic stopping
remaining_iters, prev_count = 2, 0
while remaining_iters > 0:
current_unprocessed = []
for v_idx in unprocessed:
valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
if not valid_neighbors:
current_unprocessed.append(v_idx)
continue
# Calculate inverse square distance weights
neighbors_pos = vertices[valid_neighbors]
dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
weights = 1 / np.maximum(dist_sq, 1e-8)
vtx_colors[v_idx] = np.average(
vtx_colors[valid_neighbors], weights=weights, axis=0
)
vtx_mask[v_idx] = 1.0
# Update iteration control
if len(current_unprocessed) == prev_count:
remaining_iters -= 1
else:
remaining_iters = min(remaining_iters + 1, 2)
prev_count = len(current_unprocessed)
unprocessed = current_unprocessed
# Generate output texture
inpainted_texture, updated_mask = texture.copy(), mask.copy()
for face_idx in range(faces.shape[0]):
for k in range(3):
v_idx = faces[face_idx, k]
if not vtx_mask[v_idx]:
continue
# UV coordinate conversion
uv_idx_k = faces[face_idx, k]
u = np.clip(
int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
)
v = np.clip(
int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
0,
image_h - 1,
)
inpainted_texture[v, u] = vtx_colors[v_idx]
updated_mask[v, u] = 255
return inpainted_texture, updated_mask
class TextureBacker:
"""Texture baking pipeline for multi-view projection and fusion.
This class performs UV-based texture generation for a 3D mesh using
multi-view color images, depth, and normal information. The pipeline
includes mesh normalization and UV unwrapping, visibility-aware
back-projection, confidence-weighted texture fusion, and inpainting
of missing texture regions.
Args:
camera_params (CameraSetting): Camera intrinsics and extrinsics used
for rendering each view.
view_weights (list[float]): A list of weights for each view, used
to blend confidence maps during texture fusion.
render_wh (tuple[int, int], optional): Resolution (width, height) for
intermediate rendering passes. Defaults to (2048, 2048).
texture_wh (tuple[int, int], optional): Output texture resolution
(width, height). Defaults to (2048, 2048).
bake_angle_thresh (int, optional): Maximum angle (in degrees) between
view direction and surface normal for projection to be considered valid.
Defaults to 75.
mask_thresh (float, optional): Threshold applied to visibility masks
during rendering. Defaults to 0.5.
smooth_texture (bool, optional): If True, apply post-processing (e.g.,
blurring) to the final texture. Defaults to True.
"""
def __init__(
self,
camera_params: CameraSetting,
view_weights: list[float],
render_wh: tuple[int, int] = (2048, 2048),
texture_wh: tuple[int, int] = (2048, 2048),
bake_angle_thresh: int = 75,
mask_thresh: float = 0.5,
smooth_texture: bool = True,
) -> None:
self.camera_params = camera_params
self.renderer = None
self.view_weights = view_weights
self.device = camera_params.device
self.render_wh = render_wh
self.texture_wh = texture_wh
self.mask_thresh = mask_thresh
self.smooth_texture = smooth_texture
self.bake_angle_thresh = bake_angle_thresh
self.bake_unreliable_kernel_size = int(
(2 / 512) * max(self.render_wh[0], self.render_wh[1])
)
def _lazy_init_render(self, camera_params, mask_thresh):
if self.renderer is None:
camera = init_kal_camera(camera_params)
mv = camera.view_matrix() # (n 4 4) world2cam
p = camera.intrinsics.projection_matrix()
# NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
p[:, 1, 1] = -p[:, 1, 1]
self.renderer = DiffrastRender(
p_matrix=p,
mv_matrix=mv,
resolution_hw=camera_params.resolution_hw,
context=dr.RasterizeCudaContext(),
mask_thresh=mask_thresh,
grad_db=False,
device=self.device,
antialias_mask=True,
)
def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
self.scale, self.center = scale, center
vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
uvs[:, 1] = 1 - uvs[:, 1]
mesh.vertices = mesh.vertices[vmapping]
mesh.faces = indices
mesh.visual.uv = uvs
return mesh
def get_mesh_np_attrs(
self,
mesh: trimesh.Trimesh,
scale: float = None,
center: np.ndarray = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
vertices = mesh.vertices.copy()
faces = mesh.faces.copy()
uv_map = mesh.visual.uv.copy()
uv_map[:, 1] = 1.0 - uv_map[:, 1]
if scale is not None:
vertices = vertices / scale
if center is not None:
vertices = vertices + center
return vertices, faces, uv_map
def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
depth_image_np = depth_image.cpu().numpy()
depth_image_np = (depth_image_np * 255).astype(np.uint8)
depth_edges = cv2.Canny(depth_image_np, 30, 80)
sketch_image = (
torch.from_numpy(depth_edges).to(depth_image.device).float() / 255
)
sketch_image = sketch_image.unsqueeze(-1)
return sketch_image
def compute_enhanced_viewnormal(
self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
) -> torch.Tensor:
rast, _ = self.renderer.compute_dr_raster(vertices, faces)
rendered_view_normals = []
for idx in range(len(mv_mtx)):
pos_cam = _transform_vertices(mv_mtx[idx], vertices, keepdim=True)
pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3))
face_norm = F.normalize(
torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
)
vertex_norm = (
torch.from_numpy(
trimesh.geometry.mean_vertex_normals(
len(pos_cam), faces.cpu(), face_norm.cpu()
)
)
.to(vertices.device)
.contiguous()
)
im_base_normals, _ = dr.interpolate(
vertex_norm[None, ...].float(),
rast[idx : idx + 1],
faces.to(torch.int32),
)
rendered_view_normals.append(im_base_normals)
rendered_view_normals = torch.cat(rendered_view_normals, dim=0)
return rendered_view_normals
def back_project(
self, image, vis_mask, depth, normal, uv
) -> tuple[torch.Tensor, torch.Tensor]:
image = np.array(image)
image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
if image.ndim == 2:
image = image.unsqueeze(-1)
image = image / 255
depth_inv = (1.0 - depth) * vis_mask
sketch_image = self._render_depth_edges(depth_inv)
cos = F.cosine_similarity(
torch.tensor([[0, 0, 1]], device=self.device),
normal.view(-1, 3),
).view_as(normal[..., :1])
cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
k = self.bake_unreliable_kernel_size * 2 + 1
kernel = torch.ones((1, 1, k, k), device=self.device)
vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
vis_mask = F.conv2d(
1.0 - vis_mask,
kernel,
padding=k // 2,
)
vis_mask = 1.0 - (vis_mask > 0).float()
vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
sketch_image = (sketch_image > 0).float()
sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
vis_mask = vis_mask * (sketch_image < 0.5)
cos[vis_mask == 0] = 0
valid_pixels = (vis_mask != 0).view(-1)
return (
self._scatter_texture(uv, image, valid_pixels),
self._scatter_texture(uv, cos, valid_pixels),
)
def _scatter_texture(self, uv, data, mask):
def __filter_data(data, mask):
return data.view(-1, data.shape[-1])[mask]
return _bilinear_interpolation_scattering(
self.texture_wh[1],
self.texture_wh[0],
__filter_data(uv, mask)[..., [1, 0]],
__filter_data(data, mask),
)
@torch.no_grad()
def fast_bake_texture(
self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor]:
channel = textures[0].shape[-1]
texture_merge = torch.zeros(self.texture_wh + [channel]).to(
self.device
)
trust_map_merge = torch.zeros(self.texture_wh + [1]).to(self.device)
for texture, cos_map in zip(textures, confidence_maps):
view_sum = (cos_map > 0).sum()
painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
if painted_sum / view_sum > 0.99:
continue
texture_merge += texture * cos_map
trust_map_merge += cos_map
texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
return texture_merge, trust_map_merge > 1e-8
def uv_inpaint(
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
) -> np.ndarray:
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
texture, mask = _texture_inpaint_smooth(
texture, mask, vertices, faces, uv_map
)
texture = texture.clip(0, 1)
texture = cv2.inpaint(
(texture * 255).astype(np.uint8),
255 - mask,
3,
cv2.INPAINT_NS,
)
return texture
@spaces.GPU
def compute_texture(
self,
colors: list[Image.Image],
mesh: trimesh.Trimesh,
) -> trimesh.Trimesh:
self._lazy_init_render(self.camera_params, self.mask_thresh)
vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
rendered_depth, masks = self.renderer.render_depth(vertices, faces)
norm_deps = self.renderer.normalize_map_by_mask(rendered_depth, masks)
render_uvs, _ = self.renderer.render_uv(vertices, faces, uv_map)
view_normals = self.compute_enhanced_viewnormal(
self.renderer.mv_mtx, vertices, faces
)
textures, weighted_cos_maps = [], []
for color, mask, dep, normal, uv, weight in zip(
colors,
masks,
norm_deps,
view_normals,
render_uvs,
self.view_weights,
):
texture, cos_map = self.back_project(color, mask, dep, normal, uv)
textures.append(texture)
weighted_cos_maps.append(weight * (cos_map**4))
texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
texture_np = texture.cpu().numpy()
mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
return texture_np, mask_np
def __call__(
self,
colors: list[Image.Image],
mesh: trimesh.Trimesh,
output_path: str,
) -> trimesh.Trimesh:
"""Runs the texture baking and exports the textured mesh.
Args:
colors (list[Image.Image]): List of input view images.
mesh (trimesh.Trimesh): Input mesh to be textured.
output_path (str): Path to save the output textured mesh (.obj or .glb).
Returns:
trimesh.Trimesh: The textured mesh with UV and texture image.
"""
mesh = self.load_mesh(mesh)
texture_np, mask_np = self.compute_texture(colors, mesh)
texture_np = self.uv_inpaint(mesh, texture_np, mask_np)
if self.smooth_texture:
texture_np = post_process_texture(texture_np)
vertices, faces, uv_map = self.get_mesh_np_attrs(
mesh, self.scale, self.center
)
textured_mesh = save_mesh_with_mtl(
vertices, faces, uv_map, texture_np, output_path
)
return textured_mesh
def parse_args():
parser = argparse.ArgumentParser(description="Backproject texture")
parser.add_argument(
"--color_path",
type=str,
help="Multiview color image in 6x512x512 file path",
)
parser.add_argument(
"--mesh_path",
type=str,
help="Mesh path, .obj, .glb or .ply",
)
parser.add_argument(
"--output_path",
type=str,
help="Output mesh path with suffix",
)
parser.add_argument(
"--num_images", type=int, default=6, help="Number of images to render."
)
parser.add_argument(
"--elevation",
nargs=2,
type=float,
default=[20.0, -10.0],
help="Elevation angles for the camera (default: [20.0, -10.0])",
)
parser.add_argument(
"--distance",
type=float,
default=5,
help="Camera distance (default: 5)",
)
parser.add_argument(
"--resolution_hw",
type=int,
nargs=2,
default=(2048, 2048),
help="Resolution of the output images (default: (2048, 2048))",
)
parser.add_argument(
"--fov",
type=float,
default=30,
help="Field of view in degrees (default: 30)",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda"],
default="cuda",
help="Device to run on (default: `cuda`)",
)
parser.add_argument(
"--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
)
parser.add_argument(
"--texture_wh",
nargs=2,
type=int,
default=[2048, 2048],
help="Texture resolution width and height",
)
parser.add_argument(
"--mesh_sipmlify_ratio",
type=float,
default=0.9,
help="Mesh simplification ratio (default: 0.9)",
)
parser.add_argument(
"--delight", action="store_true", help="Use delighting model."
)
parser.add_argument(
"--no_smooth_texture",
action="store_true",
help="Do not smooth the texture.",
)
parser.add_argument(
"--save_glb_path", type=str, default=None, help="Save glb path."
)
parser.add_argument(
"--no_save_delight_img",
action="store_true",
help="Disable saving delight image",
)
args, unknown = parser.parse_known_args()
return args
def entrypoint(
delight_model: DelightingModel = None,
imagesr_model: ImageRealESRGAN = None,
**kwargs,
) -> trimesh.Trimesh:
args = parse_args()
for k, v in kwargs.items():
if hasattr(args, k) and v is not None:
setattr(args, k, v)
# Setup camera parameters.
camera_params = CameraSetting(
num_images=args.num_images,
elevation=args.elevation,
distance=args.distance,
resolution_hw=args.resolution_hw,
fov=math.radians(args.fov),
device=args.device,
)
view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
color_grid = Image.open(args.color_path)
if args.delight:
if delight_model is None:
delight_model = DelightingModel()
save_dir = os.path.dirname(args.output_path)
os.makedirs(save_dir, exist_ok=True)
color_grid = delight_model(color_grid)
if not args.no_save_delight_img:
color_grid.save(f"{save_dir}/color_grid_delight.png")
multiviews = get_images_from_grid(color_grid, img_size=512)
# Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
if imagesr_model is None:
imagesr_model = ImageRealESRGAN(outscale=4)
multiviews = [imagesr_model(img) for img in multiviews]
multiviews = [img.convert("RGB") for img in multiviews]
mesh = trimesh.load(args.mesh_path)
if isinstance(mesh, trimesh.Scene):
mesh = mesh.dump(concatenate=True)
if not args.skip_fix_mesh:
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
mesh_fixer = MeshFixer(mesh.vertices, mesh.faces, args.device)
mesh.vertices, mesh.faces = mesh_fixer(
filter_ratio=args.mesh_sipmlify_ratio,
max_hole_size=0.04,
resolution=1024,
num_views=1000,
norm_mesh_ratio=0.5,
)
# Restore scale.
mesh.vertices = mesh.vertices / scale
mesh.vertices = mesh.vertices + center
# Baking texture to mesh.
texture_backer = TextureBacker(
camera_params=camera_params,
view_weights=view_weights,
render_wh=camera_params.resolution_hw,
texture_wh=args.texture_wh,
smooth_texture=not args.no_smooth_texture,
)
textured_mesh = texture_backer(multiviews, mesh, args.output_path)
if args.save_glb_path is not None:
os.makedirs(os.path.dirname(args.save_glb_path), exist_ok=True)
textured_mesh.export(args.save_glb_path)
return textured_mesh
if __name__ == "__main__":
entrypoint()