xinjie.wang
update
efa2e5d
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import logging
import os
import sys
from glob import glob
from shutil import copy, copytree, rmtree
import numpy as np
import trimesh
from PIL import Image
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
from embodied_gen.data.utils import delete_dir, trellis_preprocess
from embodied_gen.models.delight_model import DelightingModel
from embodied_gen.models.gs_model import GaussianOperator
from embodied_gen.models.segment_model import (
BMGG14Remover,
RembgRemover,
SAMPredictor,
)
from embodied_gen.models.sr_model import ImageRealESRGAN
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.utils.process_media import merge_images_video, render_video
from embodied_gen.utils.tags import VERSION
from embodied_gen.validators.quality_checkers import (
BaseChecker,
ImageAestheticChecker,
ImageSegChecker,
MeshGeoChecker,
)
from embodied_gen.validators.urdf_convertor import URDFGenerator
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_dir, "../.."))
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
"~/.cache/torch_extensions"
)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
os.environ["SPCONV_ALGO"] = "native"
DELIGHT = DelightingModel()
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
RBG_REMOVER = RembgRemover()
RBG14_REMOVER = BMGG14Remover()
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
"microsoft/TRELLIS-image-large"
)
PIPELINE.cuda()
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
AESTHETIC_CHECKER = ImageAestheticChecker()
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
TMP_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
)
def parse_args():
parser = argparse.ArgumentParser(description="Image to 3D pipeline args.")
parser.add_argument(
"--image_path", type=str, nargs="+", help="Path to the input images."
)
parser.add_argument(
"--image_root", type=str, help="Path to the input images folder."
)
parser.add_argument(
"--output_root",
type=str,
required=True,
help="Root directory for saving outputs.",
)
parser.add_argument(
"--height_range",
type=str,
default=None,
help="The hight in meter to restore the mesh real size.",
)
parser.add_argument(
"--mass_range",
type=str,
default=None,
help="The mass in kg to restore the mesh real weight.",
)
parser.add_argument("--asset_type", type=str, default=None)
parser.add_argument("--skip_exists", action="store_true")
parser.add_argument("--strict_seg", action="store_true")
parser.add_argument("--version", type=str, default=VERSION)
parser.add_argument("--remove_intermediate", type=bool, default=True)
args = parser.parse_args()
assert (
args.image_path or args.image_root
), "Please provide either --image_path or --image_root."
if not args.image_path:
args.image_path = glob(os.path.join(args.image_root, "*.png"))
args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
return args
if __name__ == "__main__":
args = parse_args()
for image_path in args.image_path:
try:
filename = os.path.basename(image_path).split(".")[0]
output_root = args.output_root
if args.image_root is not None or len(args.image_path) > 1:
output_root = os.path.join(output_root, filename)
os.makedirs(output_root, exist_ok=True)
mesh_out = f"{output_root}/{filename}.obj"
if args.skip_exists and os.path.exists(mesh_out):
logger.info(
f"Skip {image_path}, already processed in {mesh_out}"
)
continue
image = Image.open(image_path)
image.save(f"{output_root}/{filename}_raw.png")
# Segmentation: Get segmented image using SAM or Rembg.
seg_path = f"{output_root}/{filename}_cond.png"
if image.mode != "RGBA":
seg_image = RBG_REMOVER(image, save_path=seg_path)
seg_image = trellis_preprocess(seg_image)
else:
seg_image = image
seg_image.save(seg_path)
# Run the pipeline
try:
outputs = PIPELINE.run(
seg_image,
preprocess_image=False,
# Optional parameters
# seed=1,
# sparse_structure_sampler_params={
# "steps": 12,
# "cfg_strength": 7.5,
# },
# slat_sampler_params={
# "steps": 12,
# "cfg_strength": 3,
# },
)
except Exception as e:
logger.error(
f"[Pipeline Failed] process {image_path}: {e}, skip."
)
continue
# Render and save color and mesh videos
gs_model = outputs["gaussian"][0]
mesh_model = outputs["mesh"][0]
color_images = render_video(gs_model)["color"]
normal_images = render_video(mesh_model)["normal"]
video_path = os.path.join(output_root, "gs_mesh.mp4")
merge_images_video(color_images, normal_images, video_path)
# Save the raw Gaussian model
gs_path = mesh_out.replace(".obj", "_gs.ply")
gs_model.save_ply(gs_path)
# Rotate mesh and GS by 90 degrees around Z-axis.
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
# Addtional rotation for GS to align mesh.
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
pose = GaussianOperator.trans_to_quatpose(gs_rot)
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=aligned_gs_path,
instance_pose=pose,
device="cpu",
)
color_path = os.path.join(output_root, "color.png")
render_gs_api(aligned_gs_path, color_path)
mesh = trimesh.Trimesh(
vertices=mesh_model.vertices.cpu().numpy(),
faces=mesh_model.faces.cpu().numpy(),
)
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
mesh.export(mesh_obj_path)
mesh = backproject_api(
delight_model=DELIGHT,
imagesr_model=IMAGESR_MODEL,
color_path=color_path,
mesh_path=mesh_obj_path,
output_path=mesh_obj_path,
skip_fix_mesh=False,
delight=True,
texture_wh=[2048, 2048],
)
mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
mesh.export(mesh_glb_path)
urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4)
asset_attrs = {
"version": VERSION,
"gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
}
if args.height_range:
min_height, max_height = map(
float, args.height_range.split("-")
)
asset_attrs["min_height"] = min_height
asset_attrs["max_height"] = max_height
if args.mass_range:
min_mass, max_mass = map(float, args.mass_range.split("-"))
asset_attrs["min_mass"] = min_mass
asset_attrs["max_mass"] = max_mass
if args.asset_type:
asset_attrs["category"] = args.asset_type
if args.version:
asset_attrs["version"] = args.version
urdf_root = f"{output_root}/URDF_{filename}"
urdf_path = urdf_convertor(
mesh_path=mesh_obj_path,
output_root=urdf_root,
**asset_attrs,
)
# Rescale GS and save to URDF/mesh folder.
real_height = urdf_convertor.get_attr_from_urdf(
urdf_path, attr_name="real_height"
)
out_gs = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
GaussianOperator.resave_ply(
in_ply=aligned_gs_path,
out_ply=out_gs,
real_height=real_height,
device="cpu",
)
# Quality check and update .urdf file.
mesh_out = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
image_dir = f"{urdf_root}/{urdf_convertor.output_render_dir}/image_color" # noqa
image_paths = glob(f"{image_dir}/*.png")
images_list = []
for checker in CHECKERS:
images = image_paths
if isinstance(checker, ImageSegChecker):
images = [
f"{output_root}/{filename}_raw.png",
f"{output_root}/{filename}_cond.png",
]
images_list.append(images)
results = BaseChecker.validate(CHECKERS, images_list)
urdf_convertor.add_quality_tag(urdf_path, results)
# Organize the final result files
result_dir = f"{output_root}/result"
if os.path.exists(result_dir):
rmtree(result_dir, ignore_errors=True)
os.makedirs(result_dir, exist_ok=True)
copy(urdf_path, f"{result_dir}/{os.path.basename(urdf_path)}")
copytree(
f"{urdf_root}/{urdf_convertor.output_mesh_dir}",
f"{result_dir}/{urdf_convertor.output_mesh_dir}",
)
copy(video_path, f"{result_dir}/video.mp4")
if args.remove_intermediate:
delete_dir(output_root, keep_subs=["result"])
except Exception as e:
logger.error(f"Failed to process {image_path}: {e}, skip.")
continue
logger.info(f"Processing complete. Outputs saved to {args.output_root}")