import gradio as gr import sys import site from PIL import Image from pathlib import Path from omegaconf import DictConfig, OmegaConf from tqdm import tqdm, trange import random import math import hydra import numpy as np import glob import os import subprocess import time import cv2 import copy import yaml import matplotlib.pyplot as plt from sklearn.neighbors import NearestNeighbors import spaces from spaces import zero zero.startup() import torch import torch.nn as nn from torch.utils.data import DataLoader import kornia from diff_gaussian_rasterization import GaussianRasterizer from diff_gaussian_rasterization import GaussianRasterizationSettings as Camera import sys sys.path.insert(0, str(Path(__file__).parent / "src")) sys.path.append(str(Path(__file__).parent / "src" / "experiments")) from real_world.utils.render_utils import interpolate_motions from real_world.gs.helpers import setup_camera from real_world.gs.convert import save_to_splat, read_splat root = Path(__file__).parent / "src" / "experiments" def make_video( image_root: Path, video_path: Path, image_pattern: str = '%04d.png', frame_rate: int = 10): subprocess.run([ 'ffmpeg', '-y', '-hide_banner', '-loglevel', 'error', '-framerate', str(frame_rate), '-i', str(image_root / image_pattern), '-c:v', 'libx264', '-pix_fmt', 'yuv420p', str(video_path) ]) def quat2mat(quat): import kornia return kornia.geometry.conversions.quaternion_to_rotation_matrix(quat) def mat2quat(mat): import kornia return kornia.geometry.conversions.rotation_matrix_to_quaternion(mat) def fps(x, enabled, n, device, random_start=False): import torch from dgl.geometry import farthest_point_sampler assert torch.diff(enabled * 1.0).sum() in [0.0, -1.0] start_idx = random.randint(0, enabled.sum() - 1) if random_start else 0 fps_idx = farthest_point_sampler(x[enabled][None], n, start_idx=start_idx)[0] fps_idx = fps_idx.to(x.device) return fps_idx class DynamicsVisualizer: def __init__(self, wp_device='cuda', torch_device='cuda'): self.best_models = { 'cloth': ['cloth', 'train', 100000, [610, 650]], 'rope': ['rope', 'train', 100000, [651, 691]], 'paperbag': ['paperbag', 'train', 100000, [200, 220]], 'sloth': ['sloth', 'train', 100000, [113, 133]], 'box': ['box', 'train', 100000, [306, 323]], 'bread': ['bread', 'train', 100000, [143, 163]], } task_name = 'rope' self.init(task_name) def init(self, task_name): self.width = 640 self.height = 480 self.task_name = task_name with open(root / f'log/{self.best_models[task_name][0]}/{self.best_models[task_name][1]}/hydra.yaml', 'r') as f: config = yaml.load(f, Loader=yaml.CLoader) cfg = OmegaConf.create(config) cfg.iteration = self.best_models[task_name][2] cfg.start_episode = self.best_models[task_name][3][0] cfg.end_episode = self.best_models[task_name][3][1] cfg.sim.num_steps = 1000 cfg.sim.gripper_forcing = False cfg.sim.uniform = True cfg.sim.use_pv = False device = torch.device('cuda') self.cfg = cfg self.device = device self.k_rel = 8 # knn for relations self.k_wgt = 16 # knn for weights self.with_bg = True self.render_gripper = True self.render_direction = True self.verbose = False self.dt_base = cfg.sim.dt self.high_freq_pred = True seed = cfg.seed random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # torch.autograd.set_detect_anomaly(True) # torch.backends.cudnn.benchmark = True self.clear() def clear(self, clear_params=True): self.metadata = {} self.config = {} if clear_params: self.params = None self.state = { # object 'x': None, 'v': None, 'x_his': None, 'v_his': None, 'x_pred': None, 'v_pred': None, 'clip_bound': None, 'enabled': None, # robot 'prev_key_pos': None, 'prev_key_pos_timestamp': None, 'sub_pos': None, # filling in between key positions 'sub_pos_timestamps': None, 'gripper_radius': None, } self.preprocess_metadata = None self.table_params = None self.gripper_params = None self.sim = None self.statics = None self.colliders = None self.material = None self.friction = None def load_scaniverse(self, data_path): ### load splat params params_obj = read_splat(data_path / 'object.splat') params_table = read_splat(data_path / 'table.splat') params_robot = read_splat(data_path / 'gripper.splat') pts, colors, scales, quats, opacities = params_obj self.params = { 'means3D': torch.from_numpy(pts).to(torch.float32).to(self.device), 'rgb_colors': torch.from_numpy(colors).to(torch.float32).to(self.device), 'log_scales': torch.log(torch.from_numpy(scales).to(torch.float32).to(self.device)), 'unnorm_rotations': torch.from_numpy(quats).to(torch.float32).to(self.device), 'logit_opacities': torch.logit(torch.from_numpy(opacities).to(torch.float32).to(self.device)) } t_pts, t_colors, t_scales, t_quats, t_opacities = params_table t_pts = torch.tensor(t_pts).to(torch.float32).to(self.device) t_colors = torch.tensor(t_colors).to(torch.float32).to(self.device) t_scales = torch.tensor(t_scales).to(torch.float32).to(self.device) t_quats = torch.tensor(t_quats).to(torch.float32).to(self.device) t_opacities = torch.tensor(t_opacities).to(torch.float32).to(self.device) g_pts, g_colors, g_scales, g_quats, g_opacities = params_robot g_pts = torch.tensor(g_pts).to(torch.float32).to(self.device) g_colors = torch.tensor(g_colors).to(torch.float32).to(self.device) g_scales = torch.tensor(g_scales).to(torch.float32).to(self.device) g_quats = torch.tensor(g_quats).to(torch.float32).to(self.device) g_opacities = torch.tensor(g_opacities).to(torch.float32).to(self.device) self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities # data frame self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities # data frame n_particles = self.cfg.sim.n_particles self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32) self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool) ### load preprocess metadata cfg = self.cfg dx = cfg.sim.num_grids[-1] p_x = torch.tensor(pts).to(torch.float32).to(self.device) R = torch.tensor( [[1, 0, 0], [0, 0, -1], [0, 1, 0]] ).to(p_x.device).to(p_x.dtype) p_x_rotated = p_x @ R.T scale = 1.0 p_x_rotated_scaled = p_x_rotated * scale global_translation = torch.tensor([ 0.5 - p_x_rotated_scaled[:, 0].mean(), dx * (cfg.model.clip_bound + 0.5) - p_x_rotated_scaled[:, 1].min(), 0.5 - p_x_rotated_scaled[:, 2].mean(), ], dtype=p_x_rotated_scaled.dtype, device=p_x_rotated_scaled.device) R_viewer = torch.tensor( [[1, 0, 0], [0, 0, -1], [0, 1, 0]] ).to(p_x.device).to(p_x.dtype) t_viewer = torch.tensor([0, 0, 0]).to(p_x.device).to(p_x.dtype) self.preprocess_metadata = { 'R': R, 'R_viewer': R_viewer, 't_viewer': t_viewer, 'scale': scale, 'global_translation': global_translation, } ### load eef grippers = np.loadtxt(data_path / 'eef_xyz.txt')[None] assert grippers.shape == (1, 3) if grippers is not None: grippers = torch.tensor(grippers).to(self.device).to(torch.float32) # transform # data frame to model frame R = self.preprocess_metadata['R'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] grippers[:, :3] = grippers[:, :3] @ R.T grippers[:, :3] = grippers[:, :3] * scale grippers[:, :3] += global_translation assert grippers.shape[0] == 1 self.state['prev_key_pos'] = grippers[:, :3] # (1, 3) # self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32) self.state['gripper_radius'] = cfg.model.gripper_radius def load_eef(self, grippers=None, eef_t=None): assert self.state['prev_key_pos'] is None if grippers is not None: grippers = torch.tensor(grippers).to(self.device).to(torch.float32) eef_t = torch.tensor(eef_t).to(self.device).to(torch.float32) grippers[:, :3] = grippers[:, :3] + eef_t # transform # data frame to model frame R = self.preprocess_metadata['R'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] grippers[:, :3] = grippers[:, :3] @ R.T grippers[:, :3] = grippers[:, :3] * scale grippers[:, :3] += global_translation assert grippers.shape[0] == 1 self.state['prev_key_pos'] = grippers[:, :3] # (1, 3) # self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32) + 0.001 self.state['gripper_radius'] = self.cfg.model.gripper_radius def load_preprocess_metadata(self, p_x_orig): cfg = self.cfg dx = cfg.sim.num_grids[-1] p_x_orig = p_x_orig.to(self.device) R = torch.tensor( [[1, 0, 0], [0, 0, -1], [0, 1, 0]] ).to(p_x_orig.device).to(p_x_orig.dtype) p_x_orig_rotated = torch.einsum('nij,jk->nik', p_x_orig, R.T) scale = 1.0 p_x_orig_rotated_scaled = p_x_orig_rotated * scale global_translation = torch.tensor([ 0.5 - p_x_orig_rotated_scaled[:, :, 0].mean(), dx * (cfg.model.clip_bound + 0.5) - p_x_orig_rotated_scaled[:, :, 1].min(), 0.5 - p_x_orig_rotated_scaled[:, :, 2].mean(), ], dtype=p_x_orig_rotated_scaled.dtype, device=p_x_orig_rotated_scaled.device) R_viewer = torch.tensor( [[1, 0, 0], [0, 0, -1], [0, 1, 0]] ).to(p_x_orig.device).to(p_x_orig.dtype) t_viewer = torch.tensor([0, 0, 0]).to(p_x_orig.device).to(p_x_orig.dtype) self.preprocess_metadata = { 'R': R, 'R_viewer': R_viewer, 't_viewer': t_viewer, 'scale': scale, 'global_translation': global_translation, } # @torch.no_grad def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]): render_data = {k: v.to(self.device) for k, v in render_data.items()} w, h = self.metadata['w'], self.metadata['h'] k, w2c = self.metadata['k'], self.metadata['w2c'] cam = setup_camera(w, h, k, w2c, self.config['near'], self.config['far'], bg) im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data) return im, depth def knn_relations(self, bones): k = self.k_rel knn = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(bones.detach().cpu().numpy()) _, indices = knn.kneighbors(bones.detach().cpu().numpy()) # (N, k) indices = indices[:, 1:] # exclude self return indices def knn_weights_brute(self, bones, pts): k = self.k_wgt dist = torch.norm(pts[:, None] - bones, dim=-1) # (n_pts, n_bones) _, indices = torch.topk(dist, k, dim=-1, largest=False) bones_selected = bones[indices] # (N, k, 3) dist = torch.norm(bones_selected - pts[:, None], dim=-1) # (N, k) weights = 1 / (dist + 1e-6) weights = weights / weights.sum(dim=-1, keepdim=True) # (N, k) weights_all = torch.zeros((pts.shape[0], bones.shape[0]), device=pts.device) weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights return weights_all def update_camera(self, k, w2c, w=None, h=None, near=0.01, far=100.0): self.metadata['k'] = k self.metadata['w2c'] = w2c if w is not None: self.metadata['w'] = w if h is not None: self.metadata['h'] = h self.config['near'] = near self.config['far'] = far def init_model(self, batch_size, num_steps, num_particles, ckpt_path=None): from pgnd.sim import Friction, CacheDiffSimWithFrictionBatch, StaticsBatch, CollidersBatch from pgnd.material import PGNDModel self.cfg.sim.num_steps = num_steps cfg = self.cfg sim = CacheDiffSimWithFrictionBatch(cfg, num_steps, batch_size, self.wp_device, requires_grad=True) statics = StaticsBatch() statics.init(shape=(batch_size, num_particles), device=self.wp_device) statics.update_clip_bound(self.state['clip_bound'].detach().cpu()) statics.update_enabled(self.state['enabled'][None].detach().cpu()) colliders = CollidersBatch() colliders.init(shape=(batch_size, cfg.sim.num_grippers), device=self.wp_device) self.sim = sim self.statics = statics self.colliders = colliders # load ckpt ckpt_path = root / f'log/{self.task_name}/train/ckpt/100000.pt' ckpt = torch.load(ckpt_path, map_location=self.torch_device) material: nn.Module = PGNDModel(cfg) material.to(self.torch_device) material.load_state_dict(ckpt['material']) material.requires_grad_(False) material.eval() if 'friction' in ckpt: friction = ckpt['friction']['mu'].reshape(-1, 1) else: friction = torch.tensor(cfg.model.friction.value, device=self.torch_device).reshape(-1, 1) self.material = material self.friction = friction def reload_model(self, num_steps): # only change num_steps from pgnd.sim import CacheDiffSimWithFrictionBatch self.cfg.sim.num_steps = num_steps sim = CacheDiffSimWithFrictionBatch(self.cfg, num_steps, 1, self.wp_device, requires_grad=True) self.sim = sim # @torch.no_grad def step(self): cfg = self.cfg batch_size = 1 num_steps = 1 num_particles = cfg.sim.n_particles # update state by previous prediction self.state['x_his'] = torch.cat([self.state['x_his'][1:], self.state['x'][None]], dim=0) self.state['v_his'] = torch.cat([self.state['v_his'][1:], self.state['v'][None]], dim=0) self.state['x'] = self.state['x_pred'].clone() self.state['v'] = self.state['v_pred'].clone() eef_xyz_key = self.state['prev_key_pos'] # (1, 3), model frame eef_xyz_sub = self.state['sub_pos'] # (T, 1, 3), model frame if eef_xyz_sub is None: return # eef_xyz_key_timestamp = self.state['prev_key_pos_timestamp'] # eef_xyz_sub_timestamps = self.state['sub_pos_timestamps'] # assert eef_xyz_key_timestamp.item() > 0 # delta_t = (eef_xyz_sub_timestamps[-1] - eef_xyz_key_timestamp).item() # if (not self.high_freq_pred) and delta_t < self.dt_base * 0.9: # return # cfg.sim.dt = delta_t eef_xyz_key_next = eef_xyz_sub[-1] # (1, 3), model frame eef_v = (eef_xyz_key_next - eef_xyz_key) / cfg.sim.dt if self.verbose: print('delta_t:', np.round(cfg.sim.dt, 4)) print('eef_xyz_key_next:', eef_xyz_key_next.cpu().numpy().tolist()) print('eef_xyz_key:', eef_xyz_key.cpu().numpy().tolist()) print('v:', eef_v.cpu().numpy().tolist()) # load model, sim, statics, colliders # self.reload_model(num_steps) # initialize colliders if cfg.sim.num_grippers > 0: grippers = torch.zeros((batch_size, cfg.sim.num_grippers, 15), device=self.torch_device) eef_quat = torch.tensor([1, 0, 0, 0], dtype=torch.float32, device=self.torch_device).repeat(batch_size, cfg.sim.num_grippers, 1) # (B, G, 4) eef_quat_vel = torch.zeros((batch_size, cfg.sim.num_grippers, 3), dtype=torch.float32, device=self.torch_device) eef_gripper = torch.zeros((batch_size, cfg.sim.num_grippers), dtype=torch.float32, device=self.torch_device) grippers[:, :, :3] = eef_xyz_key grippers[:, :, 3:6] = eef_v grippers[:, :, 6:10] = eef_quat grippers[:, :, 10:13] = eef_quat_vel grippers[:, :, 13] = cfg.model.gripper_radius grippers[:, :, 14] = eef_gripper self.colliders.initialize_grippers(grippers) x = self.state['x'].clone()[None].repeat(batch_size, 1, 1) v = self.state['v'].clone()[None].repeat(batch_size, 1, 1) x_his = self.state['x_his'].permute(1, 0, 2).clone() assert x_his.shape[0] == num_particles x_his = x_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1) v_his = self.state['v_his'].permute(1, 0, 2).clone() assert v_his.shape[0] == num_particles v_his = v_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1) enabled = self.state['enabled'].clone().to(self.torch_device)[None].repeat(batch_size, 1) for t in range(num_steps): x_in = x.clone() pred = self.material(x, v, x_his, v_his, enabled) # x_his = torch.cat([x_his.reshape(batch_size, num_particles, -1, 3)[:, :, 1:], x[:, :, None].detach()], dim=2) # v_his = torch.cat([v_his.reshape(batch_size, num_particles, -1, 3)[:, :, 1:], v[:, :, None].detach()], dim=2) # x_his = x_his.reshape(batch_size, num_particles, -1) # v_his = v_his.reshape(batch_size, num_particles, -1) x, v = self.sim(self.statics, self.colliders, t, x, v, self.friction, pred) # calculate new x_pred, v_pred, eef_xyz_key and eef_xyz_sub x_pred = x[0].clone() v_pred = v[0].clone() self.state['x_pred'] = x_pred self.state['v_pred'] = v_pred # self.state['x_his'] = x_his[0].reshape(num_particles, self.cfg.sim.n_history, 3).permute(1, 0, 2).clone() # self.state['v_his'] = v_his[0].reshape(num_particles, self.cfg.sim.n_history, 3).permute(1, 0, 2).clone() self.state['prev_key_pos'] = eef_xyz_key_next # self.state['prev_key_pos_timestamp'] = eef_xyz_sub_timestamps[-1] self.state['sub_pos'] = None # self.state['sub_pos_timestamps'] = None def preprocess_x(self, p_x): # viewer frame to model frame (not data frame) R = self.preprocess_metadata['R'] R_viewer = self.preprocess_metadata['R_viewer'] t_viewer = self.preprocess_metadata['t_viewer'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] # viewer frame to model frame p_x = (p_x - t_viewer) @ R_viewer # model frame to data frame # p_x -= global_translation # p_x = p_x / scale # p_x = p_x @ torch.linalg.inv(R).T return p_x def preprocess_gripper(self, grippers): # viewer frame to model frame (not data frame) R = self.preprocess_metadata['R'] R_viewer = self.preprocess_metadata['R_viewer'] t_viewer = self.preprocess_metadata['t_viewer'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] # viewer frame to model frame grippers[:, :3] = grippers[:, :3] @ R_viewer return grippers def inverse_preprocess_x(self, p_x): # model frame (not data frame) to viewer frame R = self.preprocess_metadata['R'] R_viewer = self.preprocess_metadata['R_viewer'] t_viewer = self.preprocess_metadata['t_viewer'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] # model frame to viewer frame p_x = p_x @ R_viewer.T + t_viewer return p_x def inverse_preprocess_gripper(self, grippers): # model frame (not data frame) to viewer frame R = self.preprocess_metadata['R'] R_viewer = self.preprocess_metadata['R_viewer'] t_viewer = self.preprocess_metadata['t_viewer'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] # model frame to viewer frame grippers[:, :3] = grippers[:, :3] @ R_viewer.T + t_viewer return grippers def rotate(self, params, rot_mat): scale = np.linalg.norm(rot_mat, axis=1, keepdims=True) params = { 'means3D': pts, 'rgb_colors': params['rgb_colors'], 'log_scales': params['log_scales'], 'unnorm_rotations': quats, 'logit_opacities': params['logit_opacities'], } return params def preprocess_gs(self, params): if isinstance(params, dict): xyz = params['means3D'] rgb = params['rgb_colors'] quat = torch.nn.functional.normalize(params['unnorm_rotations']) opa = torch.sigmoid(params['logit_opacities']) scales = torch.exp(params['log_scales']) else: assert isinstance(params, tuple) xyz, rgb, quat, opa, scales = params quat = torch.nn.functional.normalize(quat, dim=-1) # transform R = self.preprocess_metadata['R'] R_viewer = self.preprocess_metadata['R_viewer'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] mat = quat2mat(quat) mat = R @ mat xyz = xyz @ R.T xyz = xyz * scale xyz += global_translation quat = mat2quat(mat) scales = scales * scale # viewer-specific transform (flip y and z) # model frame to viewer frame xyz = xyz @ R_viewer.T quat = mat2quat(R_viewer @ quat2mat(quat)) t_viewer = -xyz.mean(dim=0) t_viewer[2] = 0 xyz += t_viewer print('Overwriting t_viewer to be the planar mean of the object') self.preprocess_metadata['t_viewer'] = t_viewer if isinstance(params, dict): params['means3D'] = xyz params['rgb_colors'] = rgb params['unnorm_rotations'] = quat params['logit_opacities'] = opa params['log_scales'] = torch.log(scales) else: params = xyz, rgb, quat, opa, scales return params def preprocess_bg_gs(self): t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params # identify tip first g_pts_tip_z = g_pts[:, 2].max() g_pts_tip_mask = (g_pts[:, 2] > g_pts_tip_z - 0.04) & (g_pts[:, 2] < g_pts_tip_z) R = self.preprocess_metadata['R'] R_viewer = self.preprocess_metadata['R_viewer'] t_viewer = self.preprocess_metadata['t_viewer'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] t_mat = quat2mat(t_quats) t_mat = R @ t_mat t_pts = t_pts @ R.T t_pts = t_pts * scale t_pts += global_translation t_quats = mat2quat(t_mat) t_scales = t_scales * scale t_pts = t_pts @ R_viewer.T t_quats = mat2quat(R_viewer @ quat2mat(t_quats)) t_pts += t_viewer axes = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] dirs = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] # x, y, z axes for ee in range(3): gripper_direction = torch.tensor(dirs[ee], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) gripper_direction = gripper_direction / (torch.norm(gripper_direction, dim=-1, keepdim=True) + 1e-10) # normalize R = self.preprocess_metadata['R'] # model frame to data frame direction = gripper_direction @ R.T n_grippers = 1 N = 200 length = 0.2 kk = 5 xyz_test = torch.zeros((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=t_pts.dtype) if self.task_name == 'rope': pos = torch.tensor([0.0, 0.0, 1.2], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) # gripper position in model frame else: pos = torch.tensor([1.2, 0.0, 0.7], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) gripper_now_inv_xyz = self.inverse_preprocess_gripper(pos) gripper_now_inv_rot = torch.eye(3, device=self.torch_device).unsqueeze(0).repeat(n_grippers, 1, 1) center_point = torch.tensor([0.0, 0.0, 0.10], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) # center point in gripper frame gripper_center_inv_xyz = gripper_now_inv_xyz + \ torch.einsum('ijk,ik->ij', gripper_now_inv_rot, center_point) # (n_grippers, 3) for i in range(N): offset = i / N * length * direction xyz_test[:, i] = gripper_center_inv_xyz + offset if direction[0, 2] < 0.9 and direction[0, 2] > -0.9: # not vertical direction_up = -direction + torch.tensor([0.0, 0.0, 0.5], device=self.torch_device, dtype=t_pts.dtype) direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) # normalize direction_down = -direction + torch.tensor([0.0, 0.0, -0.5], device=self.torch_device, dtype=t_pts.dtype) direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) # normalize else: direction_up = -direction + torch.tensor([0.0, 0.5, 0.0], device=self.torch_device, dtype=t_pts.dtype) direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) # normalize direction_down = -direction + torch.tensor([0.0, -0.5, 0.0], device=self.torch_device, dtype=t_pts.dtype) direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) # normalize for i in range(N, N + N // kk): offset = length * direction + (i - N) / N * length * direction_up xyz_test[:, i] = gripper_center_inv_xyz + offset for i in range(N + N // kk, N + N // kk + N // kk): offset = length * direction + (i - N - N // kk) / N * length * direction_down xyz_test[:, i] = gripper_center_inv_xyz + offset color_test = torch.zeros_like(xyz_test, device=self.torch_device, dtype=t_pts.dtype) color_test[:, :, 0] = axes[ee][0] color_test[:, :, 1] = axes[ee][1] color_test[:, :, 2] = axes[ee][2] quat_test = torch.zeros((n_grippers, N + N // kk + N // kk, 4), device=self.torch_device, dtype=t_pts.dtype) quat_test[:, :, 0] = 1.0 # identity quaternion opa_test = torch.ones((n_grippers, N + N // kk + N // kk, 1), device=self.torch_device, dtype=t_pts.dtype) scales_test = torch.ones((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=t_pts.dtype) * 0.002 t_pts = torch.cat([t_pts, xyz_test.reshape(-1, 3)], dim=0) t_colors = torch.cat([t_colors, color_test.reshape(-1, 3)], dim=0) t_quats = torch.cat([t_quats, quat_test.reshape(-1, 4)], dim=0) t_opacities = torch.cat([t_opacities, opa_test.reshape(-1, 1)], dim=0) t_scales = torch.cat([t_scales, scales_test.reshape(-1, 3)], dim=0) t_pts = t_pts.reshape(-1, 3) t_colors = t_colors.reshape(-1, 3) t_quats = t_quats.reshape(-1, 4) t_opacities = t_opacities.reshape(-1, 1) g_mat = quat2mat(g_quats) g_mat = R @ g_mat g_pts = g_pts @ R.T g_pts = g_pts * scale g_pts += global_translation g_quats = mat2quat(g_mat) g_scales = g_scales * scale g_pts = g_pts @ R_viewer.T g_quats = mat2quat(R_viewer @ quat2mat(g_quats)) g_pts += t_viewer # TODO: center gripper in the viewer frame g_pts_tip = g_pts[g_pts_tip_mask] g_pts_tip_mean_xy = g_pts_tip[:, :2].mean(dim=0) if self.task_name == 'rope': g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.23]).to(torch.float32).to(self.device) elif self.task_name == 'sloth': g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.32]).to(torch.float32).to(self.device) else: raise NotImplementedError(f"Task {self.task_name} not implemented for gripper translation.") g_pts = g_pts + g_pts_translation self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities def update_rendervar(self, rendervar): p_x = self.state['x'] p_x_viewer = self.inverse_preprocess_x(p_x) p_x_pred = self.state['x_pred'] p_x_pred_viewer = self.inverse_preprocess_x(p_x_pred) xyz = rendervar['means3D'] rgb = rendervar['colors_precomp'] quat = rendervar['rotations'] opa = rendervar['opacities'] scales = rendervar['scales'] relations = self.knn_relations(p_x_viewer) weights = self.knn_weights_brute(p_x_viewer, xyz) xyz, quat, _ = interpolate_motions( bones=p_x_viewer, motions=p_x_pred_viewer - p_x_viewer, relations=relations, weights=weights, xyz=xyz, quat=quat, ) # normalize quat = torch.nn.functional.normalize(quat, dim=-1) rendervar = { 'means3D': xyz, 'colors_precomp': rgb, 'rotations': quat, 'opacities': opa, 'scales': scales, 'means2D': torch.zeros_like(xyz), } if self.with_bg: t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params # merge xyz = torch.cat([xyz, t_pts], dim=0) rgb = torch.cat([rgb, t_colors], dim=0) quat = torch.cat([quat, t_quats], dim=0) opa = torch.cat([opa, t_opacities], dim=0) scales = torch.cat([scales, t_scales], dim=0) if self.render_gripper: g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params # add gripper pos g_pts = g_pts + self.inverse_preprocess_gripper(self.state['prev_key_pos'][None].clone())[0] # merge xyz = torch.cat([xyz, g_pts], dim=0) rgb = torch.cat([rgb, g_colors], dim=0) quat = torch.cat([quat, g_quats], dim=0) opa = torch.cat([opa, g_opacities], dim=0) scales = torch.cat([scales, g_scales], dim=0) if self.render_direction: gripper_direction = self.gripper_direction gripper_direction = gripper_direction / (torch.norm(gripper_direction, dim=-1, keepdim=True) + 1e-10) # normalize R = self.preprocess_metadata['R'] # model frame to data frame direction = gripper_direction @ R.T n_grippers = 1 N = 200 length = 0.2 kk = 5 xyz_test = torch.zeros((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=xyz.dtype) gripper_now_inv_xyz = self.inverse_preprocess_gripper(self.state['prev_key_pos'][None].clone()) gripper_now_inv_rot = torch.eye(3, device=self.torch_device).unsqueeze(0).repeat(n_grippers, 1, 1) center_point = torch.tensor([0.0, 0.0, 0.10], device=self.torch_device, dtype=xyz.dtype).reshape(1, 3) # center point in gripper frame gripper_center_inv_xyz = gripper_now_inv_xyz + \ torch.einsum('ijk,ik->ij', gripper_now_inv_rot, center_point) # (n_grippers, 3) for i in range(N): offset = i / N * length * direction xyz_test[:, i] = gripper_center_inv_xyz + offset if direction[0, 2] < 0.9 and direction[0, 2] > -0.9: # not vertical direction_up = -direction + torch.tensor([0.0, 0.0, 0.5], device=self.torch_device, dtype=xyz.dtype) direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) # normalize direction_down = -direction + torch.tensor([0.0, 0.0, -0.5], device=self.torch_device, dtype=xyz.dtype) direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) # normalize else: direction_up = -direction + torch.tensor([0.0, 0.5, 0.0], device=self.torch_device, dtype=xyz.dtype) direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) # normalize direction_down = -direction + torch.tensor([0.0, -0.5, 0.0], device=self.torch_device, dtype=xyz.dtype) direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) # normalize for i in range(N, N + N // kk): offset = length * direction + (i - N) / N * length * direction_up xyz_test[:, i] = gripper_center_inv_xyz + offset for i in range(N + N // kk, N + N // kk + N // kk): offset = length * direction + (i - N - N // kk) / N * length * direction_down xyz_test[:, i] = gripper_center_inv_xyz + offset color_test = torch.zeros_like(xyz_test, device=self.torch_device, dtype=xyz.dtype) color_test[:, :, 0] = 255 / 255 # red color_test[:, :, 1] = 80 / 255 # green color_test[:, :, 2] = 110 / 255 # blue quat_test = torch.zeros((n_grippers, N + N // kk + N // kk, 4), device=self.torch_device, dtype=xyz.dtype) quat_test[:, :, 0] = 1.0 # identity quaternion opa_test = torch.ones((n_grippers, N + N // kk + N // kk, 1), device=self.torch_device, dtype=xyz.dtype) scales_test = torch.ones((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=xyz.dtype) * 0.002 xyz = torch.cat([xyz, xyz_test.reshape(-1, 3)], dim=0) rgb = torch.cat([rgb, color_test.reshape(-1, 3)], dim=0) quat = torch.cat([quat, quat_test.reshape(-1, 4)], dim=0) opa = torch.cat([opa, opa_test.reshape(-1, 1)], dim=0) scales = torch.cat([scales, scales_test.reshape(-1, 3)], dim=0) # normalize quat = torch.nn.functional.normalize(quat, dim=-1) rendervar_full = { 'means3D': xyz, 'colors_precomp': rgb, 'rotations': quat, 'opacities': opa, 'scales': scales, 'means2D': torch.zeros_like(xyz), } else: rendervar_full = rendervar return rendervar, rendervar_full def reset_state(self, params, visualize_image=False, init=False): xyz_0 = params['means3D'] rgb_0 = params['rgb_colors'] quat_0 = torch.nn.functional.normalize(params['unnorm_rotations']) opa_0 = torch.sigmoid(params['logit_opacities']) scales_0 = torch.exp(params['log_scales']) rendervar_init = { 'means3D': xyz_0, 'colors_precomp': rgb_0, 'rotations': quat_0, 'opacities': opa_0, 'scales': scales_0, 'means2D': torch.zeros_like(xyz_0), } # before preprocess w = self.width h = self.height center = (0, 0, 0.1) distance = 0.7 elevation = 20 azimuth = 180.0 if self.task_name == 'rope' else 120.0 target = np.array(center) theta = 90 + azimuth z = distance * math.sin(math.radians(elevation)) y = math.cos(math.radians(theta)) * distance * math.cos(math.radians(elevation)) x = math.sin(math.radians(theta)) * distance * math.cos(math.radians(elevation)) origin = target + np.array([x, y, z]) look_at = target - origin look_at /= np.linalg.norm(look_at) up = np.array([0.0, 0.0, 1.0]) right = np.cross(look_at, up) right /= np.linalg.norm(right) up = np.cross(right, look_at) w2c = np.eye(4) w2c[:3, 0] = right w2c[:3, 1] = -up w2c[:3, 2] = look_at w2c[:3, 3] = origin w2c = np.linalg.inv(w2c) k = np.array( [[w / 2 * 1.0, 0., w / 2], [0., w / 2 * 1.0, h / 2], [0., 0., 1.]], ) self.metadata = {} self.config = {} self.update_camera(k, w2c, w, h) n_particles = self.cfg.sim.n_particles downsample_indices = fps(xyz_0, torch.ones_like(xyz_0[:, 0]).to(torch.bool), n_particles, self.torch_device) p_x_viewer = xyz_0[downsample_indices] p_x = self.preprocess_x(p_x_viewer) self.state['x'] = p_x self.state['v'] = torch.zeros_like(p_x) self.state['x_his'] = p_x[None].repeat(self.cfg.sim.n_history, 1, 1) self.state['v_his'] = torch.zeros_like(p_x[None].repeat(self.cfg.sim.n_history, 1, 1)) self.state['x_pred'] = p_x self.state['v_pred'] = torch.zeros_like(p_x) rendervar_init, rendervar_init_full = self.update_rendervar(rendervar_init) im, depth = self.render(rendervar_init_full, 0, bg=[0.0, 0.0, 0.0]) im_vis = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8) return rendervar_init def reset(self, task_name, scene_name): self.init(task_name) import warp as wp wp.init() gpus = [int(gpu) for gpu in self.cfg.gpus] wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] device_count = len(torch_devices) assert device_count == 1 self.wp_device = wp_devices[0] self.torch_device = torch_devices[0] in_dir = root / f'log/gs/ckpts/{scene_name}' batch_size = 1 num_steps = 1 num_particles = self.cfg.sim.n_particles self.load_scaniverse(in_dir) self.init_model(batch_size, num_steps, num_particles, ckpt_path=None) self.render_direction = False params = self.preprocess_gs(self.params) if self.with_bg: self.preprocess_bg_gs() rendervar = self.reset_state(params, visualize_image=False, init=True) rendervar, rendervar_full = self.update_rendervar(rendervar) # self.rendervar = rendervar im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() cv2.imwrite(str(root / 'log/temp_init/0000.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR)) make_video(root / 'log/temp_init', root / f'log/gs/temp/form_video_init.mp4', '%04d.png', 1) gs_pred = save_to_splat( rendervar_full['means3D'].cpu().numpy(), rendervar_full['colors_precomp'].cpu().numpy(), rendervar_full['scales'].cpu().numpy(), rendervar_full['rotations'].cpu().numpy(), rendervar_full['opacities'].cpu().numpy(), root / 'log/gs/temp/gs_pred.splat', rot_rev=True, ) for k, v in self.preprocess_metadata.items(): self.preprocess_metadata[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v for k, v in self.state.items(): self.state[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v for k, v in self.params.items(): if isinstance(v, dict): for k2, v2 in v.items(): self.params[k][k2] = v2.detach().cpu() if isinstance(v2, torch.Tensor) else v2 else: self.params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v self.table_params = tuple( v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.table_params ) self.gripper_params = tuple( v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.gripper_params ) for k, v in rendervar.items(): rendervar[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v form_video = gr.Video( label='Predicted video', value=root / f'log/gs/temp/form_video_init.mp4', format='mp4', width=self.width, height=self.height, ) form_3dgs_pred = gr.Model3D( label='Predicted Gaussian Splats', height=self.height, value=root / 'log/gs/temp/gs_pred.splat', clear_color=[0, 0, 0, 0], ) return form_video, form_3dgs_pred, \ self.preprocess_metadata, self.state, self.params, \ self.table_params, self.gripper_params, rendervar, task_name def run_command(self, unit_command, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): self.task_name = task_name import warp as wp wp.init() gpus = [int(gpu) for gpu in self.cfg.gpus] wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] device_count = len(torch_devices) assert device_count == 1 self.wp_device = wp_devices[0] self.torch_device = torch_devices[0] os.system('rm -rf ' + str(root / 'log/temp/*')) w = 640 h = 480 center = (0, 0, 0.1) distance = 0.7 elevation = 20 azimuth = 180.0 if self.task_name == 'rope' else 120.0 target = np.array(center) theta = 90 + azimuth z = distance * math.sin(math.radians(elevation)) y = math.cos(math.radians(theta)) * distance * math.cos(math.radians(elevation)) x = math.sin(math.radians(theta)) * distance * math.cos(math.radians(elevation)) origin = target + np.array([x, y, z]) look_at = target - origin look_at /= np.linalg.norm(look_at) up = np.array([0.0, 0.0, 1.0]) right = np.cross(look_at, up) right /= np.linalg.norm(right) up = np.cross(right, look_at) w2c = np.eye(4) w2c[:3, 0] = right w2c[:3, 1] = -up w2c[:3, 2] = look_at w2c[:3, 3] = origin w2c = np.linalg.inv(w2c) k = np.array( [[w / 2 * 1.0, 0., w / 2], [0., w / 2 * 1.0, h / 2], [0., 0., 1.]], ) self.update_camera(k, w2c, w, h) self.preprocess_metadata = preprocess_metadata self.state = state self.params = params self.table_params = table_params self.gripper_params = gripper_params for k, v in self.preprocess_metadata.items(): self.preprocess_metadata[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v for k, v in self.state.items(): self.state[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v for k, v in self.params.items(): if isinstance(v, dict): for k2, v2 in v.items(): self.params[k][k2] = v2.to(self.torch_device) if isinstance(v2, torch.Tensor) else v2 else: self.params[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v self.table_params = tuple( v.to(self.torch_device) if isinstance(v, torch.Tensor) else v for v in self.table_params ) self.gripper_params = tuple( v.to(self.torch_device) if isinstance(v, torch.Tensor) else v for v in self.gripper_params ) for k, v in rendervar.items(): rendervar[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v num_steps = 15 batch_size = 1 num_particles = self.cfg.sim.n_particles self.init_model(batch_size, num_steps, num_particles, ckpt_path=None) self.render_direction = True # im_list = [] for i in range(num_steps): dt = 0.1 # 100ms command = torch.tensor([unit_command]).to(self.device).to(torch.float32) # 5cm/s command = self.preprocess_gripper(command) # command_timestamp = torch.tensor([self.state['prev_key_pos_timestamp'] + (i+1) * dt]).to(self.device).to(torch.float32) # print(command_timestamp) if self.verbose: print('command:', command.cpu().numpy().tolist()) self.gripper_direction = command.clone() assert self.state['sub_pos'] is None if self.state['sub_pos'] is None: eef_xyz_latest = self.state['prev_key_pos'] # eef_xyz_timestamp_latest = self.state['prev_key_pos_timestamp'] else: eef_xyz_latest = self.state['sub_pos'][-1] # (1, 3), model frame # eef_xyz_timestamp_latest = self.state['sub_pos_timestamps'][-1].item() eef_xyz_updated = eef_xyz_latest + command * dt * 0.01 # cm to m if self.state['sub_pos'] is None: self.state['sub_pos'] = eef_xyz_updated[None] # self.state['sub_pos_timestamps'] = command_timestamp else: self.state['sub_pos'] = torch.cat([self.state['sub_pos'], eef_xyz_updated[None]], dim=0) # self.state['sub_pos_timestamps'] = torch.cat([self.state['sub_pos_timestamps'], command_timestamp], dim=0) # if self.state['sub_pos'] is None: # eef_xyz = self.state['prev_key_pos'] # else: # eef_xyz = self.state['sub_pos'][-1] # (1, 3), model frame # if self.verbose: # print(eef_xyz.cpu().numpy().tolist(), end=' ') self.step() rendervar, rendervar_full = self.update_rendervar(rendervar) # self.rendervar = rendervar im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() # im_list.append(im_show) cv2.imwrite(str(root / f'log/temp/{i:04}.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR)) # self.state['prev_key_pos_timestamp'] = self.state['prev_key_pos_timestamp'] + 20 * dt self.state['v'] *= 0.0 self.state['x'] = self.state['x_pred'].clone() self.state['x_his'] = self.state['x'][None].repeat(self.cfg.sim.n_history, 1, 1) self.state['v_his'] *= 0.0 self.state['v_pred'] *= 0.0 for k, v in self.preprocess_metadata.items(): self.preprocess_metadata[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v for k, v in self.state.items(): self.state[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v for k, v in self.params.items(): if isinstance(v, dict): for k2, v2 in v.items(): self.params[k][k2] = v2.detach().cpu() if isinstance(v2, torch.Tensor) else v2 else: self.params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v self.table_params = tuple( v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.table_params ) self.gripper_params = tuple( v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.gripper_params ) for k, v in rendervar.items(): rendervar[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v make_video(root / 'log/temp', root / f'log/gs/temp/form_video.mp4', '%04d.png', 5) form_video = gr.Video( label='Predicted video', value=root / f'log/gs/temp/form_video.mp4', format='mp4', width=self.width, height=self.height, ) im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() gs_pred = save_to_splat( rendervar_full['means3D'].cpu().numpy(), rendervar_full['colors_precomp'].cpu().numpy(), rendervar_full['scales'].cpu().numpy(), rendervar_full['rotations'].cpu().numpy(), rendervar_full['opacities'].cpu().numpy(), root / 'log/gs/temp/gs_pred.splat', rot_rev=True, ) form_3dgs_pred = gr.Model3D( label='Predicted Gaussian Splats', height=self.height, value=root / 'log/gs/temp/gs_pred.splat', clear_color=[0, 0, 0, 0], ) return form_video, form_3dgs_pred, \ self.preprocess_metadata, self.state, self.params, \ self.table_params, self.gripper_params, rendervar, task_name @spaces.GPU def reset_rope(self): return self.reset('rope', 'rope_scene_1') @spaces.GPU def reset_plush(self): return self.reset('sloth', 'sloth_scene_1') @spaces.GPU def on_click_run_xplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): return self.run_command([5.0, 0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) @spaces.GPU def on_click_run_xminus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): return self.run_command([-5.0, 0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) @spaces.GPU def on_click_run_yplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): return self.run_command([0, 5.0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) @spaces.GPU def on_click_run_yminus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): return self.run_command([0, -5.0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) @spaces.GPU def on_click_run_zplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): return self.run_command([0, 0, 5.0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) @spaces.GPU def on_click_run_zminus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): return self.run_command([0, 0, -5.0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) def launch(self, share=False): with gr.Blocks() as app: preprocess_metadata = gr.State(self.preprocess_metadata) state = gr.State(self.state) params = gr.State(self.params) table_params = gr.State(self.table_params) gripper_params = gr.State(self.gripper_params) rendervar = gr.State(None) task_name = gr.State(self.task_name) with gr.Row(): gr.Markdown("# Particle-Grid Neural Dynamics for Learning Deformable Object Models from RGB-D Videos") with gr.Row(): gr.Markdown('### Project page: [https://kywind.github.io/pgnd](https://kywind.github.io/pgnd)') with gr.Row(): gr.Markdown('### Instructions:') with gr.Row(): gr.Markdown(' '.join([ '- Click the "Reset-\" button to initialize the simulation with the predicted video and Gaussian splats. Due to compute limitations of Huggingface Space, each run may take a prolonged period (up to 30 seconds).\n', '- Use the buttons to move the gripper in the x, y, z directions. The gripper will move for a fixed length per click. The predicted video and Gaussian splats will be updated accordingly.\n', '- X-Y plane is the table surface, and Z is the height.\n', '- The predicted video from the previous step to the current step will be shown in the "Predicted video" section.\n', '- The Gaussian splats after the current step will be shown in the "Predicted Gaussians" section.\n', '- The simulation results may deviate from the initial shape due to accumulative prediction artifacts. Click the Reset button to reset the simulation state and reinitialize the predicted video and Gaussian splats.\n', ])) with gr.Row(): gr.Markdown('### Select a scene to reset the simulation:') with gr.Row(): with gr.Column(scale=2): with gr.Row(): with gr.Column(): run_reset_plush = gr.Button("Reset - Plush") with gr.Column(): run_reset_rope = gr.Button("Reset - Rope") with gr.Column(scale=2): _ = gr.Button(visible=False) # empty placeholder with gr.Row(): with gr.Column(scale=2): form_video = gr.Video( label='Predicted video', value=None, format='mp4', width=self.width, height=self.height, ) with gr.Column(scale=2): form_3dgs_pred = gr.Model3D( label='Predicted Gaussians', height=self.height, value=None, clear_color=[0, 0, 0, 0], ) # Layout with gr.Row(): gr.Markdown('### Control the gripper to move in the x, y, z directions:') with gr.Row(): with gr.Column(scale=2): with gr.Row(): with gr.Column(): run_xminus = gr.Button("x-") with gr.Column(): run_xplus = gr.Button("x+") with gr.Row(): with gr.Column(): run_yminus = gr.Button("y-") with gr.Column(): run_yplus = gr.Button("y+") with gr.Row(): with gr.Column(): run_zminus = gr.Button("z-") with gr.Column(): run_zplus = gr.Button("z+") with gr.Column(scale=2): _ = gr.Button(visible=False) # empty placeholder # Set up callbacks run_reset_rope.click(self.reset_rope, inputs=[], outputs=[form_video, form_3dgs_pred, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name]) run_reset_plush.click(self.reset_plush, inputs=[], outputs=[form_video, form_3dgs_pred, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name]) run_xplus.click(self.on_click_run_xplus, inputs=[preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name], outputs=[form_video, form_3dgs_pred, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name]) run_xminus.click(self.on_click_run_xminus, inputs=[preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name], outputs=[form_video, form_3dgs_pred, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name]) run_yplus.click(self.on_click_run_yplus, inputs=[preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name], outputs=[form_video, form_3dgs_pred, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name]) run_yminus.click(self.on_click_run_yminus, inputs=[preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name], outputs=[form_video, form_3dgs_pred, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name]) run_zplus.click(self.on_click_run_zplus, inputs=[preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name], outputs=[form_video, form_3dgs_pred, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name]) run_zminus.click(self.on_click_run_zminus, inputs=[preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name], outputs=[form_video, form_3dgs_pred, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name]) app.launch(share=share) if __name__ == '__main__': visualizer = DynamicsVisualizer() visualizer.launch(share=True)