Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,570 Bytes
631a83a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from dataclasses import dataclass, field
from typing import List, Optional, Union
from gsplat.strategy import DefaultStrategy, MCMCStrategy
from typing_extensions import Literal, assert_never
__all__ = [
"Pano2MeshSRConfig",
"GsplatTrainConfig",
]
@dataclass
class Pano2MeshSRConfig:
mesh_file: str = "mesh_model.ply"
gs_data_file: str = "gs_data.pt"
device: str = "cuda"
blur_radius: int = 0
faces_per_pixel: int = 8
fov: int = 90
pano_w: int = 2048
pano_h: int = 1024
cubemap_w: int = 512
cubemap_h: int = 512
pose_scale: float = 0.6
pano_center_offset: tuple = (-0.2, 0.3)
inpaint_frame_stride: int = 20
trajectory_dir: str = "apps/assets/example_scene/camera_trajectory"
visualize: bool = False
depth_scale_factor: float = 3.4092
kernel_size: tuple = (9, 9)
upscale_factor: int = 4
@dataclass
class GsplatTrainConfig:
# Path to the .pt files. If provide, it will skip training and run evaluation only.
ckpt: Optional[List[str]] = None
# Render trajectory path
render_traj_path: str = "interp"
# Path to the Mip-NeRF 360 dataset
data_dir: str = "outputs/bg"
# Downsample factor for the dataset
data_factor: int = 4
# Directory to save results
result_dir: str = "outputs/bg"
# Every N images there is a test image
test_every: int = 8
# Random crop size for training (experimental)
patch_size: Optional[int] = None
# A global scaler that applies to the scene size related parameters
global_scale: float = 1.0
# Normalize the world space
normalize_world_space: bool = True
# Camera model
camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole"
# Port for the viewer server
port: int = 8080
# Batch size for training. Learning rates are scaled automatically
batch_size: int = 1
# A global factor to scale the number of training steps
steps_scaler: float = 1.0
# Number of training steps
max_steps: int = 30_000
# Steps to evaluate the model
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Whether to save ply file (storage size can be large)
save_ply: bool = True
# Steps to save the model as ply
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Whether to disable video generation during training and evaluation
disable_video: bool = False
# Initial number of GSs. Ignored if using sfm
init_num_pts: int = 100_000
# Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
init_extent: float = 3.0
# Degree of spherical harmonics
sh_degree: int = 1
# Turn on another SH degree every this steps
sh_degree_interval: int = 1000
# Initial opacity of GS
init_opa: float = 0.1
# Initial scale of GS
init_scale: float = 1.0
# Weight for SSIM loss
ssim_lambda: float = 0.2
# Near plane clipping distance
near_plane: float = 0.01
# Far plane clipping distance
far_plane: float = 1e10
# Strategy for GS densification
strategy: Union[DefaultStrategy, MCMCStrategy] = field(
default_factory=DefaultStrategy
)
# Use packed mode for rasterization, this leads to less memory usage but slightly slower.
packed: bool = False
# Use sparse gradients for optimization. (experimental)
sparse_grad: bool = False
# Use visible adam from Taming 3DGS. (experimental)
visible_adam: bool = False
# Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
antialiased: bool = False
# Use random background for training to discourage transparency
random_bkgd: bool = False
# LR for 3D point positions
means_lr: float = 1.6e-4
# LR for Gaussian scale factors
scales_lr: float = 5e-3
# LR for alpha blending weights
opacities_lr: float = 5e-2
# LR for orientation (quaternions)
quats_lr: float = 1e-3
# LR for SH band 0 (brightness)
sh0_lr: float = 2.5e-3
# LR for higher-order SH (detail)
shN_lr: float = 2.5e-3 / 20
# Opacity regularization
opacity_reg: float = 0.0
# Scale regularization
scale_reg: float = 0.0
# Enable depth loss. (experimental)
depth_loss: bool = False
# Weight for depth loss
depth_lambda: float = 1e-2
# Dump information to tensorboard every this steps
tb_every: int = 200
# Save training images to tensorboard
tb_save_image: bool = False
lpips_net: Literal["vgg", "alex"] = "alex"
# 3DGUT (uncented transform + eval 3D)
with_ut: bool = False
with_eval3d: bool = False
scene_scale: float = 1.0
def adjust_steps(self, factor: float):
self.eval_steps = [int(i * factor) for i in self.eval_steps]
self.save_steps = [int(i * factor) for i in self.save_steps]
self.ply_steps = [int(i * factor) for i in self.ply_steps]
self.max_steps = int(self.max_steps * factor)
self.sh_degree_interval = int(self.sh_degree_interval * factor)
strategy = self.strategy
if isinstance(strategy, DefaultStrategy):
strategy.refine_start_iter = int(
strategy.refine_start_iter * factor
)
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
strategy.reset_every = int(strategy.reset_every * factor)
strategy.refine_every = int(strategy.refine_every * factor)
elif isinstance(strategy, MCMCStrategy):
strategy.refine_start_iter = int(
strategy.refine_start_iter * factor
)
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
strategy.refine_every = int(strategy.refine_every * factor)
else:
assert_never(strategy)
|