xinjie.wang
update
049066b
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import logging
import math
import cv2
import spaces
import torch
from PIL import Image
from tqdm import tqdm
from embodied_gen.data.utils import (
CameraSetting,
init_kal_camera,
normalize_vertices_array,
)
from embodied_gen.models.gs_model import GaussianOperator
from embodied_gen.utils.process_media import combine_images_to_grid
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Render GS color images")
parser.add_argument(
"--input_gs", type=str, help="Input render GS.ply path."
)
parser.add_argument(
"--output_path",
type=str,
help="Output grid image path for rendered GS color images.",
)
parser.add_argument(
"--num_images", type=int, default=6, help="Number of images to render."
)
parser.add_argument(
"--elevation",
type=float,
nargs="+",
default=[20.0, -10.0],
help="Elevation angles for the camera (default: [20.0, -10.0])",
)
parser.add_argument(
"--distance",
type=float,
default=5,
help="Camera distance (default: 5)",
)
parser.add_argument(
"--resolution_hw",
type=int,
nargs=2,
default=(512, 512),
help="Resolution of the output images (default: (512, 512))",
)
parser.add_argument(
"--fov",
type=float,
default=30,
help="Field of view in degrees (default: 30)",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda"],
default="cuda",
help="Device to run on (default: `cuda`)",
)
parser.add_argument(
"--image_size",
type=int,
default=512,
help="Output image size for single view in color grid (default: 512)",
)
args, unknown = parser.parse_known_args()
return args
def load_gs_model(
input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
) -> GaussianOperator:
gs_model = GaussianOperator.load_from_ply(input_gs)
# Normalize vertices to [-1, 1], center to (0, 0, 0).
_, scale, center = normalize_vertices_array(gs_model._means)
scale, center = float(scale), center.tolist()
transpose = [*[v for v in center], *pre_quat]
instance_pose = torch.tensor(transpose).to(gs_model.device)
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
gs_model.rescale(scale)
return gs_model
@spaces.GPU
def entrypoint(**kwargs) -> None:
args = parse_args()
for k, v in kwargs.items():
if hasattr(args, k) and v is not None:
setattr(args, k, v)
# Setup camera parameters
camera_params = CameraSetting(
num_images=args.num_images,
elevation=args.elevation,
distance=args.distance,
resolution_hw=args.resolution_hw,
fov=math.radians(args.fov),
device=args.device,
)
camera = init_kal_camera(camera_params, flip_az=True)
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
w2cs = matrix_mv.to(camera_params.device)
c2ws = [torch.linalg.inv(matrix) for matrix in w2cs]
Ks = torch.tensor(camera_params.Ks).to(camera_params.device)
# Load GS model and normalize.
gs_model = load_gs_model(args.input_gs, pre_quat=[0.0, 0.0, 1.0, 0.0])
# Render GS color images.
images = []
for idx in tqdm(range(len(c2ws)), desc="Rendering GS"):
result = gs_model.render(
c2ws[idx],
Ks=Ks,
image_width=camera_params.resolution_hw[1],
image_height=camera_params.resolution_hw[0],
)
color = cv2.resize(
result.rgba,
(args.image_size, args.image_size),
interpolation=cv2.INTER_AREA,
)
color = cv2.cvtColor(color, cv2.COLOR_BGRA2RGBA)
images.append(Image.fromarray(color))
combine_images_to_grid(images, image_mode="RGBA")[0].save(args.output_path)
logger.info(f"Saved grid image to {args.output_path}")
if __name__ == "__main__":
entrypoint()