pgnd / app.py
kaifz's picture
add sloth and axis
13cdb4b
import gradio as gr
import sys
import site
from PIL import Image
from pathlib import Path
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm, trange
import random
import math
import hydra
import numpy as np
import glob
import os
import subprocess
import time
import cv2
import copy
import yaml
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
import spaces
from spaces import zero
zero.startup()
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import kornia
from diff_gaussian_rasterization import GaussianRasterizer
from diff_gaussian_rasterization import GaussianRasterizationSettings as Camera
import sys
sys.path.insert(0, str(Path(__file__).parent / "src"))
sys.path.append(str(Path(__file__).parent / "src" / "experiments"))
from real_world.utils.render_utils import interpolate_motions
from real_world.gs.helpers import setup_camera
from real_world.gs.convert import save_to_splat, read_splat
root = Path(__file__).parent / "src" / "experiments"
def make_video(
image_root: Path,
video_path: Path,
image_pattern: str = '%04d.png',
frame_rate: int = 10):
subprocess.run([
'ffmpeg',
'-y',
'-hide_banner',
'-loglevel', 'error',
'-framerate', str(frame_rate),
'-i', str(image_root / image_pattern),
'-c:v', 'libx264',
'-pix_fmt', 'yuv420p',
str(video_path)
])
def quat2mat(quat):
import kornia
return kornia.geometry.conversions.quaternion_to_rotation_matrix(quat)
def mat2quat(mat):
import kornia
return kornia.geometry.conversions.rotation_matrix_to_quaternion(mat)
def fps(x, enabled, n, device, random_start=False):
import torch
from dgl.geometry import farthest_point_sampler
assert torch.diff(enabled * 1.0).sum() in [0.0, -1.0]
start_idx = random.randint(0, enabled.sum() - 1) if random_start else 0
fps_idx = farthest_point_sampler(x[enabled][None], n, start_idx=start_idx)[0]
fps_idx = fps_idx.to(x.device)
return fps_idx
class DynamicsVisualizer:
def __init__(self, wp_device='cuda', torch_device='cuda'):
self.best_models = {
'cloth': ['cloth', 'train', 100000, [610, 650]],
'rope': ['rope', 'train', 100000, [651, 691]],
'paperbag': ['paperbag', 'train', 100000, [200, 220]],
'sloth': ['sloth', 'train', 100000, [113, 133]],
'box': ['box', 'train', 100000, [306, 323]],
'bread': ['bread', 'train', 100000, [143, 163]],
}
task_name = 'rope'
self.init(task_name)
def init(self, task_name):
self.width = 640
self.height = 480
self.task_name = task_name
with open(root / f'log/{self.best_models[task_name][0]}/{self.best_models[task_name][1]}/hydra.yaml', 'r') as f:
config = yaml.load(f, Loader=yaml.CLoader)
cfg = OmegaConf.create(config)
cfg.iteration = self.best_models[task_name][2]
cfg.start_episode = self.best_models[task_name][3][0]
cfg.end_episode = self.best_models[task_name][3][1]
cfg.sim.num_steps = 1000
cfg.sim.gripper_forcing = False
cfg.sim.uniform = True
cfg.sim.use_pv = False
device = torch.device('cuda')
self.cfg = cfg
self.device = device
self.k_rel = 8 # knn for relations
self.k_wgt = 16 # knn for weights
self.with_bg = True
self.render_gripper = True
self.render_direction = True
self.verbose = False
self.dt_base = cfg.sim.dt
self.high_freq_pred = True
seed = cfg.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# torch.autograd.set_detect_anomaly(True)
# torch.backends.cudnn.benchmark = True
self.clear()
def clear(self, clear_params=True):
self.metadata = {}
self.config = {}
if clear_params:
self.params = None
self.state = {
# object
'x': None,
'v': None,
'x_his': None,
'v_his': None,
'x_pred': None,
'v_pred': None,
'clip_bound': None,
'enabled': None,
# robot
'prev_key_pos': None,
'prev_key_pos_timestamp': None,
'sub_pos': None, # filling in between key positions
'sub_pos_timestamps': None,
'gripper_radius': None,
}
self.preprocess_metadata = None
self.table_params = None
self.gripper_params = None
self.sim = None
self.statics = None
self.colliders = None
self.material = None
self.friction = None
def load_scaniverse(self, data_path):
### load splat params
params_obj = read_splat(data_path / 'object.splat')
params_table = read_splat(data_path / 'table.splat')
params_robot = read_splat(data_path / 'gripper.splat')
pts, colors, scales, quats, opacities = params_obj
self.params = {
'means3D': torch.from_numpy(pts).to(torch.float32).to(self.device),
'rgb_colors': torch.from_numpy(colors).to(torch.float32).to(self.device),
'log_scales': torch.log(torch.from_numpy(scales).to(torch.float32).to(self.device)),
'unnorm_rotations': torch.from_numpy(quats).to(torch.float32).to(self.device),
'logit_opacities': torch.logit(torch.from_numpy(opacities).to(torch.float32).to(self.device))
}
t_pts, t_colors, t_scales, t_quats, t_opacities = params_table
t_pts = torch.tensor(t_pts).to(torch.float32).to(self.device)
t_colors = torch.tensor(t_colors).to(torch.float32).to(self.device)
t_scales = torch.tensor(t_scales).to(torch.float32).to(self.device)
t_quats = torch.tensor(t_quats).to(torch.float32).to(self.device)
t_opacities = torch.tensor(t_opacities).to(torch.float32).to(self.device)
g_pts, g_colors, g_scales, g_quats, g_opacities = params_robot
g_pts = torch.tensor(g_pts).to(torch.float32).to(self.device)
g_colors = torch.tensor(g_colors).to(torch.float32).to(self.device)
g_scales = torch.tensor(g_scales).to(torch.float32).to(self.device)
g_quats = torch.tensor(g_quats).to(torch.float32).to(self.device)
g_opacities = torch.tensor(g_opacities).to(torch.float32).to(self.device)
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities # data frame
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities # data frame
n_particles = self.cfg.sim.n_particles
self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32)
self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool)
### load preprocess metadata
cfg = self.cfg
dx = cfg.sim.num_grids[-1]
p_x = torch.tensor(pts).to(torch.float32).to(self.device)
R = torch.tensor(
[[1, 0, 0],
[0, 0, -1],
[0, 1, 0]]
).to(p_x.device).to(p_x.dtype)
p_x_rotated = p_x @ R.T
scale = 1.0
p_x_rotated_scaled = p_x_rotated * scale
global_translation = torch.tensor([
0.5 - p_x_rotated_scaled[:, 0].mean(),
dx * (cfg.model.clip_bound + 0.5) - p_x_rotated_scaled[:, 1].min(),
0.5 - p_x_rotated_scaled[:, 2].mean(),
], dtype=p_x_rotated_scaled.dtype, device=p_x_rotated_scaled.device)
R_viewer = torch.tensor(
[[1, 0, 0],
[0, 0, -1],
[0, 1, 0]]
).to(p_x.device).to(p_x.dtype)
t_viewer = torch.tensor([0, 0, 0]).to(p_x.device).to(p_x.dtype)
self.preprocess_metadata = {
'R': R,
'R_viewer': R_viewer,
't_viewer': t_viewer,
'scale': scale,
'global_translation': global_translation,
}
### load eef
grippers = np.loadtxt(data_path / 'eef_xyz.txt')[None]
assert grippers.shape == (1, 3)
if grippers is not None:
grippers = torch.tensor(grippers).to(self.device).to(torch.float32)
# transform
# data frame to model frame
R = self.preprocess_metadata['R']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
grippers[:, :3] = grippers[:, :3] @ R.T
grippers[:, :3] = grippers[:, :3] * scale
grippers[:, :3] += global_translation
assert grippers.shape[0] == 1
self.state['prev_key_pos'] = grippers[:, :3] # (1, 3)
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32)
self.state['gripper_radius'] = cfg.model.gripper_radius
def load_eef(self, grippers=None, eef_t=None):
assert self.state['prev_key_pos'] is None
if grippers is not None:
grippers = torch.tensor(grippers).to(self.device).to(torch.float32)
eef_t = torch.tensor(eef_t).to(self.device).to(torch.float32)
grippers[:, :3] = grippers[:, :3] + eef_t
# transform
# data frame to model frame
R = self.preprocess_metadata['R']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
grippers[:, :3] = grippers[:, :3] @ R.T
grippers[:, :3] = grippers[:, :3] * scale
grippers[:, :3] += global_translation
assert grippers.shape[0] == 1
self.state['prev_key_pos'] = grippers[:, :3] # (1, 3)
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32) + 0.001
self.state['gripper_radius'] = self.cfg.model.gripper_radius
def load_preprocess_metadata(self, p_x_orig):
cfg = self.cfg
dx = cfg.sim.num_grids[-1]
p_x_orig = p_x_orig.to(self.device)
R = torch.tensor(
[[1, 0, 0],
[0, 0, -1],
[0, 1, 0]]
).to(p_x_orig.device).to(p_x_orig.dtype)
p_x_orig_rotated = torch.einsum('nij,jk->nik', p_x_orig, R.T)
scale = 1.0
p_x_orig_rotated_scaled = p_x_orig_rotated * scale
global_translation = torch.tensor([
0.5 - p_x_orig_rotated_scaled[:, :, 0].mean(),
dx * (cfg.model.clip_bound + 0.5) - p_x_orig_rotated_scaled[:, :, 1].min(),
0.5 - p_x_orig_rotated_scaled[:, :, 2].mean(),
], dtype=p_x_orig_rotated_scaled.dtype, device=p_x_orig_rotated_scaled.device)
R_viewer = torch.tensor(
[[1, 0, 0],
[0, 0, -1],
[0, 1, 0]]
).to(p_x_orig.device).to(p_x_orig.dtype)
t_viewer = torch.tensor([0, 0, 0]).to(p_x_orig.device).to(p_x_orig.dtype)
self.preprocess_metadata = {
'R': R,
'R_viewer': R_viewer,
't_viewer': t_viewer,
'scale': scale,
'global_translation': global_translation,
}
# @torch.no_grad
def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]):
render_data = {k: v.to(self.device) for k, v in render_data.items()}
w, h = self.metadata['w'], self.metadata['h']
k, w2c = self.metadata['k'], self.metadata['w2c']
cam = setup_camera(w, h, k, w2c, self.config['near'], self.config['far'], bg)
im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data)
return im, depth
def knn_relations(self, bones):
k = self.k_rel
knn = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(bones.detach().cpu().numpy())
_, indices = knn.kneighbors(bones.detach().cpu().numpy()) # (N, k)
indices = indices[:, 1:] # exclude self
return indices
def knn_weights_brute(self, bones, pts):
k = self.k_wgt
dist = torch.norm(pts[:, None] - bones, dim=-1) # (n_pts, n_bones)
_, indices = torch.topk(dist, k, dim=-1, largest=False)
bones_selected = bones[indices] # (N, k, 3)
dist = torch.norm(bones_selected - pts[:, None], dim=-1) # (N, k)
weights = 1 / (dist + 1e-6)
weights = weights / weights.sum(dim=-1, keepdim=True) # (N, k)
weights_all = torch.zeros((pts.shape[0], bones.shape[0]), device=pts.device)
weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights
return weights_all
def update_camera(self, k, w2c, w=None, h=None, near=0.01, far=100.0):
self.metadata['k'] = k
self.metadata['w2c'] = w2c
if w is not None:
self.metadata['w'] = w
if h is not None:
self.metadata['h'] = h
self.config['near'] = near
self.config['far'] = far
def init_model(self, batch_size, num_steps, num_particles, ckpt_path=None):
from pgnd.sim import Friction, CacheDiffSimWithFrictionBatch, StaticsBatch, CollidersBatch
from pgnd.material import PGNDModel
self.cfg.sim.num_steps = num_steps
cfg = self.cfg
sim = CacheDiffSimWithFrictionBatch(cfg, num_steps, batch_size, self.wp_device, requires_grad=True)
statics = StaticsBatch()
statics.init(shape=(batch_size, num_particles), device=self.wp_device)
statics.update_clip_bound(self.state['clip_bound'].detach().cpu())
statics.update_enabled(self.state['enabled'][None].detach().cpu())
colliders = CollidersBatch()
colliders.init(shape=(batch_size, cfg.sim.num_grippers), device=self.wp_device)
self.sim = sim
self.statics = statics
self.colliders = colliders
# load ckpt
ckpt_path = root / f'log/{self.task_name}/train/ckpt/100000.pt'
ckpt = torch.load(ckpt_path, map_location=self.torch_device)
material: nn.Module = PGNDModel(cfg)
material.to(self.torch_device)
material.load_state_dict(ckpt['material'])
material.requires_grad_(False)
material.eval()
if 'friction' in ckpt:
friction = ckpt['friction']['mu'].reshape(-1, 1)
else:
friction = torch.tensor(cfg.model.friction.value, device=self.torch_device).reshape(-1, 1)
self.material = material
self.friction = friction
def reload_model(self, num_steps): # only change num_steps
from pgnd.sim import CacheDiffSimWithFrictionBatch
self.cfg.sim.num_steps = num_steps
sim = CacheDiffSimWithFrictionBatch(self.cfg, num_steps, 1, self.wp_device, requires_grad=True)
self.sim = sim
# @torch.no_grad
def step(self):
cfg = self.cfg
batch_size = 1
num_steps = 1
num_particles = cfg.sim.n_particles
# update state by previous prediction
self.state['x_his'] = torch.cat([self.state['x_his'][1:], self.state['x'][None]], dim=0)
self.state['v_his'] = torch.cat([self.state['v_his'][1:], self.state['v'][None]], dim=0)
self.state['x'] = self.state['x_pred'].clone()
self.state['v'] = self.state['v_pred'].clone()
eef_xyz_key = self.state['prev_key_pos'] # (1, 3), model frame
eef_xyz_sub = self.state['sub_pos'] # (T, 1, 3), model frame
if eef_xyz_sub is None:
return
# eef_xyz_key_timestamp = self.state['prev_key_pos_timestamp']
# eef_xyz_sub_timestamps = self.state['sub_pos_timestamps']
# assert eef_xyz_key_timestamp.item() > 0
# delta_t = (eef_xyz_sub_timestamps[-1] - eef_xyz_key_timestamp).item()
# if (not self.high_freq_pred) and delta_t < self.dt_base * 0.9:
# return
# cfg.sim.dt = delta_t
eef_xyz_key_next = eef_xyz_sub[-1] # (1, 3), model frame
eef_v = (eef_xyz_key_next - eef_xyz_key) / cfg.sim.dt
if self.verbose:
print('delta_t:', np.round(cfg.sim.dt, 4))
print('eef_xyz_key_next:', eef_xyz_key_next.cpu().numpy().tolist())
print('eef_xyz_key:', eef_xyz_key.cpu().numpy().tolist())
print('v:', eef_v.cpu().numpy().tolist())
# load model, sim, statics, colliders
# self.reload_model(num_steps)
# initialize colliders
if cfg.sim.num_grippers > 0:
grippers = torch.zeros((batch_size, cfg.sim.num_grippers, 15), device=self.torch_device)
eef_quat = torch.tensor([1, 0, 0, 0], dtype=torch.float32, device=self.torch_device).repeat(batch_size, cfg.sim.num_grippers, 1) # (B, G, 4)
eef_quat_vel = torch.zeros((batch_size, cfg.sim.num_grippers, 3), dtype=torch.float32, device=self.torch_device)
eef_gripper = torch.zeros((batch_size, cfg.sim.num_grippers), dtype=torch.float32, device=self.torch_device)
grippers[:, :, :3] = eef_xyz_key
grippers[:, :, 3:6] = eef_v
grippers[:, :, 6:10] = eef_quat
grippers[:, :, 10:13] = eef_quat_vel
grippers[:, :, 13] = cfg.model.gripper_radius
grippers[:, :, 14] = eef_gripper
self.colliders.initialize_grippers(grippers)
x = self.state['x'].clone()[None].repeat(batch_size, 1, 1)
v = self.state['v'].clone()[None].repeat(batch_size, 1, 1)
x_his = self.state['x_his'].permute(1, 0, 2).clone()
assert x_his.shape[0] == num_particles
x_his = x_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1)
v_his = self.state['v_his'].permute(1, 0, 2).clone()
assert v_his.shape[0] == num_particles
v_his = v_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1)
enabled = self.state['enabled'].clone().to(self.torch_device)[None].repeat(batch_size, 1)
for t in range(num_steps):
x_in = x.clone()
pred = self.material(x, v, x_his, v_his, enabled)
# x_his = torch.cat([x_his.reshape(batch_size, num_particles, -1, 3)[:, :, 1:], x[:, :, None].detach()], dim=2)
# v_his = torch.cat([v_his.reshape(batch_size, num_particles, -1, 3)[:, :, 1:], v[:, :, None].detach()], dim=2)
# x_his = x_his.reshape(batch_size, num_particles, -1)
# v_his = v_his.reshape(batch_size, num_particles, -1)
x, v = self.sim(self.statics, self.colliders, t, x, v, self.friction, pred)
# calculate new x_pred, v_pred, eef_xyz_key and eef_xyz_sub
x_pred = x[0].clone()
v_pred = v[0].clone()
self.state['x_pred'] = x_pred
self.state['v_pred'] = v_pred
# self.state['x_his'] = x_his[0].reshape(num_particles, self.cfg.sim.n_history, 3).permute(1, 0, 2).clone()
# self.state['v_his'] = v_his[0].reshape(num_particles, self.cfg.sim.n_history, 3).permute(1, 0, 2).clone()
self.state['prev_key_pos'] = eef_xyz_key_next
# self.state['prev_key_pos_timestamp'] = eef_xyz_sub_timestamps[-1]
self.state['sub_pos'] = None
# self.state['sub_pos_timestamps'] = None
def preprocess_x(self, p_x): # viewer frame to model frame (not data frame)
R = self.preprocess_metadata['R']
R_viewer = self.preprocess_metadata['R_viewer']
t_viewer = self.preprocess_metadata['t_viewer']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
# viewer frame to model frame
p_x = (p_x - t_viewer) @ R_viewer
# model frame to data frame
# p_x -= global_translation
# p_x = p_x / scale
# p_x = p_x @ torch.linalg.inv(R).T
return p_x
def preprocess_gripper(self, grippers): # viewer frame to model frame (not data frame)
R = self.preprocess_metadata['R']
R_viewer = self.preprocess_metadata['R_viewer']
t_viewer = self.preprocess_metadata['t_viewer']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
# viewer frame to model frame
grippers[:, :3] = grippers[:, :3] @ R_viewer
return grippers
def inverse_preprocess_x(self, p_x): # model frame (not data frame) to viewer frame
R = self.preprocess_metadata['R']
R_viewer = self.preprocess_metadata['R_viewer']
t_viewer = self.preprocess_metadata['t_viewer']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
# model frame to viewer frame
p_x = p_x @ R_viewer.T + t_viewer
return p_x
def inverse_preprocess_gripper(self, grippers): # model frame (not data frame) to viewer frame
R = self.preprocess_metadata['R']
R_viewer = self.preprocess_metadata['R_viewer']
t_viewer = self.preprocess_metadata['t_viewer']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
# model frame to viewer frame
grippers[:, :3] = grippers[:, :3] @ R_viewer.T + t_viewer
return grippers
def rotate(self, params, rot_mat):
scale = np.linalg.norm(rot_mat, axis=1, keepdims=True)
params = {
'means3D': pts,
'rgb_colors': params['rgb_colors'],
'log_scales': params['log_scales'],
'unnorm_rotations': quats,
'logit_opacities': params['logit_opacities'],
}
return params
def preprocess_gs(self, params):
if isinstance(params, dict):
xyz = params['means3D']
rgb = params['rgb_colors']
quat = torch.nn.functional.normalize(params['unnorm_rotations'])
opa = torch.sigmoid(params['logit_opacities'])
scales = torch.exp(params['log_scales'])
else:
assert isinstance(params, tuple)
xyz, rgb, quat, opa, scales = params
quat = torch.nn.functional.normalize(quat, dim=-1)
# transform
R = self.preprocess_metadata['R']
R_viewer = self.preprocess_metadata['R_viewer']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
mat = quat2mat(quat)
mat = R @ mat
xyz = xyz @ R.T
xyz = xyz * scale
xyz += global_translation
quat = mat2quat(mat)
scales = scales * scale
# viewer-specific transform (flip y and z)
# model frame to viewer frame
xyz = xyz @ R_viewer.T
quat = mat2quat(R_viewer @ quat2mat(quat))
t_viewer = -xyz.mean(dim=0)
t_viewer[2] = 0
xyz += t_viewer
print('Overwriting t_viewer to be the planar mean of the object')
self.preprocess_metadata['t_viewer'] = t_viewer
if isinstance(params, dict):
params['means3D'] = xyz
params['rgb_colors'] = rgb
params['unnorm_rotations'] = quat
params['logit_opacities'] = opa
params['log_scales'] = torch.log(scales)
else:
params = xyz, rgb, quat, opa, scales
return params
def preprocess_bg_gs(self):
t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params
g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params
# identify tip first
g_pts_tip_z = g_pts[:, 2].max()
g_pts_tip_mask = (g_pts[:, 2] > g_pts_tip_z - 0.04) & (g_pts[:, 2] < g_pts_tip_z)
R = self.preprocess_metadata['R']
R_viewer = self.preprocess_metadata['R_viewer']
t_viewer = self.preprocess_metadata['t_viewer']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
t_mat = quat2mat(t_quats)
t_mat = R @ t_mat
t_pts = t_pts @ R.T
t_pts = t_pts * scale
t_pts += global_translation
t_quats = mat2quat(t_mat)
t_scales = t_scales * scale
t_pts = t_pts @ R_viewer.T
t_quats = mat2quat(R_viewer @ quat2mat(t_quats))
t_pts += t_viewer
axes = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
dirs = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] # x, y, z axes
for ee in range(3):
gripper_direction = torch.tensor(dirs[ee], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3)
gripper_direction = gripper_direction / (torch.norm(gripper_direction, dim=-1, keepdim=True) + 1e-10) # normalize
R = self.preprocess_metadata['R']
# model frame to data frame
direction = gripper_direction @ R.T
n_grippers = 1
N = 200
length = 0.2
kk = 5
xyz_test = torch.zeros((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=t_pts.dtype)
if self.task_name == 'rope':
pos = torch.tensor([0.0, 0.0, 1.2], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) # gripper position in model frame
else:
pos = torch.tensor([1.2, 0.0, 0.7], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3)
gripper_now_inv_xyz = self.inverse_preprocess_gripper(pos)
gripper_now_inv_rot = torch.eye(3, device=self.torch_device).unsqueeze(0).repeat(n_grippers, 1, 1)
center_point = torch.tensor([0.0, 0.0, 0.10], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) # center point in gripper frame
gripper_center_inv_xyz = gripper_now_inv_xyz + \
torch.einsum('ijk,ik->ij', gripper_now_inv_rot, center_point) # (n_grippers, 3)
for i in range(N):
offset = i / N * length * direction
xyz_test[:, i] = gripper_center_inv_xyz + offset
if direction[0, 2] < 0.9 and direction[0, 2] > -0.9: # not vertical
direction_up = -direction + torch.tensor([0.0, 0.0, 0.5], device=self.torch_device, dtype=t_pts.dtype)
direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) # normalize
direction_down = -direction + torch.tensor([0.0, 0.0, -0.5], device=self.torch_device, dtype=t_pts.dtype)
direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) # normalize
else:
direction_up = -direction + torch.tensor([0.0, 0.5, 0.0], device=self.torch_device, dtype=t_pts.dtype)
direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) # normalize
direction_down = -direction + torch.tensor([0.0, -0.5, 0.0], device=self.torch_device, dtype=t_pts.dtype)
direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) # normalize
for i in range(N, N + N // kk):
offset = length * direction + (i - N) / N * length * direction_up
xyz_test[:, i] = gripper_center_inv_xyz + offset
for i in range(N + N // kk, N + N // kk + N // kk):
offset = length * direction + (i - N - N // kk) / N * length * direction_down
xyz_test[:, i] = gripper_center_inv_xyz + offset
color_test = torch.zeros_like(xyz_test, device=self.torch_device, dtype=t_pts.dtype)
color_test[:, :, 0] = axes[ee][0]
color_test[:, :, 1] = axes[ee][1]
color_test[:, :, 2] = axes[ee][2]
quat_test = torch.zeros((n_grippers, N + N // kk + N // kk, 4), device=self.torch_device, dtype=t_pts.dtype)
quat_test[:, :, 0] = 1.0 # identity quaternion
opa_test = torch.ones((n_grippers, N + N // kk + N // kk, 1), device=self.torch_device, dtype=t_pts.dtype)
scales_test = torch.ones((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=t_pts.dtype) * 0.002
t_pts = torch.cat([t_pts, xyz_test.reshape(-1, 3)], dim=0)
t_colors = torch.cat([t_colors, color_test.reshape(-1, 3)], dim=0)
t_quats = torch.cat([t_quats, quat_test.reshape(-1, 4)], dim=0)
t_opacities = torch.cat([t_opacities, opa_test.reshape(-1, 1)], dim=0)
t_scales = torch.cat([t_scales, scales_test.reshape(-1, 3)], dim=0)
t_pts = t_pts.reshape(-1, 3)
t_colors = t_colors.reshape(-1, 3)
t_quats = t_quats.reshape(-1, 4)
t_opacities = t_opacities.reshape(-1, 1)
g_mat = quat2mat(g_quats)
g_mat = R @ g_mat
g_pts = g_pts @ R.T
g_pts = g_pts * scale
g_pts += global_translation
g_quats = mat2quat(g_mat)
g_scales = g_scales * scale
g_pts = g_pts @ R_viewer.T
g_quats = mat2quat(R_viewer @ quat2mat(g_quats))
g_pts += t_viewer
# TODO: center gripper in the viewer frame
g_pts_tip = g_pts[g_pts_tip_mask]
g_pts_tip_mean_xy = g_pts_tip[:, :2].mean(dim=0)
if self.task_name == 'rope':
g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.23]).to(torch.float32).to(self.device)
elif self.task_name == 'sloth':
g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.32]).to(torch.float32).to(self.device)
else:
raise NotImplementedError(f"Task {self.task_name} not implemented for gripper translation.")
g_pts = g_pts + g_pts_translation
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities
def update_rendervar(self, rendervar):
p_x = self.state['x']
p_x_viewer = self.inverse_preprocess_x(p_x)
p_x_pred = self.state['x_pred']
p_x_pred_viewer = self.inverse_preprocess_x(p_x_pred)
xyz = rendervar['means3D']
rgb = rendervar['colors_precomp']
quat = rendervar['rotations']
opa = rendervar['opacities']
scales = rendervar['scales']
relations = self.knn_relations(p_x_viewer)
weights = self.knn_weights_brute(p_x_viewer, xyz)
xyz, quat, _ = interpolate_motions(
bones=p_x_viewer,
motions=p_x_pred_viewer - p_x_viewer,
relations=relations,
weights=weights,
xyz=xyz,
quat=quat,
)
# normalize
quat = torch.nn.functional.normalize(quat, dim=-1)
rendervar = {
'means3D': xyz,
'colors_precomp': rgb,
'rotations': quat,
'opacities': opa,
'scales': scales,
'means2D': torch.zeros_like(xyz),
}
if self.with_bg:
t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params
# merge
xyz = torch.cat([xyz, t_pts], dim=0)
rgb = torch.cat([rgb, t_colors], dim=0)
quat = torch.cat([quat, t_quats], dim=0)
opa = torch.cat([opa, t_opacities], dim=0)
scales = torch.cat([scales, t_scales], dim=0)
if self.render_gripper:
g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params
# add gripper pos
g_pts = g_pts + self.inverse_preprocess_gripper(self.state['prev_key_pos'][None].clone())[0]
# merge
xyz = torch.cat([xyz, g_pts], dim=0)
rgb = torch.cat([rgb, g_colors], dim=0)
quat = torch.cat([quat, g_quats], dim=0)
opa = torch.cat([opa, g_opacities], dim=0)
scales = torch.cat([scales, g_scales], dim=0)
if self.render_direction:
gripper_direction = self.gripper_direction
gripper_direction = gripper_direction / (torch.norm(gripper_direction, dim=-1, keepdim=True) + 1e-10) # normalize
R = self.preprocess_metadata['R']
# model frame to data frame
direction = gripper_direction @ R.T
n_grippers = 1
N = 200
length = 0.2
kk = 5
xyz_test = torch.zeros((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=xyz.dtype)
gripper_now_inv_xyz = self.inverse_preprocess_gripper(self.state['prev_key_pos'][None].clone())
gripper_now_inv_rot = torch.eye(3, device=self.torch_device).unsqueeze(0).repeat(n_grippers, 1, 1)
center_point = torch.tensor([0.0, 0.0, 0.10], device=self.torch_device, dtype=xyz.dtype).reshape(1, 3) # center point in gripper frame
gripper_center_inv_xyz = gripper_now_inv_xyz + \
torch.einsum('ijk,ik->ij', gripper_now_inv_rot, center_point) # (n_grippers, 3)
for i in range(N):
offset = i / N * length * direction
xyz_test[:, i] = gripper_center_inv_xyz + offset
if direction[0, 2] < 0.9 and direction[0, 2] > -0.9: # not vertical
direction_up = -direction + torch.tensor([0.0, 0.0, 0.5], device=self.torch_device, dtype=xyz.dtype)
direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) # normalize
direction_down = -direction + torch.tensor([0.0, 0.0, -0.5], device=self.torch_device, dtype=xyz.dtype)
direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) # normalize
else:
direction_up = -direction + torch.tensor([0.0, 0.5, 0.0], device=self.torch_device, dtype=xyz.dtype)
direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) # normalize
direction_down = -direction + torch.tensor([0.0, -0.5, 0.0], device=self.torch_device, dtype=xyz.dtype)
direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) # normalize
for i in range(N, N + N // kk):
offset = length * direction + (i - N) / N * length * direction_up
xyz_test[:, i] = gripper_center_inv_xyz + offset
for i in range(N + N // kk, N + N // kk + N // kk):
offset = length * direction + (i - N - N // kk) / N * length * direction_down
xyz_test[:, i] = gripper_center_inv_xyz + offset
color_test = torch.zeros_like(xyz_test, device=self.torch_device, dtype=xyz.dtype)
color_test[:, :, 0] = 255 / 255 # red
color_test[:, :, 1] = 80 / 255 # green
color_test[:, :, 2] = 110 / 255 # blue
quat_test = torch.zeros((n_grippers, N + N // kk + N // kk, 4), device=self.torch_device, dtype=xyz.dtype)
quat_test[:, :, 0] = 1.0 # identity quaternion
opa_test = torch.ones((n_grippers, N + N // kk + N // kk, 1), device=self.torch_device, dtype=xyz.dtype)
scales_test = torch.ones((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=xyz.dtype) * 0.002
xyz = torch.cat([xyz, xyz_test.reshape(-1, 3)], dim=0)
rgb = torch.cat([rgb, color_test.reshape(-1, 3)], dim=0)
quat = torch.cat([quat, quat_test.reshape(-1, 4)], dim=0)
opa = torch.cat([opa, opa_test.reshape(-1, 1)], dim=0)
scales = torch.cat([scales, scales_test.reshape(-1, 3)], dim=0)
# normalize
quat = torch.nn.functional.normalize(quat, dim=-1)
rendervar_full = {
'means3D': xyz,
'colors_precomp': rgb,
'rotations': quat,
'opacities': opa,
'scales': scales,
'means2D': torch.zeros_like(xyz),
}
else:
rendervar_full = rendervar
return rendervar, rendervar_full
def reset_state(self, params, visualize_image=False, init=False):
xyz_0 = params['means3D']
rgb_0 = params['rgb_colors']
quat_0 = torch.nn.functional.normalize(params['unnorm_rotations'])
opa_0 = torch.sigmoid(params['logit_opacities'])
scales_0 = torch.exp(params['log_scales'])
rendervar_init = {
'means3D': xyz_0,
'colors_precomp': rgb_0,
'rotations': quat_0,
'opacities': opa_0,
'scales': scales_0,
'means2D': torch.zeros_like(xyz_0),
} # before preprocess
w = self.width
h = self.height
center = (0, 0, 0.1)
distance = 0.7
elevation = 20
azimuth = 180.0 if self.task_name == 'rope' else 120.0
target = np.array(center)
theta = 90 + azimuth
z = distance * math.sin(math.radians(elevation))
y = math.cos(math.radians(theta)) * distance * math.cos(math.radians(elevation))
x = math.sin(math.radians(theta)) * distance * math.cos(math.radians(elevation))
origin = target + np.array([x, y, z])
look_at = target - origin
look_at /= np.linalg.norm(look_at)
up = np.array([0.0, 0.0, 1.0])
right = np.cross(look_at, up)
right /= np.linalg.norm(right)
up = np.cross(right, look_at)
w2c = np.eye(4)
w2c[:3, 0] = right
w2c[:3, 1] = -up
w2c[:3, 2] = look_at
w2c[:3, 3] = origin
w2c = np.linalg.inv(w2c)
k = np.array(
[[w / 2 * 1.0, 0., w / 2],
[0., w / 2 * 1.0, h / 2],
[0., 0., 1.]],
)
self.metadata = {}
self.config = {}
self.update_camera(k, w2c, w, h)
n_particles = self.cfg.sim.n_particles
downsample_indices = fps(xyz_0, torch.ones_like(xyz_0[:, 0]).to(torch.bool), n_particles, self.torch_device)
p_x_viewer = xyz_0[downsample_indices]
p_x = self.preprocess_x(p_x_viewer)
self.state['x'] = p_x
self.state['v'] = torch.zeros_like(p_x)
self.state['x_his'] = p_x[None].repeat(self.cfg.sim.n_history, 1, 1)
self.state['v_his'] = torch.zeros_like(p_x[None].repeat(self.cfg.sim.n_history, 1, 1))
self.state['x_pred'] = p_x
self.state['v_pred'] = torch.zeros_like(p_x)
rendervar_init, rendervar_init_full = self.update_rendervar(rendervar_init)
im, depth = self.render(rendervar_init_full, 0, bg=[0.0, 0.0, 0.0])
im_vis = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8)
return rendervar_init
def reset(self, task_name, scene_name):
self.init(task_name)
import warp as wp
wp.init()
gpus = [int(gpu) for gpu in self.cfg.gpus]
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
device_count = len(torch_devices)
assert device_count == 1
self.wp_device = wp_devices[0]
self.torch_device = torch_devices[0]
in_dir = root / f'log/gs/ckpts/{scene_name}'
batch_size = 1
num_steps = 1
num_particles = self.cfg.sim.n_particles
self.load_scaniverse(in_dir)
self.init_model(batch_size, num_steps, num_particles, ckpt_path=None)
self.render_direction = False
params = self.preprocess_gs(self.params)
if self.with_bg:
self.preprocess_bg_gs()
rendervar = self.reset_state(params, visualize_image=False, init=True)
rendervar, rendervar_full = self.update_rendervar(rendervar)
# self.rendervar = rendervar
im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0])
im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy()
cv2.imwrite(str(root / 'log/temp_init/0000.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR))
make_video(root / 'log/temp_init', root / f'log/gs/temp/form_video_init.mp4', '%04d.png', 1)
gs_pred = save_to_splat(
rendervar_full['means3D'].cpu().numpy(),
rendervar_full['colors_precomp'].cpu().numpy(),
rendervar_full['scales'].cpu().numpy(),
rendervar_full['rotations'].cpu().numpy(),
rendervar_full['opacities'].cpu().numpy(),
root / 'log/gs/temp/gs_pred.splat',
rot_rev=True,
)
for k, v in self.preprocess_metadata.items():
self.preprocess_metadata[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
for k, v in self.state.items():
self.state[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
for k, v in self.params.items():
if isinstance(v, dict):
for k2, v2 in v.items():
self.params[k][k2] = v2.detach().cpu() if isinstance(v2, torch.Tensor) else v2
else:
self.params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
self.table_params = tuple(
v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.table_params
)
self.gripper_params = tuple(
v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.gripper_params
)
for k, v in rendervar.items():
rendervar[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
form_video = gr.Video(
label='Predicted video',
value=root / f'log/gs/temp/form_video_init.mp4',
format='mp4',
width=self.width,
height=self.height,
)
form_3dgs_pred = gr.Model3D(
label='Predicted Gaussian Splats',
height=self.height,
value=root / 'log/gs/temp/gs_pred.splat',
clear_color=[0, 0, 0, 0],
)
return form_video, form_3dgs_pred, \
self.preprocess_metadata, self.state, self.params, \
self.table_params, self.gripper_params, rendervar, task_name
def run_command(self, unit_command, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name):
self.task_name = task_name
import warp as wp
wp.init()
gpus = [int(gpu) for gpu in self.cfg.gpus]
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
device_count = len(torch_devices)
assert device_count == 1
self.wp_device = wp_devices[0]
self.torch_device = torch_devices[0]
os.system('rm -rf ' + str(root / 'log/temp/*'))
w = 640
h = 480
center = (0, 0, 0.1)
distance = 0.7
elevation = 20
azimuth = 180.0 if self.task_name == 'rope' else 120.0
target = np.array(center)
theta = 90 + azimuth
z = distance * math.sin(math.radians(elevation))
y = math.cos(math.radians(theta)) * distance * math.cos(math.radians(elevation))
x = math.sin(math.radians(theta)) * distance * math.cos(math.radians(elevation))
origin = target + np.array([x, y, z])
look_at = target - origin
look_at /= np.linalg.norm(look_at)
up = np.array([0.0, 0.0, 1.0])
right = np.cross(look_at, up)
right /= np.linalg.norm(right)
up = np.cross(right, look_at)
w2c = np.eye(4)
w2c[:3, 0] = right
w2c[:3, 1] = -up
w2c[:3, 2] = look_at
w2c[:3, 3] = origin
w2c = np.linalg.inv(w2c)
k = np.array(
[[w / 2 * 1.0, 0., w / 2],
[0., w / 2 * 1.0, h / 2],
[0., 0., 1.]],
)
self.update_camera(k, w2c, w, h)
self.preprocess_metadata = preprocess_metadata
self.state = state
self.params = params
self.table_params = table_params
self.gripper_params = gripper_params
for k, v in self.preprocess_metadata.items():
self.preprocess_metadata[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v
for k, v in self.state.items():
self.state[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v
for k, v in self.params.items():
if isinstance(v, dict):
for k2, v2 in v.items():
self.params[k][k2] = v2.to(self.torch_device) if isinstance(v2, torch.Tensor) else v2
else:
self.params[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v
self.table_params = tuple(
v.to(self.torch_device) if isinstance(v, torch.Tensor) else v for v in self.table_params
)
self.gripper_params = tuple(
v.to(self.torch_device) if isinstance(v, torch.Tensor) else v for v in self.gripper_params
)
for k, v in rendervar.items():
rendervar[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v
num_steps = 15
batch_size = 1
num_particles = self.cfg.sim.n_particles
self.init_model(batch_size, num_steps, num_particles, ckpt_path=None)
self.render_direction = True
# im_list = []
for i in range(num_steps):
dt = 0.1 # 100ms
command = torch.tensor([unit_command]).to(self.device).to(torch.float32) # 5cm/s
command = self.preprocess_gripper(command)
# command_timestamp = torch.tensor([self.state['prev_key_pos_timestamp'] + (i+1) * dt]).to(self.device).to(torch.float32)
# print(command_timestamp)
if self.verbose:
print('command:', command.cpu().numpy().tolist())
self.gripper_direction = command.clone()
assert self.state['sub_pos'] is None
if self.state['sub_pos'] is None:
eef_xyz_latest = self.state['prev_key_pos']
# eef_xyz_timestamp_latest = self.state['prev_key_pos_timestamp']
else:
eef_xyz_latest = self.state['sub_pos'][-1] # (1, 3), model frame
# eef_xyz_timestamp_latest = self.state['sub_pos_timestamps'][-1].item()
eef_xyz_updated = eef_xyz_latest + command * dt * 0.01 # cm to m
if self.state['sub_pos'] is None:
self.state['sub_pos'] = eef_xyz_updated[None]
# self.state['sub_pos_timestamps'] = command_timestamp
else:
self.state['sub_pos'] = torch.cat([self.state['sub_pos'], eef_xyz_updated[None]], dim=0)
# self.state['sub_pos_timestamps'] = torch.cat([self.state['sub_pos_timestamps'], command_timestamp], dim=0)
# if self.state['sub_pos'] is None:
# eef_xyz = self.state['prev_key_pos']
# else:
# eef_xyz = self.state['sub_pos'][-1] # (1, 3), model frame
# if self.verbose:
# print(eef_xyz.cpu().numpy().tolist(), end=' ')
self.step()
rendervar, rendervar_full = self.update_rendervar(rendervar)
# self.rendervar = rendervar
im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0])
im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy()
# im_list.append(im_show)
cv2.imwrite(str(root / f'log/temp/{i:04}.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR))
# self.state['prev_key_pos_timestamp'] = self.state['prev_key_pos_timestamp'] + 20 * dt
self.state['v'] *= 0.0
self.state['x'] = self.state['x_pred'].clone()
self.state['x_his'] = self.state['x'][None].repeat(self.cfg.sim.n_history, 1, 1)
self.state['v_his'] *= 0.0
self.state['v_pred'] *= 0.0
for k, v in self.preprocess_metadata.items():
self.preprocess_metadata[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
for k, v in self.state.items():
self.state[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
for k, v in self.params.items():
if isinstance(v, dict):
for k2, v2 in v.items():
self.params[k][k2] = v2.detach().cpu() if isinstance(v2, torch.Tensor) else v2
else:
self.params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
self.table_params = tuple(
v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.table_params
)
self.gripper_params = tuple(
v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.gripper_params
)
for k, v in rendervar.items():
rendervar[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
make_video(root / 'log/temp', root / f'log/gs/temp/form_video.mp4', '%04d.png', 5)
form_video = gr.Video(
label='Predicted video',
value=root / f'log/gs/temp/form_video.mp4',
format='mp4',
width=self.width,
height=self.height,
)
im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0])
im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy()
gs_pred = save_to_splat(
rendervar_full['means3D'].cpu().numpy(),
rendervar_full['colors_precomp'].cpu().numpy(),
rendervar_full['scales'].cpu().numpy(),
rendervar_full['rotations'].cpu().numpy(),
rendervar_full['opacities'].cpu().numpy(),
root / 'log/gs/temp/gs_pred.splat',
rot_rev=True,
)
form_3dgs_pred = gr.Model3D(
label='Predicted Gaussian Splats',
height=self.height,
value=root / 'log/gs/temp/gs_pred.splat',
clear_color=[0, 0, 0, 0],
)
return form_video, form_3dgs_pred, \
self.preprocess_metadata, self.state, self.params, \
self.table_params, self.gripper_params, rendervar, task_name
@spaces.GPU
def reset_rope(self):
return self.reset('rope', 'rope_scene_1')
@spaces.GPU
def reset_plush(self):
return self.reset('sloth', 'sloth_scene_1')
@spaces.GPU
def on_click_run_xplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name):
return self.run_command([5.0, 0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name)
@spaces.GPU
def on_click_run_xminus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name):
return self.run_command([-5.0, 0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name)
@spaces.GPU
def on_click_run_yplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name):
return self.run_command([0, 5.0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name)
@spaces.GPU
def on_click_run_yminus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name):
return self.run_command([0, -5.0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name)
@spaces.GPU
def on_click_run_zplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name):
return self.run_command([0, 0, 5.0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name)
@spaces.GPU
def on_click_run_zminus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name):
return self.run_command([0, 0, -5.0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name)
def launch(self, share=False):
with gr.Blocks() as app:
preprocess_metadata = gr.State(self.preprocess_metadata)
state = gr.State(self.state)
params = gr.State(self.params)
table_params = gr.State(self.table_params)
gripper_params = gr.State(self.gripper_params)
rendervar = gr.State(None)
task_name = gr.State(self.task_name)
with gr.Row():
gr.Markdown("# Particle-Grid Neural Dynamics for Learning Deformable Object Models from RGB-D Videos")
with gr.Row():
gr.Markdown('### Project page: [https://kywind.github.io/pgnd](https://kywind.github.io/pgnd)')
with gr.Row():
gr.Markdown('### Instructions:')
with gr.Row():
gr.Markdown(' '.join([
'- Click the "Reset-\<object\>" button to initialize the simulation with the predicted video and Gaussian splats. Due to compute limitations of Huggingface Space, each run may take a prolonged period (up to 30 seconds).\n',
'- Use the buttons to move the gripper in the x, y, z directions. The gripper will move for a fixed length per click. The predicted video and Gaussian splats will be updated accordingly.\n',
'- X-Y plane is the table surface, and Z is the height.\n',
'- The predicted video from the previous step to the current step will be shown in the "Predicted video" section.\n',
'- The Gaussian splats after the current step will be shown in the "Predicted Gaussians" section.\n',
'- The simulation results may deviate from the initial shape due to accumulative prediction artifacts. Click the Reset button to reset the simulation state and reinitialize the predicted video and Gaussian splats.\n',
]))
with gr.Row():
gr.Markdown('### Select a scene to reset the simulation:')
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
with gr.Column():
run_reset_plush = gr.Button("Reset - Plush")
with gr.Column():
run_reset_rope = gr.Button("Reset - Rope")
with gr.Column(scale=2):
_ = gr.Button(visible=False) # empty placeholder
with gr.Row():
with gr.Column(scale=2):
form_video = gr.Video(
label='Predicted video',
value=None,
format='mp4',
width=self.width,
height=self.height,
)
with gr.Column(scale=2):
form_3dgs_pred = gr.Model3D(
label='Predicted Gaussians',
height=self.height,
value=None,
clear_color=[0, 0, 0, 0],
)
# Layout
with gr.Row():
gr.Markdown('### Control the gripper to move in the x, y, z directions:')
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
with gr.Column():
run_xminus = gr.Button("x-")
with gr.Column():
run_xplus = gr.Button("x+")
with gr.Row():
with gr.Column():
run_yminus = gr.Button("y-")
with gr.Column():
run_yplus = gr.Button("y+")
with gr.Row():
with gr.Column():
run_zminus = gr.Button("z-")
with gr.Column():
run_zplus = gr.Button("z+")
with gr.Column(scale=2):
_ = gr.Button(visible=False) # empty placeholder
# Set up callbacks
run_reset_rope.click(self.reset_rope,
inputs=[],
outputs=[form_video, form_3dgs_pred,
preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name])
run_reset_plush.click(self.reset_plush,
inputs=[],
outputs=[form_video, form_3dgs_pred,
preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name])
run_xplus.click(self.on_click_run_xplus,
inputs=[preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name],
outputs=[form_video, form_3dgs_pred,
preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name])
run_xminus.click(self.on_click_run_xminus,
inputs=[preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name],
outputs=[form_video, form_3dgs_pred,
preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name])
run_yplus.click(self.on_click_run_yplus,
inputs=[preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name],
outputs=[form_video, form_3dgs_pred,
preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name])
run_yminus.click(self.on_click_run_yminus,
inputs=[preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name],
outputs=[form_video, form_3dgs_pred,
preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name])
run_zplus.click(self.on_click_run_zplus,
inputs=[preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name],
outputs=[form_video, form_3dgs_pred,
preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name])
run_zminus.click(self.on_click_run_zminus,
inputs=[preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name],
outputs=[form_video, form_3dgs_pred,
preprocess_metadata, state, params,
table_params, gripper_params, rendervar, task_name])
app.launch(share=share)
if __name__ == '__main__':
visualizer = DynamicsVisualizer()
visualizer.launch(share=True)