|
import gradio as gr |
|
import sys |
|
import site |
|
from PIL import Image |
|
from pathlib import Path |
|
from omegaconf import DictConfig, OmegaConf |
|
from tqdm import tqdm, trange |
|
import random |
|
import math |
|
import hydra |
|
import numpy as np |
|
import glob |
|
import os |
|
import subprocess |
|
import time |
|
import cv2 |
|
import copy |
|
import yaml |
|
import matplotlib.pyplot as plt |
|
from sklearn.neighbors import NearestNeighbors |
|
|
|
import spaces |
|
from spaces import zero |
|
zero.startup() |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader |
|
import kornia |
|
|
|
from diff_gaussian_rasterization import GaussianRasterizer |
|
from diff_gaussian_rasterization import GaussianRasterizationSettings as Camera |
|
|
|
import sys |
|
sys.path.insert(0, str(Path(__file__).parent / "src")) |
|
sys.path.append(str(Path(__file__).parent / "src" / "experiments")) |
|
|
|
from real_world.utils.render_utils import interpolate_motions |
|
from real_world.gs.helpers import setup_camera |
|
from real_world.gs.convert import save_to_splat, read_splat |
|
|
|
root = Path(__file__).parent / "src" / "experiments" |
|
|
|
def make_video( |
|
image_root: Path, |
|
video_path: Path, |
|
image_pattern: str = '%04d.png', |
|
frame_rate: int = 10): |
|
|
|
subprocess.run([ |
|
'ffmpeg', |
|
'-y', |
|
'-hide_banner', |
|
'-loglevel', 'error', |
|
'-framerate', str(frame_rate), |
|
'-i', str(image_root / image_pattern), |
|
'-c:v', 'libx264', |
|
'-pix_fmt', 'yuv420p', |
|
str(video_path) |
|
]) |
|
|
|
def quat2mat(quat): |
|
import kornia |
|
return kornia.geometry.conversions.quaternion_to_rotation_matrix(quat) |
|
|
|
|
|
def mat2quat(mat): |
|
import kornia |
|
return kornia.geometry.conversions.rotation_matrix_to_quaternion(mat) |
|
|
|
|
|
def fps(x, enabled, n, device, random_start=False): |
|
import torch |
|
from dgl.geometry import farthest_point_sampler |
|
assert torch.diff(enabled * 1.0).sum() in [0.0, -1.0] |
|
start_idx = random.randint(0, enabled.sum() - 1) if random_start else 0 |
|
fps_idx = farthest_point_sampler(x[enabled][None], n, start_idx=start_idx)[0] |
|
fps_idx = fps_idx.to(x.device) |
|
return fps_idx |
|
|
|
|
|
class DynamicsVisualizer: |
|
|
|
def __init__(self, wp_device='cuda', torch_device='cuda'): |
|
|
|
self.best_models = { |
|
'cloth': ['cloth', 'train', 100000, [610, 650]], |
|
'rope': ['rope', 'train', 100000, [651, 691]], |
|
'paperbag': ['paperbag', 'train', 100000, [200, 220]], |
|
'sloth': ['sloth', 'train', 100000, [113, 133]], |
|
'box': ['box', 'train', 100000, [306, 323]], |
|
'bread': ['bread', 'train', 100000, [143, 163]], |
|
} |
|
task_name = 'rope' |
|
self.init(task_name) |
|
|
|
def init(self, task_name): |
|
self.width = 640 |
|
self.height = 480 |
|
self.task_name = task_name |
|
|
|
with open(root / f'log/{self.best_models[task_name][0]}/{self.best_models[task_name][1]}/hydra.yaml', 'r') as f: |
|
config = yaml.load(f, Loader=yaml.CLoader) |
|
cfg = OmegaConf.create(config) |
|
|
|
cfg.iteration = self.best_models[task_name][2] |
|
cfg.start_episode = self.best_models[task_name][3][0] |
|
cfg.end_episode = self.best_models[task_name][3][1] |
|
cfg.sim.num_steps = 1000 |
|
cfg.sim.gripper_forcing = False |
|
cfg.sim.uniform = True |
|
cfg.sim.use_pv = False |
|
|
|
device = torch.device('cuda') |
|
|
|
self.cfg = cfg |
|
self.device = device |
|
self.k_rel = 8 |
|
self.k_wgt = 16 |
|
self.with_bg = True |
|
self.render_gripper = True |
|
self.render_direction = True |
|
self.verbose = False |
|
|
|
self.dt_base = cfg.sim.dt |
|
self.high_freq_pred = True |
|
|
|
seed = cfg.seed |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
|
|
|
|
|
|
self.clear() |
|
|
|
def clear(self, clear_params=True): |
|
self.metadata = {} |
|
self.config = {} |
|
if clear_params: |
|
self.params = None |
|
self.state = { |
|
|
|
'x': None, |
|
'v': None, |
|
'x_his': None, |
|
'v_his': None, |
|
'x_pred': None, |
|
'v_pred': None, |
|
'clip_bound': None, |
|
'enabled': None, |
|
|
|
'prev_key_pos': None, |
|
'prev_key_pos_timestamp': None, |
|
'sub_pos': None, |
|
'sub_pos_timestamps': None, |
|
'gripper_radius': None, |
|
} |
|
self.preprocess_metadata = None |
|
self.table_params = None |
|
self.gripper_params = None |
|
|
|
self.sim = None |
|
self.statics = None |
|
self.colliders = None |
|
self.material = None |
|
self.friction = None |
|
|
|
def load_scaniverse(self, data_path): |
|
|
|
|
|
|
|
params_obj = read_splat(data_path / 'object.splat') |
|
params_table = read_splat(data_path / 'table.splat') |
|
params_robot = read_splat(data_path / 'gripper.splat') |
|
|
|
pts, colors, scales, quats, opacities = params_obj |
|
self.params = { |
|
'means3D': torch.from_numpy(pts).to(torch.float32).to(self.device), |
|
'rgb_colors': torch.from_numpy(colors).to(torch.float32).to(self.device), |
|
'log_scales': torch.log(torch.from_numpy(scales).to(torch.float32).to(self.device)), |
|
'unnorm_rotations': torch.from_numpy(quats).to(torch.float32).to(self.device), |
|
'logit_opacities': torch.logit(torch.from_numpy(opacities).to(torch.float32).to(self.device)) |
|
} |
|
|
|
t_pts, t_colors, t_scales, t_quats, t_opacities = params_table |
|
t_pts = torch.tensor(t_pts).to(torch.float32).to(self.device) |
|
t_colors = torch.tensor(t_colors).to(torch.float32).to(self.device) |
|
t_scales = torch.tensor(t_scales).to(torch.float32).to(self.device) |
|
t_quats = torch.tensor(t_quats).to(torch.float32).to(self.device) |
|
t_opacities = torch.tensor(t_opacities).to(torch.float32).to(self.device) |
|
|
|
g_pts, g_colors, g_scales, g_quats, g_opacities = params_robot |
|
g_pts = torch.tensor(g_pts).to(torch.float32).to(self.device) |
|
g_colors = torch.tensor(g_colors).to(torch.float32).to(self.device) |
|
g_scales = torch.tensor(g_scales).to(torch.float32).to(self.device) |
|
g_quats = torch.tensor(g_quats).to(torch.float32).to(self.device) |
|
g_opacities = torch.tensor(g_opacities).to(torch.float32).to(self.device) |
|
|
|
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities |
|
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities |
|
|
|
n_particles = self.cfg.sim.n_particles |
|
self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32) |
|
self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool) |
|
|
|
|
|
|
|
cfg = self.cfg |
|
dx = cfg.sim.num_grids[-1] |
|
|
|
p_x = torch.tensor(pts).to(torch.float32).to(self.device) |
|
R = torch.tensor( |
|
[[1, 0, 0], |
|
[0, 0, -1], |
|
[0, 1, 0]] |
|
).to(p_x.device).to(p_x.dtype) |
|
p_x_rotated = p_x @ R.T |
|
|
|
scale = 1.0 |
|
p_x_rotated_scaled = p_x_rotated * scale |
|
|
|
global_translation = torch.tensor([ |
|
0.5 - p_x_rotated_scaled[:, 0].mean(), |
|
dx * (cfg.model.clip_bound + 0.5) - p_x_rotated_scaled[:, 1].min(), |
|
0.5 - p_x_rotated_scaled[:, 2].mean(), |
|
], dtype=p_x_rotated_scaled.dtype, device=p_x_rotated_scaled.device) |
|
|
|
R_viewer = torch.tensor( |
|
[[1, 0, 0], |
|
[0, 0, -1], |
|
[0, 1, 0]] |
|
).to(p_x.device).to(p_x.dtype) |
|
t_viewer = torch.tensor([0, 0, 0]).to(p_x.device).to(p_x.dtype) |
|
|
|
self.preprocess_metadata = { |
|
'R': R, |
|
'R_viewer': R_viewer, |
|
't_viewer': t_viewer, |
|
'scale': scale, |
|
'global_translation': global_translation, |
|
} |
|
|
|
|
|
grippers = np.loadtxt(data_path / 'eef_xyz.txt')[None] |
|
assert grippers.shape == (1, 3) |
|
|
|
if grippers is not None: |
|
grippers = torch.tensor(grippers).to(self.device).to(torch.float32) |
|
|
|
|
|
|
|
R = self.preprocess_metadata['R'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
grippers[:, :3] = grippers[:, :3] @ R.T |
|
grippers[:, :3] = grippers[:, :3] * scale |
|
grippers[:, :3] += global_translation |
|
|
|
assert grippers.shape[0] == 1 |
|
self.state['prev_key_pos'] = grippers[:, :3] |
|
|
|
self.state['gripper_radius'] = cfg.model.gripper_radius |
|
|
|
def load_eef(self, grippers=None, eef_t=None): |
|
assert self.state['prev_key_pos'] is None |
|
|
|
if grippers is not None: |
|
grippers = torch.tensor(grippers).to(self.device).to(torch.float32) |
|
eef_t = torch.tensor(eef_t).to(self.device).to(torch.float32) |
|
grippers[:, :3] = grippers[:, :3] + eef_t |
|
|
|
|
|
|
|
R = self.preprocess_metadata['R'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
grippers[:, :3] = grippers[:, :3] @ R.T |
|
grippers[:, :3] = grippers[:, :3] * scale |
|
grippers[:, :3] += global_translation |
|
|
|
assert grippers.shape[0] == 1 |
|
self.state['prev_key_pos'] = grippers[:, :3] |
|
|
|
self.state['gripper_radius'] = self.cfg.model.gripper_radius |
|
|
|
def load_preprocess_metadata(self, p_x_orig): |
|
cfg = self.cfg |
|
dx = cfg.sim.num_grids[-1] |
|
|
|
p_x_orig = p_x_orig.to(self.device) |
|
R = torch.tensor( |
|
[[1, 0, 0], |
|
[0, 0, -1], |
|
[0, 1, 0]] |
|
).to(p_x_orig.device).to(p_x_orig.dtype) |
|
p_x_orig_rotated = torch.einsum('nij,jk->nik', p_x_orig, R.T) |
|
|
|
scale = 1.0 |
|
p_x_orig_rotated_scaled = p_x_orig_rotated * scale |
|
|
|
global_translation = torch.tensor([ |
|
0.5 - p_x_orig_rotated_scaled[:, :, 0].mean(), |
|
dx * (cfg.model.clip_bound + 0.5) - p_x_orig_rotated_scaled[:, :, 1].min(), |
|
0.5 - p_x_orig_rotated_scaled[:, :, 2].mean(), |
|
], dtype=p_x_orig_rotated_scaled.dtype, device=p_x_orig_rotated_scaled.device) |
|
|
|
R_viewer = torch.tensor( |
|
[[1, 0, 0], |
|
[0, 0, -1], |
|
[0, 1, 0]] |
|
).to(p_x_orig.device).to(p_x_orig.dtype) |
|
t_viewer = torch.tensor([0, 0, 0]).to(p_x_orig.device).to(p_x_orig.dtype) |
|
|
|
self.preprocess_metadata = { |
|
'R': R, |
|
'R_viewer': R_viewer, |
|
't_viewer': t_viewer, |
|
'scale': scale, |
|
'global_translation': global_translation, |
|
} |
|
|
|
|
|
def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]): |
|
render_data = {k: v.to(self.device) for k, v in render_data.items()} |
|
w, h = self.metadata['w'], self.metadata['h'] |
|
k, w2c = self.metadata['k'], self.metadata['w2c'] |
|
cam = setup_camera(w, h, k, w2c, self.config['near'], self.config['far'], bg) |
|
im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data) |
|
return im, depth |
|
|
|
def knn_relations(self, bones): |
|
k = self.k_rel |
|
knn = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(bones.detach().cpu().numpy()) |
|
_, indices = knn.kneighbors(bones.detach().cpu().numpy()) |
|
indices = indices[:, 1:] |
|
return indices |
|
|
|
def knn_weights_brute(self, bones, pts): |
|
k = self.k_wgt |
|
dist = torch.norm(pts[:, None] - bones, dim=-1) |
|
_, indices = torch.topk(dist, k, dim=-1, largest=False) |
|
bones_selected = bones[indices] |
|
dist = torch.norm(bones_selected - pts[:, None], dim=-1) |
|
weights = 1 / (dist + 1e-6) |
|
weights = weights / weights.sum(dim=-1, keepdim=True) |
|
weights_all = torch.zeros((pts.shape[0], bones.shape[0]), device=pts.device) |
|
weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights |
|
return weights_all |
|
|
|
def update_camera(self, k, w2c, w=None, h=None, near=0.01, far=100.0): |
|
self.metadata['k'] = k |
|
self.metadata['w2c'] = w2c |
|
if w is not None: |
|
self.metadata['w'] = w |
|
if h is not None: |
|
self.metadata['h'] = h |
|
self.config['near'] = near |
|
self.config['far'] = far |
|
|
|
def init_model(self, batch_size, num_steps, num_particles, ckpt_path=None): |
|
from pgnd.sim import Friction, CacheDiffSimWithFrictionBatch, StaticsBatch, CollidersBatch |
|
from pgnd.material import PGNDModel |
|
|
|
self.cfg.sim.num_steps = num_steps |
|
cfg = self.cfg |
|
|
|
sim = CacheDiffSimWithFrictionBatch(cfg, num_steps, batch_size, self.wp_device, requires_grad=True) |
|
|
|
statics = StaticsBatch() |
|
statics.init(shape=(batch_size, num_particles), device=self.wp_device) |
|
statics.update_clip_bound(self.state['clip_bound'].detach().cpu()) |
|
statics.update_enabled(self.state['enabled'][None].detach().cpu()) |
|
colliders = CollidersBatch() |
|
colliders.init(shape=(batch_size, cfg.sim.num_grippers), device=self.wp_device) |
|
|
|
self.sim = sim |
|
self.statics = statics |
|
self.colliders = colliders |
|
|
|
|
|
ckpt_path = root / f'log/{self.task_name}/train/ckpt/100000.pt' |
|
ckpt = torch.load(ckpt_path, map_location=self.torch_device) |
|
|
|
material: nn.Module = PGNDModel(cfg) |
|
material.to(self.torch_device) |
|
material.load_state_dict(ckpt['material']) |
|
material.requires_grad_(False) |
|
material.eval() |
|
|
|
if 'friction' in ckpt: |
|
friction = ckpt['friction']['mu'].reshape(-1, 1) |
|
else: |
|
friction = torch.tensor(cfg.model.friction.value, device=self.torch_device).reshape(-1, 1) |
|
|
|
self.material = material |
|
self.friction = friction |
|
|
|
def reload_model(self, num_steps): |
|
from pgnd.sim import CacheDiffSimWithFrictionBatch |
|
self.cfg.sim.num_steps = num_steps |
|
sim = CacheDiffSimWithFrictionBatch(self.cfg, num_steps, 1, self.wp_device, requires_grad=True) |
|
self.sim = sim |
|
|
|
|
|
def step(self): |
|
cfg = self.cfg |
|
batch_size = 1 |
|
num_steps = 1 |
|
num_particles = cfg.sim.n_particles |
|
|
|
|
|
self.state['x_his'] = torch.cat([self.state['x_his'][1:], self.state['x'][None]], dim=0) |
|
self.state['v_his'] = torch.cat([self.state['v_his'][1:], self.state['v'][None]], dim=0) |
|
self.state['x'] = self.state['x_pred'].clone() |
|
self.state['v'] = self.state['v_pred'].clone() |
|
|
|
eef_xyz_key = self.state['prev_key_pos'] |
|
eef_xyz_sub = self.state['sub_pos'] |
|
|
|
if eef_xyz_sub is None: |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eef_xyz_key_next = eef_xyz_sub[-1] |
|
eef_v = (eef_xyz_key_next - eef_xyz_key) / cfg.sim.dt |
|
if self.verbose: |
|
print('delta_t:', np.round(cfg.sim.dt, 4)) |
|
print('eef_xyz_key_next:', eef_xyz_key_next.cpu().numpy().tolist()) |
|
print('eef_xyz_key:', eef_xyz_key.cpu().numpy().tolist()) |
|
print('v:', eef_v.cpu().numpy().tolist()) |
|
|
|
|
|
|
|
|
|
|
|
if cfg.sim.num_grippers > 0: |
|
grippers = torch.zeros((batch_size, cfg.sim.num_grippers, 15), device=self.torch_device) |
|
eef_quat = torch.tensor([1, 0, 0, 0], dtype=torch.float32, device=self.torch_device).repeat(batch_size, cfg.sim.num_grippers, 1) |
|
eef_quat_vel = torch.zeros((batch_size, cfg.sim.num_grippers, 3), dtype=torch.float32, device=self.torch_device) |
|
eef_gripper = torch.zeros((batch_size, cfg.sim.num_grippers), dtype=torch.float32, device=self.torch_device) |
|
grippers[:, :, :3] = eef_xyz_key |
|
grippers[:, :, 3:6] = eef_v |
|
grippers[:, :, 6:10] = eef_quat |
|
grippers[:, :, 10:13] = eef_quat_vel |
|
grippers[:, :, 13] = cfg.model.gripper_radius |
|
grippers[:, :, 14] = eef_gripper |
|
self.colliders.initialize_grippers(grippers) |
|
|
|
x = self.state['x'].clone()[None].repeat(batch_size, 1, 1) |
|
v = self.state['v'].clone()[None].repeat(batch_size, 1, 1) |
|
x_his = self.state['x_his'].permute(1, 0, 2).clone() |
|
assert x_his.shape[0] == num_particles |
|
x_his = x_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1) |
|
v_his = self.state['v_his'].permute(1, 0, 2).clone() |
|
assert v_his.shape[0] == num_particles |
|
v_his = v_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1) |
|
enabled = self.state['enabled'].clone().to(self.torch_device)[None].repeat(batch_size, 1) |
|
|
|
for t in range(num_steps): |
|
x_in = x.clone() |
|
pred = self.material(x, v, x_his, v_his, enabled) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x, v = self.sim(self.statics, self.colliders, t, x, v, self.friction, pred) |
|
|
|
|
|
x_pred = x[0].clone() |
|
v_pred = v[0].clone() |
|
self.state['x_pred'] = x_pred |
|
self.state['v_pred'] = v_pred |
|
|
|
|
|
|
|
self.state['prev_key_pos'] = eef_xyz_key_next |
|
|
|
self.state['sub_pos'] = None |
|
|
|
|
|
def preprocess_x(self, p_x): |
|
R = self.preprocess_metadata['R'] |
|
R_viewer = self.preprocess_metadata['R_viewer'] |
|
t_viewer = self.preprocess_metadata['t_viewer'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
|
|
|
|
p_x = (p_x - t_viewer) @ R_viewer |
|
|
|
|
|
|
|
|
|
|
|
|
|
return p_x |
|
|
|
def preprocess_gripper(self, grippers): |
|
R = self.preprocess_metadata['R'] |
|
R_viewer = self.preprocess_metadata['R_viewer'] |
|
t_viewer = self.preprocess_metadata['t_viewer'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
|
|
|
|
grippers[:, :3] = grippers[:, :3] @ R_viewer |
|
|
|
return grippers |
|
|
|
def inverse_preprocess_x(self, p_x): |
|
R = self.preprocess_metadata['R'] |
|
R_viewer = self.preprocess_metadata['R_viewer'] |
|
t_viewer = self.preprocess_metadata['t_viewer'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
|
|
|
|
p_x = p_x @ R_viewer.T + t_viewer |
|
|
|
return p_x |
|
|
|
def inverse_preprocess_gripper(self, grippers): |
|
R = self.preprocess_metadata['R'] |
|
R_viewer = self.preprocess_metadata['R_viewer'] |
|
t_viewer = self.preprocess_metadata['t_viewer'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
|
|
|
|
grippers[:, :3] = grippers[:, :3] @ R_viewer.T + t_viewer |
|
|
|
return grippers |
|
|
|
def rotate(self, params, rot_mat): |
|
scale = np.linalg.norm(rot_mat, axis=1, keepdims=True) |
|
|
|
params = { |
|
'means3D': pts, |
|
'rgb_colors': params['rgb_colors'], |
|
'log_scales': params['log_scales'], |
|
'unnorm_rotations': quats, |
|
'logit_opacities': params['logit_opacities'], |
|
} |
|
return params |
|
|
|
def preprocess_gs(self, params): |
|
if isinstance(params, dict): |
|
xyz = params['means3D'] |
|
rgb = params['rgb_colors'] |
|
quat = torch.nn.functional.normalize(params['unnorm_rotations']) |
|
opa = torch.sigmoid(params['logit_opacities']) |
|
scales = torch.exp(params['log_scales']) |
|
else: |
|
assert isinstance(params, tuple) |
|
xyz, rgb, quat, opa, scales = params |
|
|
|
quat = torch.nn.functional.normalize(quat, dim=-1) |
|
|
|
|
|
R = self.preprocess_metadata['R'] |
|
R_viewer = self.preprocess_metadata['R_viewer'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
|
|
mat = quat2mat(quat) |
|
mat = R @ mat |
|
xyz = xyz @ R.T |
|
xyz = xyz * scale |
|
xyz += global_translation |
|
quat = mat2quat(mat) |
|
scales = scales * scale |
|
|
|
|
|
|
|
xyz = xyz @ R_viewer.T |
|
quat = mat2quat(R_viewer @ quat2mat(quat)) |
|
|
|
t_viewer = -xyz.mean(dim=0) |
|
t_viewer[2] = 0 |
|
xyz += t_viewer |
|
print('Overwriting t_viewer to be the planar mean of the object') |
|
self.preprocess_metadata['t_viewer'] = t_viewer |
|
|
|
if isinstance(params, dict): |
|
params['means3D'] = xyz |
|
params['rgb_colors'] = rgb |
|
params['unnorm_rotations'] = quat |
|
params['logit_opacities'] = opa |
|
params['log_scales'] = torch.log(scales) |
|
else: |
|
params = xyz, rgb, quat, opa, scales |
|
|
|
return params |
|
|
|
def preprocess_bg_gs(self): |
|
t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params |
|
g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params |
|
|
|
|
|
g_pts_tip_z = g_pts[:, 2].max() |
|
g_pts_tip_mask = (g_pts[:, 2] > g_pts_tip_z - 0.04) & (g_pts[:, 2] < g_pts_tip_z) |
|
|
|
R = self.preprocess_metadata['R'] |
|
R_viewer = self.preprocess_metadata['R_viewer'] |
|
t_viewer = self.preprocess_metadata['t_viewer'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
|
|
t_mat = quat2mat(t_quats) |
|
t_mat = R @ t_mat |
|
t_pts = t_pts @ R.T |
|
t_pts = t_pts * scale |
|
t_pts += global_translation |
|
t_quats = mat2quat(t_mat) |
|
t_scales = t_scales * scale |
|
|
|
t_pts = t_pts @ R_viewer.T |
|
t_quats = mat2quat(R_viewer @ quat2mat(t_quats)) |
|
t_pts += t_viewer |
|
|
|
axes = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] |
|
dirs = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] |
|
for ee in range(3): |
|
gripper_direction = torch.tensor(dirs[ee], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) |
|
gripper_direction = gripper_direction / (torch.norm(gripper_direction, dim=-1, keepdim=True) + 1e-10) |
|
|
|
R = self.preprocess_metadata['R'] |
|
|
|
direction = gripper_direction @ R.T |
|
|
|
n_grippers = 1 |
|
N = 200 |
|
length = 0.2 |
|
kk = 5 |
|
xyz_test = torch.zeros((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=t_pts.dtype) |
|
|
|
if self.task_name == 'rope': |
|
pos = torch.tensor([0.0, 0.0, 1.2], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) |
|
else: |
|
pos = torch.tensor([1.2, 0.0, 0.7], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) |
|
gripper_now_inv_xyz = self.inverse_preprocess_gripper(pos) |
|
gripper_now_inv_rot = torch.eye(3, device=self.torch_device).unsqueeze(0).repeat(n_grippers, 1, 1) |
|
|
|
center_point = torch.tensor([0.0, 0.0, 0.10], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) |
|
gripper_center_inv_xyz = gripper_now_inv_xyz + \ |
|
torch.einsum('ijk,ik->ij', gripper_now_inv_rot, center_point) |
|
|
|
for i in range(N): |
|
offset = i / N * length * direction |
|
xyz_test[:, i] = gripper_center_inv_xyz + offset |
|
|
|
if direction[0, 2] < 0.9 and direction[0, 2] > -0.9: |
|
direction_up = -direction + torch.tensor([0.0, 0.0, 0.5], device=self.torch_device, dtype=t_pts.dtype) |
|
direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) |
|
direction_down = -direction + torch.tensor([0.0, 0.0, -0.5], device=self.torch_device, dtype=t_pts.dtype) |
|
direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) |
|
else: |
|
direction_up = -direction + torch.tensor([0.0, 0.5, 0.0], device=self.torch_device, dtype=t_pts.dtype) |
|
direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) |
|
direction_down = -direction + torch.tensor([0.0, -0.5, 0.0], device=self.torch_device, dtype=t_pts.dtype) |
|
direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) |
|
|
|
for i in range(N, N + N // kk): |
|
offset = length * direction + (i - N) / N * length * direction_up |
|
xyz_test[:, i] = gripper_center_inv_xyz + offset |
|
|
|
for i in range(N + N // kk, N + N // kk + N // kk): |
|
offset = length * direction + (i - N - N // kk) / N * length * direction_down |
|
xyz_test[:, i] = gripper_center_inv_xyz + offset |
|
|
|
color_test = torch.zeros_like(xyz_test, device=self.torch_device, dtype=t_pts.dtype) |
|
color_test[:, :, 0] = axes[ee][0] |
|
color_test[:, :, 1] = axes[ee][1] |
|
color_test[:, :, 2] = axes[ee][2] |
|
quat_test = torch.zeros((n_grippers, N + N // kk + N // kk, 4), device=self.torch_device, dtype=t_pts.dtype) |
|
quat_test[:, :, 0] = 1.0 |
|
opa_test = torch.ones((n_grippers, N + N // kk + N // kk, 1), device=self.torch_device, dtype=t_pts.dtype) |
|
scales_test = torch.ones((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=t_pts.dtype) * 0.002 |
|
|
|
t_pts = torch.cat([t_pts, xyz_test.reshape(-1, 3)], dim=0) |
|
t_colors = torch.cat([t_colors, color_test.reshape(-1, 3)], dim=0) |
|
t_quats = torch.cat([t_quats, quat_test.reshape(-1, 4)], dim=0) |
|
t_opacities = torch.cat([t_opacities, opa_test.reshape(-1, 1)], dim=0) |
|
t_scales = torch.cat([t_scales, scales_test.reshape(-1, 3)], dim=0) |
|
|
|
t_pts = t_pts.reshape(-1, 3) |
|
t_colors = t_colors.reshape(-1, 3) |
|
t_quats = t_quats.reshape(-1, 4) |
|
t_opacities = t_opacities.reshape(-1, 1) |
|
|
|
g_mat = quat2mat(g_quats) |
|
g_mat = R @ g_mat |
|
g_pts = g_pts @ R.T |
|
g_pts = g_pts * scale |
|
g_pts += global_translation |
|
g_quats = mat2quat(g_mat) |
|
g_scales = g_scales * scale |
|
|
|
g_pts = g_pts @ R_viewer.T |
|
g_quats = mat2quat(R_viewer @ quat2mat(g_quats)) |
|
g_pts += t_viewer |
|
|
|
|
|
g_pts_tip = g_pts[g_pts_tip_mask] |
|
g_pts_tip_mean_xy = g_pts_tip[:, :2].mean(dim=0) |
|
|
|
if self.task_name == 'rope': |
|
g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.23]).to(torch.float32).to(self.device) |
|
elif self.task_name == 'sloth': |
|
g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.32]).to(torch.float32).to(self.device) |
|
else: |
|
raise NotImplementedError(f"Task {self.task_name} not implemented for gripper translation.") |
|
g_pts = g_pts + g_pts_translation |
|
|
|
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities |
|
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities |
|
|
|
def update_rendervar(self, rendervar): |
|
p_x = self.state['x'] |
|
p_x_viewer = self.inverse_preprocess_x(p_x) |
|
|
|
p_x_pred = self.state['x_pred'] |
|
p_x_pred_viewer = self.inverse_preprocess_x(p_x_pred) |
|
|
|
xyz = rendervar['means3D'] |
|
rgb = rendervar['colors_precomp'] |
|
quat = rendervar['rotations'] |
|
opa = rendervar['opacities'] |
|
scales = rendervar['scales'] |
|
|
|
relations = self.knn_relations(p_x_viewer) |
|
weights = self.knn_weights_brute(p_x_viewer, xyz) |
|
xyz, quat, _ = interpolate_motions( |
|
bones=p_x_viewer, |
|
motions=p_x_pred_viewer - p_x_viewer, |
|
relations=relations, |
|
weights=weights, |
|
xyz=xyz, |
|
quat=quat, |
|
) |
|
|
|
|
|
quat = torch.nn.functional.normalize(quat, dim=-1) |
|
|
|
rendervar = { |
|
'means3D': xyz, |
|
'colors_precomp': rgb, |
|
'rotations': quat, |
|
'opacities': opa, |
|
'scales': scales, |
|
'means2D': torch.zeros_like(xyz), |
|
} |
|
|
|
if self.with_bg: |
|
t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params |
|
|
|
|
|
xyz = torch.cat([xyz, t_pts], dim=0) |
|
rgb = torch.cat([rgb, t_colors], dim=0) |
|
quat = torch.cat([quat, t_quats], dim=0) |
|
opa = torch.cat([opa, t_opacities], dim=0) |
|
scales = torch.cat([scales, t_scales], dim=0) |
|
|
|
if self.render_gripper: |
|
g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params |
|
|
|
|
|
g_pts = g_pts + self.inverse_preprocess_gripper(self.state['prev_key_pos'][None].clone())[0] |
|
|
|
|
|
xyz = torch.cat([xyz, g_pts], dim=0) |
|
rgb = torch.cat([rgb, g_colors], dim=0) |
|
quat = torch.cat([quat, g_quats], dim=0) |
|
opa = torch.cat([opa, g_opacities], dim=0) |
|
scales = torch.cat([scales, g_scales], dim=0) |
|
|
|
if self.render_direction: |
|
gripper_direction = self.gripper_direction |
|
gripper_direction = gripper_direction / (torch.norm(gripper_direction, dim=-1, keepdim=True) + 1e-10) |
|
|
|
R = self.preprocess_metadata['R'] |
|
|
|
direction = gripper_direction @ R.T |
|
|
|
n_grippers = 1 |
|
N = 200 |
|
length = 0.2 |
|
kk = 5 |
|
xyz_test = torch.zeros((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=xyz.dtype) |
|
|
|
gripper_now_inv_xyz = self.inverse_preprocess_gripper(self.state['prev_key_pos'][None].clone()) |
|
gripper_now_inv_rot = torch.eye(3, device=self.torch_device).unsqueeze(0).repeat(n_grippers, 1, 1) |
|
|
|
center_point = torch.tensor([0.0, 0.0, 0.10], device=self.torch_device, dtype=xyz.dtype).reshape(1, 3) |
|
gripper_center_inv_xyz = gripper_now_inv_xyz + \ |
|
torch.einsum('ijk,ik->ij', gripper_now_inv_rot, center_point) |
|
|
|
for i in range(N): |
|
offset = i / N * length * direction |
|
xyz_test[:, i] = gripper_center_inv_xyz + offset |
|
|
|
if direction[0, 2] < 0.9 and direction[0, 2] > -0.9: |
|
direction_up = -direction + torch.tensor([0.0, 0.0, 0.5], device=self.torch_device, dtype=xyz.dtype) |
|
direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) |
|
direction_down = -direction + torch.tensor([0.0, 0.0, -0.5], device=self.torch_device, dtype=xyz.dtype) |
|
direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) |
|
else: |
|
direction_up = -direction + torch.tensor([0.0, 0.5, 0.0], device=self.torch_device, dtype=xyz.dtype) |
|
direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) |
|
direction_down = -direction + torch.tensor([0.0, -0.5, 0.0], device=self.torch_device, dtype=xyz.dtype) |
|
direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) |
|
|
|
for i in range(N, N + N // kk): |
|
offset = length * direction + (i - N) / N * length * direction_up |
|
xyz_test[:, i] = gripper_center_inv_xyz + offset |
|
|
|
for i in range(N + N // kk, N + N // kk + N // kk): |
|
offset = length * direction + (i - N - N // kk) / N * length * direction_down |
|
xyz_test[:, i] = gripper_center_inv_xyz + offset |
|
|
|
color_test = torch.zeros_like(xyz_test, device=self.torch_device, dtype=xyz.dtype) |
|
color_test[:, :, 0] = 255 / 255 |
|
color_test[:, :, 1] = 80 / 255 |
|
color_test[:, :, 2] = 110 / 255 |
|
quat_test = torch.zeros((n_grippers, N + N // kk + N // kk, 4), device=self.torch_device, dtype=xyz.dtype) |
|
quat_test[:, :, 0] = 1.0 |
|
opa_test = torch.ones((n_grippers, N + N // kk + N // kk, 1), device=self.torch_device, dtype=xyz.dtype) |
|
scales_test = torch.ones((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=xyz.dtype) * 0.002 |
|
|
|
xyz = torch.cat([xyz, xyz_test.reshape(-1, 3)], dim=0) |
|
rgb = torch.cat([rgb, color_test.reshape(-1, 3)], dim=0) |
|
quat = torch.cat([quat, quat_test.reshape(-1, 4)], dim=0) |
|
opa = torch.cat([opa, opa_test.reshape(-1, 1)], dim=0) |
|
scales = torch.cat([scales, scales_test.reshape(-1, 3)], dim=0) |
|
|
|
|
|
quat = torch.nn.functional.normalize(quat, dim=-1) |
|
|
|
rendervar_full = { |
|
'means3D': xyz, |
|
'colors_precomp': rgb, |
|
'rotations': quat, |
|
'opacities': opa, |
|
'scales': scales, |
|
'means2D': torch.zeros_like(xyz), |
|
} |
|
|
|
else: |
|
rendervar_full = rendervar |
|
|
|
return rendervar, rendervar_full |
|
|
|
def reset_state(self, params, visualize_image=False, init=False): |
|
xyz_0 = params['means3D'] |
|
rgb_0 = params['rgb_colors'] |
|
quat_0 = torch.nn.functional.normalize(params['unnorm_rotations']) |
|
opa_0 = torch.sigmoid(params['logit_opacities']) |
|
scales_0 = torch.exp(params['log_scales']) |
|
|
|
rendervar_init = { |
|
'means3D': xyz_0, |
|
'colors_precomp': rgb_0, |
|
'rotations': quat_0, |
|
'opacities': opa_0, |
|
'scales': scales_0, |
|
'means2D': torch.zeros_like(xyz_0), |
|
} |
|
|
|
w = self.width |
|
h = self.height |
|
center = (0, 0, 0.1) |
|
distance = 0.7 |
|
elevation = 20 |
|
azimuth = 180.0 if self.task_name == 'rope' else 120.0 |
|
target = np.array(center) |
|
theta = 90 + azimuth |
|
z = distance * math.sin(math.radians(elevation)) |
|
y = math.cos(math.radians(theta)) * distance * math.cos(math.radians(elevation)) |
|
x = math.sin(math.radians(theta)) * distance * math.cos(math.radians(elevation)) |
|
origin = target + np.array([x, y, z]) |
|
|
|
look_at = target - origin |
|
look_at /= np.linalg.norm(look_at) |
|
up = np.array([0.0, 0.0, 1.0]) |
|
right = np.cross(look_at, up) |
|
right /= np.linalg.norm(right) |
|
up = np.cross(right, look_at) |
|
w2c = np.eye(4) |
|
w2c[:3, 0] = right |
|
w2c[:3, 1] = -up |
|
w2c[:3, 2] = look_at |
|
w2c[:3, 3] = origin |
|
w2c = np.linalg.inv(w2c) |
|
|
|
k = np.array( |
|
[[w / 2 * 1.0, 0., w / 2], |
|
[0., w / 2 * 1.0, h / 2], |
|
[0., 0., 1.]], |
|
) |
|
self.metadata = {} |
|
self.config = {} |
|
self.update_camera(k, w2c, w, h) |
|
|
|
n_particles = self.cfg.sim.n_particles |
|
downsample_indices = fps(xyz_0, torch.ones_like(xyz_0[:, 0]).to(torch.bool), n_particles, self.torch_device) |
|
p_x_viewer = xyz_0[downsample_indices] |
|
p_x = self.preprocess_x(p_x_viewer) |
|
|
|
self.state['x'] = p_x |
|
self.state['v'] = torch.zeros_like(p_x) |
|
self.state['x_his'] = p_x[None].repeat(self.cfg.sim.n_history, 1, 1) |
|
self.state['v_his'] = torch.zeros_like(p_x[None].repeat(self.cfg.sim.n_history, 1, 1)) |
|
self.state['x_pred'] = p_x |
|
self.state['v_pred'] = torch.zeros_like(p_x) |
|
|
|
rendervar_init, rendervar_init_full = self.update_rendervar(rendervar_init) |
|
im, depth = self.render(rendervar_init_full, 0, bg=[0.0, 0.0, 0.0]) |
|
im_vis = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8) |
|
|
|
return rendervar_init |
|
|
|
def reset(self, task_name, scene_name): |
|
self.init(task_name) |
|
|
|
import warp as wp |
|
wp.init() |
|
gpus = [int(gpu) for gpu in self.cfg.gpus] |
|
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] |
|
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] |
|
device_count = len(torch_devices) |
|
assert device_count == 1 |
|
self.wp_device = wp_devices[0] |
|
self.torch_device = torch_devices[0] |
|
|
|
in_dir = root / f'log/gs/ckpts/{scene_name}' |
|
batch_size = 1 |
|
num_steps = 1 |
|
num_particles = self.cfg.sim.n_particles |
|
self.load_scaniverse(in_dir) |
|
self.init_model(batch_size, num_steps, num_particles, ckpt_path=None) |
|
self.render_direction = False |
|
|
|
params = self.preprocess_gs(self.params) |
|
if self.with_bg: |
|
self.preprocess_bg_gs() |
|
rendervar = self.reset_state(params, visualize_image=False, init=True) |
|
rendervar, rendervar_full = self.update_rendervar(rendervar) |
|
|
|
|
|
im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) |
|
im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() |
|
|
|
cv2.imwrite(str(root / 'log/temp_init/0000.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR)) |
|
|
|
make_video(root / 'log/temp_init', root / f'log/gs/temp/form_video_init.mp4', '%04d.png', 1) |
|
|
|
gs_pred = save_to_splat( |
|
rendervar_full['means3D'].cpu().numpy(), |
|
rendervar_full['colors_precomp'].cpu().numpy(), |
|
rendervar_full['scales'].cpu().numpy(), |
|
rendervar_full['rotations'].cpu().numpy(), |
|
rendervar_full['opacities'].cpu().numpy(), |
|
root / 'log/gs/temp/gs_pred.splat', |
|
rot_rev=True, |
|
) |
|
|
|
for k, v in self.preprocess_metadata.items(): |
|
self.preprocess_metadata[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
|
for k, v in self.state.items(): |
|
self.state[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
|
for k, v in self.params.items(): |
|
if isinstance(v, dict): |
|
for k2, v2 in v.items(): |
|
self.params[k][k2] = v2.detach().cpu() if isinstance(v2, torch.Tensor) else v2 |
|
else: |
|
self.params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
|
self.table_params = tuple( |
|
v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.table_params |
|
) |
|
self.gripper_params = tuple( |
|
v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.gripper_params |
|
) |
|
for k, v in rendervar.items(): |
|
rendervar[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
|
|
|
form_video = gr.Video( |
|
label='Predicted video', |
|
value=root / f'log/gs/temp/form_video_init.mp4', |
|
format='mp4', |
|
width=self.width, |
|
height=self.height, |
|
) |
|
form_3dgs_pred = gr.Model3D( |
|
label='Predicted Gaussian Splats', |
|
height=self.height, |
|
value=root / 'log/gs/temp/gs_pred.splat', |
|
clear_color=[0, 0, 0, 0], |
|
) |
|
|
|
return form_video, form_3dgs_pred, \ |
|
self.preprocess_metadata, self.state, self.params, \ |
|
self.table_params, self.gripper_params, rendervar, task_name |
|
|
|
def run_command(self, unit_command, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
|
self.task_name = task_name |
|
import warp as wp |
|
wp.init() |
|
gpus = [int(gpu) for gpu in self.cfg.gpus] |
|
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] |
|
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] |
|
device_count = len(torch_devices) |
|
assert device_count == 1 |
|
self.wp_device = wp_devices[0] |
|
self.torch_device = torch_devices[0] |
|
os.system('rm -rf ' + str(root / 'log/temp/*')) |
|
|
|
w = 640 |
|
h = 480 |
|
center = (0, 0, 0.1) |
|
distance = 0.7 |
|
elevation = 20 |
|
azimuth = 180.0 if self.task_name == 'rope' else 120.0 |
|
target = np.array(center) |
|
theta = 90 + azimuth |
|
z = distance * math.sin(math.radians(elevation)) |
|
y = math.cos(math.radians(theta)) * distance * math.cos(math.radians(elevation)) |
|
x = math.sin(math.radians(theta)) * distance * math.cos(math.radians(elevation)) |
|
origin = target + np.array([x, y, z]) |
|
|
|
look_at = target - origin |
|
look_at /= np.linalg.norm(look_at) |
|
up = np.array([0.0, 0.0, 1.0]) |
|
right = np.cross(look_at, up) |
|
right /= np.linalg.norm(right) |
|
up = np.cross(right, look_at) |
|
w2c = np.eye(4) |
|
w2c[:3, 0] = right |
|
w2c[:3, 1] = -up |
|
w2c[:3, 2] = look_at |
|
w2c[:3, 3] = origin |
|
w2c = np.linalg.inv(w2c) |
|
|
|
k = np.array( |
|
[[w / 2 * 1.0, 0., w / 2], |
|
[0., w / 2 * 1.0, h / 2], |
|
[0., 0., 1.]], |
|
) |
|
self.update_camera(k, w2c, w, h) |
|
|
|
self.preprocess_metadata = preprocess_metadata |
|
self.state = state |
|
self.params = params |
|
self.table_params = table_params |
|
self.gripper_params = gripper_params |
|
for k, v in self.preprocess_metadata.items(): |
|
self.preprocess_metadata[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v |
|
for k, v in self.state.items(): |
|
self.state[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v |
|
for k, v in self.params.items(): |
|
if isinstance(v, dict): |
|
for k2, v2 in v.items(): |
|
self.params[k][k2] = v2.to(self.torch_device) if isinstance(v2, torch.Tensor) else v2 |
|
else: |
|
self.params[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v |
|
self.table_params = tuple( |
|
v.to(self.torch_device) if isinstance(v, torch.Tensor) else v for v in self.table_params |
|
) |
|
self.gripper_params = tuple( |
|
v.to(self.torch_device) if isinstance(v, torch.Tensor) else v for v in self.gripper_params |
|
) |
|
for k, v in rendervar.items(): |
|
rendervar[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v |
|
|
|
num_steps = 15 |
|
batch_size = 1 |
|
num_particles = self.cfg.sim.n_particles |
|
self.init_model(batch_size, num_steps, num_particles, ckpt_path=None) |
|
self.render_direction = True |
|
|
|
|
|
for i in range(num_steps): |
|
dt = 0.1 |
|
command = torch.tensor([unit_command]).to(self.device).to(torch.float32) |
|
command = self.preprocess_gripper(command) |
|
|
|
|
|
if self.verbose: |
|
print('command:', command.cpu().numpy().tolist()) |
|
|
|
self.gripper_direction = command.clone() |
|
|
|
assert self.state['sub_pos'] is None |
|
|
|
if self.state['sub_pos'] is None: |
|
eef_xyz_latest = self.state['prev_key_pos'] |
|
|
|
|
|
else: |
|
eef_xyz_latest = self.state['sub_pos'][-1] |
|
|
|
|
|
eef_xyz_updated = eef_xyz_latest + command * dt * 0.01 |
|
|
|
if self.state['sub_pos'] is None: |
|
self.state['sub_pos'] = eef_xyz_updated[None] |
|
|
|
else: |
|
self.state['sub_pos'] = torch.cat([self.state['sub_pos'], eef_xyz_updated[None]], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.step() |
|
rendervar, rendervar_full = self.update_rendervar(rendervar) |
|
|
|
im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) |
|
im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() |
|
|
|
|
|
cv2.imwrite(str(root / f'log/temp/{i:04}.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR)) |
|
|
|
|
|
self.state['v'] *= 0.0 |
|
self.state['x'] = self.state['x_pred'].clone() |
|
self.state['x_his'] = self.state['x'][None].repeat(self.cfg.sim.n_history, 1, 1) |
|
self.state['v_his'] *= 0.0 |
|
self.state['v_pred'] *= 0.0 |
|
|
|
for k, v in self.preprocess_metadata.items(): |
|
self.preprocess_metadata[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
|
for k, v in self.state.items(): |
|
self.state[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
|
for k, v in self.params.items(): |
|
if isinstance(v, dict): |
|
for k2, v2 in v.items(): |
|
self.params[k][k2] = v2.detach().cpu() if isinstance(v2, torch.Tensor) else v2 |
|
else: |
|
self.params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
|
self.table_params = tuple( |
|
v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.table_params |
|
) |
|
self.gripper_params = tuple( |
|
v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.gripper_params |
|
) |
|
for k, v in rendervar.items(): |
|
rendervar[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v |
|
|
|
make_video(root / 'log/temp', root / f'log/gs/temp/form_video.mp4', '%04d.png', 5) |
|
|
|
form_video = gr.Video( |
|
label='Predicted video', |
|
value=root / f'log/gs/temp/form_video.mp4', |
|
format='mp4', |
|
width=self.width, |
|
height=self.height, |
|
) |
|
|
|
im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) |
|
im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() |
|
|
|
gs_pred = save_to_splat( |
|
rendervar_full['means3D'].cpu().numpy(), |
|
rendervar_full['colors_precomp'].cpu().numpy(), |
|
rendervar_full['scales'].cpu().numpy(), |
|
rendervar_full['rotations'].cpu().numpy(), |
|
rendervar_full['opacities'].cpu().numpy(), |
|
root / 'log/gs/temp/gs_pred.splat', |
|
rot_rev=True, |
|
) |
|
form_3dgs_pred = gr.Model3D( |
|
label='Predicted Gaussian Splats', |
|
height=self.height, |
|
value=root / 'log/gs/temp/gs_pred.splat', |
|
clear_color=[0, 0, 0, 0], |
|
) |
|
return form_video, form_3dgs_pred, \ |
|
self.preprocess_metadata, self.state, self.params, \ |
|
self.table_params, self.gripper_params, rendervar, task_name |
|
|
|
@spaces.GPU |
|
def reset_rope(self): |
|
return self.reset('rope', 'rope_scene_1') |
|
|
|
@spaces.GPU |
|
def reset_plush(self): |
|
return self.reset('sloth', 'sloth_scene_1') |
|
|
|
@spaces.GPU |
|
def on_click_run_xplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
|
return self.run_command([5.0, 0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) |
|
|
|
@spaces.GPU |
|
def on_click_run_xminus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
|
return self.run_command([-5.0, 0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) |
|
|
|
@spaces.GPU |
|
def on_click_run_yplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
|
return self.run_command([0, 5.0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) |
|
|
|
@spaces.GPU |
|
def on_click_run_yminus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
|
return self.run_command([0, -5.0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) |
|
|
|
@spaces.GPU |
|
def on_click_run_zplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
|
return self.run_command([0, 0, 5.0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) |
|
|
|
@spaces.GPU |
|
def on_click_run_zminus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name): |
|
return self.run_command([0, 0, -5.0], preprocess_metadata, state, params, table_params, gripper_params, rendervar, task_name) |
|
|
|
def launch(self, share=False): |
|
|
|
with gr.Blocks() as app: |
|
preprocess_metadata = gr.State(self.preprocess_metadata) |
|
state = gr.State(self.state) |
|
params = gr.State(self.params) |
|
table_params = gr.State(self.table_params) |
|
gripper_params = gr.State(self.gripper_params) |
|
rendervar = gr.State(None) |
|
task_name = gr.State(self.task_name) |
|
|
|
with gr.Row(): |
|
gr.Markdown("# Particle-Grid Neural Dynamics for Learning Deformable Object Models from RGB-D Videos") |
|
|
|
with gr.Row(): |
|
gr.Markdown('### Project page: [https://kywind.github.io/pgnd](https://kywind.github.io/pgnd)') |
|
|
|
with gr.Row(): |
|
gr.Markdown('### Instructions:') |
|
|
|
with gr.Row(): |
|
gr.Markdown(' '.join([ |
|
'- Click the "Reset-\<object\>" button to initialize the simulation with the predicted video and Gaussian splats. Due to compute limitations of Huggingface Space, each run may take a prolonged period (up to 30 seconds).\n', |
|
'- Use the buttons to move the gripper in the x, y, z directions. The gripper will move for a fixed length per click. The predicted video and Gaussian splats will be updated accordingly.\n', |
|
'- X-Y plane is the table surface, and Z is the height.\n', |
|
'- The predicted video from the previous step to the current step will be shown in the "Predicted video" section.\n', |
|
'- The Gaussian splats after the current step will be shown in the "Predicted Gaussians" section.\n', |
|
'- The simulation results may deviate from the initial shape due to accumulative prediction artifacts. Click the Reset button to reset the simulation state and reinitialize the predicted video and Gaussian splats.\n', |
|
])) |
|
|
|
with gr.Row(): |
|
gr.Markdown('### Select a scene to reset the simulation:') |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
with gr.Column(): |
|
run_reset_plush = gr.Button("Reset - Plush") |
|
with gr.Column(): |
|
run_reset_rope = gr.Button("Reset - Rope") |
|
|
|
with gr.Column(scale=2): |
|
_ = gr.Button(visible=False) |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=2): |
|
form_video = gr.Video( |
|
label='Predicted video', |
|
value=None, |
|
format='mp4', |
|
width=self.width, |
|
height=self.height, |
|
) |
|
|
|
with gr.Column(scale=2): |
|
form_3dgs_pred = gr.Model3D( |
|
label='Predicted Gaussians', |
|
height=self.height, |
|
value=None, |
|
clear_color=[0, 0, 0, 0], |
|
) |
|
|
|
|
|
with gr.Row(): |
|
gr.Markdown('### Control the gripper to move in the x, y, z directions:') |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
run_xminus = gr.Button("x-") |
|
with gr.Column(): |
|
run_xplus = gr.Button("x+") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
run_yminus = gr.Button("y-") |
|
with gr.Column(): |
|
run_yplus = gr.Button("y+") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
run_zminus = gr.Button("z-") |
|
with gr.Column(): |
|
run_zplus = gr.Button("z+") |
|
|
|
with gr.Column(scale=2): |
|
_ = gr.Button(visible=False) |
|
|
|
|
|
run_reset_rope.click(self.reset_rope, |
|
inputs=[], |
|
outputs=[form_video, form_3dgs_pred, |
|
preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name]) |
|
|
|
run_reset_plush.click(self.reset_plush, |
|
inputs=[], |
|
outputs=[form_video, form_3dgs_pred, |
|
preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name]) |
|
|
|
run_xplus.click(self.on_click_run_xplus, |
|
inputs=[preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name], |
|
outputs=[form_video, form_3dgs_pred, |
|
preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name]) |
|
|
|
run_xminus.click(self.on_click_run_xminus, |
|
inputs=[preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name], |
|
outputs=[form_video, form_3dgs_pred, |
|
preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name]) |
|
|
|
run_yplus.click(self.on_click_run_yplus, |
|
inputs=[preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name], |
|
outputs=[form_video, form_3dgs_pred, |
|
preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name]) |
|
|
|
run_yminus.click(self.on_click_run_yminus, |
|
inputs=[preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name], |
|
outputs=[form_video, form_3dgs_pred, |
|
preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name]) |
|
|
|
run_zplus.click(self.on_click_run_zplus, |
|
inputs=[preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name], |
|
outputs=[form_video, form_3dgs_pred, |
|
preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name]) |
|
|
|
run_zminus.click(self.on_click_run_zminus, |
|
inputs=[preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name], |
|
outputs=[form_video, form_3dgs_pred, |
|
preprocess_metadata, state, params, |
|
table_params, gripper_params, rendervar, task_name]) |
|
|
|
app.launch(share=share) |
|
|
|
|
|
if __name__ == '__main__': |
|
visualizer = DynamicsVisualizer() |
|
visualizer.launch(share=True) |
|
|