xinjie.wang
update
8131b67
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from dataclasses import dataclass, field
from typing import List, Optional, Union
from gsplat.strategy import DefaultStrategy, MCMCStrategy
from typing_extensions import Literal, assert_never
__all__ = [
"Pano2MeshSRConfig",
"GsplatTrainConfig",
]
@dataclass
class Pano2MeshSRConfig:
mesh_file: str = "mesh_model.ply"
gs_data_file: str = "gs_data.pt"
device: str = "cuda"
blur_radius: int = 0
faces_per_pixel: int = 8
fov: int = 90
pano_w: int = 2048
pano_h: int = 1024
cubemap_w: int = 512
cubemap_h: int = 512
pose_scale: float = 0.6
pano_center_offset: tuple = (-0.2, 0.3)
inpaint_frame_stride: int = 20
trajectory_dir: str = "apps/assets/example_scene/camera_trajectory"
visualize: bool = False
depth_scale_factor: float = 3.4092
kernel_size: tuple = (9, 9)
upscale_factor: int = 4
@dataclass
class GsplatTrainConfig:
# Path to the .pt files. If provide, it will skip training and run evaluation only.
ckpt: Optional[List[str]] = None
# Render trajectory path
render_traj_path: str = "interp"
# Path to the Mip-NeRF 360 dataset
data_dir: str = "outputs/bg"
# Downsample factor for the dataset
data_factor: int = 4
# Directory to save results
result_dir: str = "outputs/bg"
# Every N images there is a test image
test_every: int = 8
# Random crop size for training (experimental)
patch_size: Optional[int] = None
# A global scaler that applies to the scene size related parameters
global_scale: float = 1.0
# Normalize the world space
normalize_world_space: bool = True
# Camera model
camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole"
# Port for the viewer server
port: int = 8080
# Batch size for training. Learning rates are scaled automatically
batch_size: int = 1
# A global factor to scale the number of training steps
steps_scaler: float = 1.0
# Number of training steps
max_steps: int = 30_000
# Steps to evaluate the model
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Whether to save ply file (storage size can be large)
save_ply: bool = True
# Steps to save the model as ply
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Whether to disable video generation during training and evaluation
disable_video: bool = False
# Initial number of GSs. Ignored if using sfm
init_num_pts: int = 100_000
# Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
init_extent: float = 3.0
# Degree of spherical harmonics
sh_degree: int = 1
# Turn on another SH degree every this steps
sh_degree_interval: int = 1000
# Initial opacity of GS
init_opa: float = 0.1
# Initial scale of GS
init_scale: float = 1.0
# Weight for SSIM loss
ssim_lambda: float = 0.2
# Near plane clipping distance
near_plane: float = 0.01
# Far plane clipping distance
far_plane: float = 1e10
# Strategy for GS densification
strategy: Union[DefaultStrategy, MCMCStrategy] = field(
default_factory=DefaultStrategy
)
# Use packed mode for rasterization, this leads to less memory usage but slightly slower.
packed: bool = False
# Use sparse gradients for optimization. (experimental)
sparse_grad: bool = False
# Use visible adam from Taming 3DGS. (experimental)
visible_adam: bool = False
# Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
antialiased: bool = False
# Use random background for training to discourage transparency
random_bkgd: bool = False
# LR for 3D point positions
means_lr: float = 1.6e-4
# LR for Gaussian scale factors
scales_lr: float = 5e-3
# LR for alpha blending weights
opacities_lr: float = 5e-2
# LR for orientation (quaternions)
quats_lr: float = 1e-3
# LR for SH band 0 (brightness)
sh0_lr: float = 2.5e-3
# LR for higher-order SH (detail)
shN_lr: float = 2.5e-3 / 20
# Opacity regularization
opacity_reg: float = 0.0
# Scale regularization
scale_reg: float = 0.0
# Enable depth loss. (experimental)
depth_loss: bool = False
# Weight for depth loss
depth_lambda: float = 1e-2
# Dump information to tensorboard every this steps
tb_every: int = 200
# Save training images to tensorboard
tb_save_image: bool = False
lpips_net: Literal["vgg", "alex"] = "alex"
# 3DGUT (uncented transform + eval 3D)
with_ut: bool = False
with_eval3d: bool = False
scene_scale: float = 1.0
def adjust_steps(self, factor: float):
self.eval_steps = [int(i * factor) for i in self.eval_steps]
self.save_steps = [int(i * factor) for i in self.save_steps]
self.ply_steps = [int(i * factor) for i in self.ply_steps]
self.max_steps = int(self.max_steps * factor)
self.sh_degree_interval = int(self.sh_degree_interval * factor)
strategy = self.strategy
if isinstance(strategy, DefaultStrategy):
strategy.refine_start_iter = int(
strategy.refine_start_iter * factor
)
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
strategy.reset_every = int(strategy.reset_every * factor)
strategy.refine_every = int(strategy.refine_every * factor)
elif isinstance(strategy, MCMCStrategy):
strategy.refine_start_iter = int(
strategy.refine_start_iter * factor
)
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
strategy.refine_every = int(strategy.refine_every * factor)
else:
assert_never(strategy)