from pathlib import Path import torch import numpy as np import cv2 import os import subprocess import open3d as o3d from tqdm import tqdm from pgnd.utils import get_root from pgnd.ffmpeg import make_video import sys root: Path = get_root(__file__) from diff_gaussian_rasterization import GaussianRasterizer from diff_gaussian_rasterization import GaussianRasterizationSettings as Camera from gs.helpers import setup_camera from gs.external import densify from gs.train_utils import get_custom_dataset, initialize_params, initialize_optimizer, get_loss, report_progress, get_batch def Rt_to_w2c(R, t): w2c = np.concatenate([np.concatenate([R, t.reshape(3, 1)], axis=1), np.array([[0, 0, 0, 1]])], axis=0) w2c = np.linalg.inv(w2c) return w2c class GSTrainer: def __init__(self, config, device): self.config = config self.device = device self.clear() def clear(self, clear_params=True): # training data self.init_pt_cld = None self.metadata = None self.img_list = None self.seg_list = None # training results if clear_params: self.params = None @torch.no_grad def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]): render_data = {k: v.to(self.device) for k, v in render_data.items()} w, h = self.metadata['w'], self.metadata['h'] k = self.metadata['k'][cam_id] w2c = self.metadata['w2c'][cam_id] cam = setup_camera(w, h, k, w2c, self.config['near'], self.config['far'], bg) im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data) return im, depth def update_state_env(self, pcd, env, imgs, masks): # RGB R_list, t_list = env.get_extrinsics() intr_list = env.get_intrinsics() rgb_list = [] seg_list = [] for c in range(env.n_fixed_cameras): rgb_list.append(imgs[c] * masks[c][:, :, None]) seg_list.append(masks[c] * 1.0) self.update_state(pcd, rgb_list, seg_list, R_list, t_list, intr_list) def update_state_no_env(self, pcd, imgs, masks, R_list, t_list, intr_list, n_cameras=4): # RGB rgb_list = [] seg_list = [] for c in range(n_cameras): rgb_list.append(imgs[c] * masks[c][:, :, None]) seg_list.append(masks[c] * 1.0) self.update_state(pcd, rgb_list, seg_list, R_list, t_list, intr_list) def update_state(self, pcd, img_list, seg_list, R_list, t_list, intr_list): pts = np.array(pcd.points).astype(np.float32) colors = np.array(pcd.colors).astype(np.float32) seg = np.ones_like(pts[:, 0:1]) self.init_pt_cld = np.concatenate([pts, colors, seg], axis=1) w, h = img_list[0].shape[1], img_list[0].shape[0] assert np.all([img.shape[1] == w and img.shape[0] == h for img in img_list]) self.metadata = { 'w': w, 'h': h, 'k': [intr for intr in intr_list], 'w2c': [Rt_to_w2c(R, t) for R, t in zip(R_list, t_list)] } self.img_list = img_list self.seg_list = seg_list def train(self, vis_dir): params, variables = initialize_params(self.init_pt_cld, self.metadata) optimizer = initialize_optimizer(params, variables) dataset = get_custom_dataset(self.img_list, self.seg_list, self.metadata) todo_dataset = [] num_iters = self.config['num_iters'] loss_weights = {'im': self.config['weight_im'], 'seg': self.config['weight_seg']} densify_params = { 'grad_thresh': self.config['grad_thresh'], 'remove_thresh': self.config['remove_threshold'], 'remove_thresh_5k': self.config['remove_thresh_5k'], 'scale_scene_radius': self.config['scale_scene_radius'] } progress_bar = tqdm(range(num_iters), dynamic_ncols=True) for i in range(num_iters): curr_data = get_batch(todo_dataset, dataset) loss, variables = get_loss(params, curr_data, variables, loss_weights) loss.backward() with torch.no_grad(): params, variables, num_pts = densify(params, variables, optimizer, i, **densify_params) report_progress(params, dataset[0], i, progress_bar, num_pts, vis_dir=vis_dir) optimizer.step() optimizer.zero_grad(set_to_none=True) progress_bar.close() params = {k: v.detach() for k, v in params.items()} self.params = params def rollout_and_render(self, dm, action, vis_dir=None, save_images=True, overwrite_params=True, remove_black=False): assert vis_dir is not None assert self.params is not None xyz_0 = self.params['means3D'] rgb_0 = self.params['rgb_colors'] quat_0 = torch.nn.functional.normalize(self.params['unnorm_rotations']) opa_0 = torch.sigmoid(self.params['logit_opacities']) scales_0 = torch.exp(self.params['log_scales']) low_opa_idx = opa_0[:, 0] < 0.1 xyz_0 = xyz_0[~low_opa_idx] rgb_0 = rgb_0[~low_opa_idx] quat_0 = quat_0[~low_opa_idx] opa_0 = opa_0[~low_opa_idx] scales_0 = scales_0[~low_opa_idx] if remove_black: low_color_idx = rgb_0.sum(dim=-1) < 0.5 xyz_0 = xyz_0[~low_color_idx] rgb_0 = rgb_0[~low_color_idx] quat_0 = quat_0[~low_color_idx] opa_0 = opa_0[~low_color_idx] scales_0 = scales_0[~low_color_idx] eef_xyz_start = action[0] eef_xyz_end = action[1] dist_thresh = 0.005 # 5mm n_steps = int((eef_xyz_end - eef_xyz_start).norm().item() / dist_thresh) eef_xyz = torch.lerp(eef_xyz_start, eef_xyz_end, torch.linspace(0, 1, n_steps).to(self.device)[:, None]) eef_xyz_pad = torch.cat([eef_xyz, eef_xyz_end[None].repeat(dm.n_his, 1)], dim=0) eef_xyz_pad = eef_xyz_pad[:, None] # (n_steps, 1, 3) n_steps = eef_xyz_pad.shape[0] inlier_idx_all = np.arange(len(xyz_0)) # no outlier removal xyz, rgb, quat, opa, xyz_bones, eef = dm.rollout( xyz_0, rgb_0, quat_0, opa_0, eef_xyz_pad, n_steps, inlier_idx_all) # interpolate smoothly change_points = (xyz - torch.concatenate([xyz[0:1], xyz[:-1]], dim=0)).norm(dim=-1).sum(dim=-1).nonzero().squeeze(1) change_points = torch.cat([torch.tensor([0]), change_points]) for i in range(1, len(change_points)): start = change_points[i - 1] end = change_points[i] if end - start < 2: # 0 or 1 continue xyz[start:end] = torch.lerp(xyz[start][None], xyz[end][None], torch.linspace(0, 1, end - start + 1).to(xyz.device)[:, None, None])[:-1] rgb[start:end] = torch.lerp(rgb[start][None], rgb[end][None], torch.linspace(0, 1, end - start + 1).to(rgb.device)[:, None, None])[:-1] quat[start:end] = torch.lerp(quat[start][None], quat[end][None], torch.linspace(0, 1, end - start + 1).to(quat.device)[:, None, None])[:-1] opa[start:end] = torch.lerp(opa[start][None], opa[end][None], torch.linspace(0, 1, end - start + 1).to(opa.device)[:, None, None])[:-1] xyz_bones[start:end] = torch.lerp(xyz_bones[start][None], xyz_bones[end][None], torch.linspace(0, 1, end - start + 1).to(xyz_bones.device)[:, None, None])[:-1] eef[start:end] = torch.lerp(eef[start][None], eef[end][None], torch.linspace(0, 1, end - start + 1).to(eef.device)[:, None, None])[:-1] for _ in range(3): xyz[1:-1] = (xyz[:-2] + 2 * xyz[1:-1] + xyz[2:]) / 4 # quat[1:-1] = (quat[:-2] + 2 * quat[1:-1] + quat[2:]) / 4 quat = torch.nn.functional.normalize(quat, dim=-1) rendervar_list = [] visvar_list = [] # im_list = [] for t in range(n_steps): rendervar = { 'means3D': xyz[t], 'colors_precomp': rgb[t], 'rotations': quat[t], 'opacities': opa[t], 'scales': scales_0, 'means2D': torch.zeros_like(xyz[t]), } rendervar_list.append(rendervar) visvar = { 'xyz_bones': xyz_bones[t].numpy(), # params['means3D'][t][fps_idx].detach().cpu().numpy(), 'eef': eef[t].numpy(), # eef_xyz[t].detach().cpu().numpy(), } visvar_list.append(visvar) if save_images: im, _ = self.render(rendervar, 0, bg=[0, 0, 0]) im = im.cpu().numpy().transpose(1, 2, 0) im = (im * 255).astype(np.uint8) # im_list.append(im) cv2.imwrite(os.path.join(vis_dir, f'{t:04d}.png'), im[:, :, ::-1].copy()) if save_images: make_video(vis_dir, os.path.join(os.path.dirname(vis_dir), f"{vis_dir.split('/')[-1]}.mp4"), '%04d.png', 5) if overwrite_params: self.params['means3D'] = xyz[-1].to(self.device) self.params['rgb_colors'] = rgb[-1].to(self.device) self.params['unnorm_rotations'] = quat[-1].to(self.device) self.params['logit_opacities'] = torch.logit(opa[-1]).to(self.device) self.params['log_scales'] = torch.log(scales_0).to(self.device) self.params['means2D'] = torch.zeros_like(xyz[-1]).to(self.device) return rendervar_list, visvar_list