|
from pathlib import Path |
|
import open3d as o3d |
|
from threadpoolctl import threadpool_limits |
|
import multiprocess as mp |
|
import threading |
|
from threading import Lock |
|
|
|
import os |
|
import numpy as np |
|
from copy import deepcopy |
|
from functools import partial |
|
import torch |
|
from torch import Tensor |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import transforms3d |
|
import kornia |
|
|
|
from pgnd.utils import get_root |
|
root: Path = get_root(__file__) |
|
|
|
from modules_planning.dynamics_module import DynamicsModule, fps |
|
from utils.planning_utils import batch_chamfer_dist |
|
|
|
|
|
class PlanningModule: |
|
|
|
def __init__(self, |
|
cfg, |
|
device, |
|
exp_root, |
|
ckpt_path, |
|
use_robot=False, |
|
bimanual=True, |
|
bbox=None, |
|
eef_point=None, |
|
debug=False, |
|
): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.exp_root = exp_root |
|
self.ckpt_path = ckpt_path |
|
self.bimanual = bimanual |
|
self.use_robot = use_robot |
|
self.debug = debug |
|
self.eef_point = eef_point |
|
self.torch_device = device |
|
|
|
assert bbox is not None |
|
self.bbox = torch.tensor(bbox, dtype=torch.float32, device=self.torch_device) |
|
|
|
self.repeated_action = False |
|
|
|
self.n_look_ahead = 10 |
|
self.n_sample = 32 |
|
self.n_sample_chunk = 32 |
|
self.n_chunk = np.ceil(self.n_sample / self.n_sample_chunk).astype(int) |
|
self.n_update_iter = 20 |
|
|
|
self.reward_weight = 10.0 |
|
|
|
self.target_state = torch.empty(0) |
|
|
|
self.state = None |
|
self.pts = torch.empty(0) |
|
self.eef_xyz = torch.empty(0) |
|
self.eef_rot = torch.empty(0) |
|
self.eef_gripper = torch.empty(0) |
|
|
|
self.dynamics_module = DynamicsModule(cfg, exp_root=exp_root, ckpt_path=ckpt_path, batch_size=self.n_sample_chunk, num_steps_total=self.n_look_ahead) |
|
self.dynamics_module.reset_model(self.n_look_ahead) |
|
self.xyz_noise_level = 0.002 |
|
self.quat_noise_level = 0.001 |
|
self.gripper_noise_level = 0.0 |
|
|
|
def set_target(self, pcd_path): |
|
pcd = o3d.io.read_point_cloud(pcd_path) |
|
target_state = np.array(pcd.points) |
|
if len(target_state) == 0: |
|
print('target state is empty') |
|
return |
|
target_state = torch.tensor(target_state, dtype=torch.float32, device=self.torch_device) |
|
fps_idx = fps(target_state, 1000, device=self.torch_device, random_start=False) |
|
target_state = target_state[fps_idx] |
|
self.target_state = target_state.clone() |
|
self.dynamics_module.set_target_state(target_state) |
|
|
|
def model_rollout(self, act_seqs, visualize_pv=False): |
|
|
|
pts = self.pts.clone() |
|
n_grippers = self.cfg.sim.num_grippers |
|
|
|
n_sample = act_seqs.shape[0] |
|
eef_xyz = act_seqs[:, :, :n_grippers * 3].reshape(n_sample, self.n_look_ahead, n_grippers, 3) |
|
eef_rot = act_seqs[:, :, n_grippers * 3:n_grippers * (3 + 3 * 3)].reshape(n_sample, self.n_look_ahead, n_grippers, 3, 3) |
|
eef_gripper = act_seqs[:, :, n_grippers * (3 + 3 * 3):].reshape(n_sample, self.n_look_ahead, n_grippers, 1) |
|
|
|
eef_xyz_now = self.eef_xyz.clone() |
|
eef_rot_now = self.eef_rot.clone() |
|
eef_gripper_now = self.eef_gripper.clone() |
|
eef_xyz = torch.cat([eef_xyz_now.repeat(n_sample, 1, 1, 1), eef_xyz], dim=1) |
|
eef_rot = torch.cat([eef_rot_now.repeat(n_sample, 1, 1, 1, 1), eef_rot], dim=1) |
|
eef_gripper = torch.cat([eef_gripper_now[:, None].repeat(n_sample, 1, 1, 1), eef_gripper], dim=1) |
|
assert eef_xyz.shape[1] == eef_rot.shape[1] == eef_gripper.shape[1] == self.n_look_ahead + 1 |
|
|
|
x, v = self.dynamics_module.rollout(pts, eef_xyz, eef_rot, eef_gripper, pts_his=None, visualize_pv=visualize_pv) |
|
model_out = { |
|
'x': x, |
|
} |
|
return model_out |
|
|
|
|
|
def sample_action_seq(self, act_seq, iter_index=0): |
|
|
|
n_grippers = self.cfg.sim.num_grippers |
|
eef_xyz = act_seq[:, :n_grippers * 3].reshape(self.n_look_ahead, n_grippers, 3) |
|
eef_rot = act_seq[:, n_grippers * 3:n_grippers * (3 + 3 * 3)].reshape(self.n_look_ahead, n_grippers, 3, 3) |
|
eef_quat = kornia.geometry.conversions.rotation_matrix_to_quaternion(eef_rot) |
|
eef_gripper = act_seq[:, n_grippers * (3 + 3 * 3):].reshape(self.n_look_ahead, n_grippers, 1) |
|
|
|
|
|
n_sample = self.n_sample_chunk |
|
|
|
if self.repeated_action: |
|
eef_xyz_delta = torch.randn((n_sample, eef_xyz.shape[1], eef_xyz.shape[2]), device=self.torch_device, dtype=torch.float32) * self.xyz_noise_level |
|
eef_quat_delta = torch.randn((n_sample, eef_quat.shape[1], eef_quat.shape[2]), device=self.torch_device, dtype=torch.float32) * self.quat_noise_level |
|
eef_gripper_delta = torch.randn((n_sample, eef_gripper.shape[1], eef_gripper.shape[2]), device=self.torch_device, dtype=torch.float32) * self.gripper_noise_level |
|
eef_xyz_delta = eef_xyz_delta[:, None].repeat(1, self.n_look_ahead, 1, 1) |
|
eef_quat_delta = eef_quat_delta[:, None].repeat(1, self.n_look_ahead, 1, 1) |
|
eef_gripper_delta = eef_gripper_delta[:, None].repeat(1, self.n_look_ahead, 1, 1) |
|
|
|
else: |
|
n_parts = 4 |
|
eef_xyz_delta_list = [] |
|
eef_quat_delta_list = [] |
|
eef_gripper_delta_list = [] |
|
for p in range(n_parts): |
|
p_len = self.n_look_ahead // n_parts if p < n_parts - 1 else self.n_look_ahead - (n_parts - 1) * (self.n_look_ahead // n_parts) |
|
eef_xyz_delta = torch.randn((n_sample, eef_xyz.shape[1], eef_xyz.shape[2]), device=self.torch_device, dtype=torch.float32) * self.xyz_noise_level * (1. / (iter_index + 1) ** 1) |
|
eef_quat_delta = torch.randn((n_sample, eef_quat.shape[1], eef_quat.shape[2]), device=self.torch_device, dtype=torch.float32) * self.quat_noise_level * (1. / (iter_index + 1) ** 1) |
|
eef_gripper_delta = torch.randn((n_sample, eef_gripper.shape[1], eef_gripper.shape[2]), device=self.torch_device, dtype=torch.float32) * self.gripper_noise_level * (1. / (iter_index + 1) ** 1) |
|
eef_xyz_delta = eef_xyz_delta[:, None].repeat(1, p_len, 1, 1) |
|
eef_quat_delta = eef_quat_delta[:, None].repeat(1, p_len, 1, 1) |
|
eef_gripper_delta = eef_gripper_delta[:, None].repeat(1, p_len, 1, 1) |
|
eef_xyz_delta_list.append(eef_xyz_delta) |
|
eef_quat_delta_list.append(eef_quat_delta) |
|
eef_gripper_delta_list.append(eef_gripper_delta) |
|
eef_xyz_delta = torch.cat(eef_xyz_delta_list, dim=1) |
|
eef_quat_delta = torch.cat(eef_quat_delta_list, dim=1) |
|
eef_gripper_delta = torch.cat(eef_gripper_delta_list, dim=1) |
|
|
|
eef_xyz_delta_cum = torch.cumsum(eef_xyz_delta, dim=1) |
|
eef_quat_delta_cum = torch.cumsum(eef_quat_delta, dim=1) |
|
eef_gripper_delta_cum = torch.cumsum(eef_gripper_delta, dim=1) |
|
|
|
eef_xyz = eef_xyz[None] + eef_xyz_delta_cum |
|
eef_quat = eef_quat[None] + eef_quat_delta_cum |
|
eef_gripper = eef_gripper[None] + eef_gripper_delta_cum |
|
|
|
eef_quat = eef_quat / (torch.norm(eef_quat, dim=-1, keepdim=True) + 1e-6) |
|
eef_rot = kornia.geometry.conversions.quaternion_to_rotation_matrix(eef_quat) |
|
|
|
act_seqs = torch.zeros((n_sample, self.n_look_ahead, n_grippers * (3 + 9 + 1)), device=self.torch_device, dtype=torch.float32) |
|
act_seqs[:, :, :n_grippers * 3] = eef_xyz.reshape(n_sample, self.n_look_ahead, -1) |
|
act_seqs[:, :, n_grippers * 3:n_grippers * (3 + 3 * 3)] = eef_rot.reshape(n_sample, self.n_look_ahead, -1) |
|
act_seqs[:, :, n_grippers * (3 + 3 * 3):] = eef_gripper.reshape(n_sample, self.n_look_ahead, -1) |
|
|
|
self.clip_actions(act_seqs) |
|
return act_seqs |
|
|
|
def evaluate_traj(self, model_out, act_seqs): |
|
target_state = self.target_state.clone() |
|
x = model_out['x'] |
|
chamfer_distance = batch_chamfer_dist(x[:, -1], target_state) |
|
curr_chamfer_distance = batch_chamfer_dist(x[:, 0], target_state) |
|
print('curr chamfer_distance:', curr_chamfer_distance.min().item(), end=' ') |
|
print('best chamfer_distance:', chamfer_distance.min().item()) |
|
|
|
n_sample = self.n_sample_chunk |
|
n_grippers = self.cfg.sim.num_grippers |
|
assert act_seqs.shape[0] == n_sample |
|
eef_xyz = act_seqs[:, :, :n_grippers * 3].reshape(n_sample, self.n_look_ahead, n_grippers, 3) |
|
|
|
if eef_xyz.shape[2] == 2: |
|
eef_xyz_left = eef_xyz[:, :, 0] |
|
eef_xyz_right = eef_xyz[:, :, 1] |
|
eef_xyz_dist = torch.norm(eef_xyz_left - eef_xyz_right, dim=-1) |
|
eef_xyz_dist_penalty = (eef_xyz_dist.max(dim=1).values > 0.3).float() * 100.0 |
|
|
|
eef_height_penalty = torch.logical_or(eef_xyz_left[:, :, 2].max(dim=1).values > -0.02, eef_xyz_right[:, :, 2].max(dim=1).values > -0.02).to(torch.float32) * 100.0 |
|
eef_height_penalty += torch.logical_or( |
|
(eef_xyz_left.max(dim=1).values > (self.bbox[:, 1] - 0.02)).any(dim=-1), |
|
(eef_xyz_left.min(dim=1).values < (self.bbox[:, 0] + 0.02)).any(dim=-1) |
|
).to(torch.float32) * 100.0 |
|
eef_height_penalty += torch.logical_or( |
|
(eef_xyz_right.max(dim=1).values > (self.bbox[:, 1] - 0.02)).any(dim=-1), |
|
(eef_xyz_right.min(dim=1).values < (self.bbox[:, 0] + 0.02)).any(dim=-1) |
|
).to(torch.float32) * 100.0 |
|
|
|
reward = -chamfer_distance - eef_xyz_dist_penalty - eef_height_penalty |
|
else: |
|
assert eef_xyz.shape[2] == 1 |
|
eef_xyz_dist_penalty = 0 |
|
eef_height_penalty = (eef_xyz[:, :, 0, 2].max(dim=1).values > -0.02).to(torch.float32) * 100.0 |
|
eef_height_penalty += torch.logical_or( |
|
(eef_xyz[:, :, 0].max(dim=1).values > (self.bbox[:, 1] - 0.02)).any(dim=-1), |
|
(eef_xyz[:, :, 0].min(dim=1).values < (self.bbox[:, 0] + 0.02)).any(dim=-1) |
|
).to(torch.float32) * 100.0 |
|
reward = -chamfer_distance - eef_xyz_dist_penalty - eef_height_penalty |
|
|
|
print('best reward:', reward.max().item()) |
|
eval_out = { |
|
'reward_seqs': reward, |
|
} |
|
return eval_out |
|
|
|
def optimize_action_mppi(self, act_seqs, reward_seqs): |
|
|
|
weight_seqs = F.softmax(reward_seqs * self.reward_weight, dim=0) |
|
assert len(weight_seqs.shape) == 1 and weight_seqs.shape[0] == self.n_sample_chunk |
|
|
|
n_sample = self.n_sample_chunk |
|
n_grippers = self.cfg.sim.num_grippers |
|
assert act_seqs.shape[0] == n_sample |
|
eef_xyz = act_seqs[:, :, :n_grippers * 3].reshape(n_sample, self.n_look_ahead, n_grippers, 3) |
|
eef_rot = act_seqs[:, :, n_grippers * 3:n_grippers * (3 + 3 * 3)].reshape(n_sample, self.n_look_ahead, n_grippers, 3, 3) |
|
eef_quat = kornia.geometry.conversions.rotation_matrix_to_quaternion(eef_rot) |
|
eef_gripper = act_seqs[:, :, n_grippers * (3 + 3 * 3):].reshape(n_sample, self.n_look_ahead, n_grippers, 1) |
|
|
|
eef_xyz = torch.sum(weight_seqs[:, None, None, None] * eef_xyz, dim=0) |
|
eef_gripper = torch.sum(weight_seqs[:, None, None, None] * eef_gripper, dim=0) |
|
eef_quat = torch.sum(weight_seqs[:, None, None, None] * eef_quat, dim=0) |
|
|
|
eef_quat = eef_quat / (torch.norm(eef_quat, dim=-1, keepdim=True) + 1e-6) |
|
|
|
eef_rot = kornia.geometry.conversions.quaternion_to_rotation_matrix(eef_quat) |
|
act_seq = torch.zeros((self.n_look_ahead, n_grippers * (3 + 9 + 1)), device=self.torch_device, dtype=torch.float32) |
|
act_seq[:, :n_grippers * 3] = eef_xyz.reshape(self.n_look_ahead, -1) |
|
act_seq[:, n_grippers * 3:n_grippers * (3 + 3 * 3)] = eef_rot.reshape(self.n_look_ahead, -1) |
|
act_seq[:, n_grippers * (3 + 3 * 3):] = eef_gripper.reshape(self.n_look_ahead, -1) |
|
|
|
act_seq = self.clip_actions(act_seq[None])[0] |
|
return act_seq |
|
|
|
def clip_actions(self, act_seqs): |
|
no_sample_dim = False |
|
if len(act_seqs.shape) == 2: |
|
no_sample_dim = True |
|
act_seqs = act_seqs[None] |
|
n_sample = act_seqs.shape[0] |
|
n_grippers = self.cfg.sim.num_grippers |
|
eef_xyz = act_seqs[:, :, :n_grippers * 3].reshape(n_sample, self.n_look_ahead, n_grippers, 3) |
|
eef_rot = act_seqs[:, :, n_grippers * 3:n_grippers * (3 + 3 * 3)].reshape(n_sample, self.n_look_ahead, n_grippers, 3, 3) |
|
eef_quat = kornia.geometry.conversions.rotation_matrix_to_quaternion(eef_rot) |
|
eef_gripper = act_seqs[:, :, n_grippers * (3 + 3 * 3):].reshape(n_sample, self.n_look_ahead, n_grippers, 1) |
|
|
|
eef_xyz = torch.clamp(eef_xyz, self.bbox[:, 0], self.bbox[:, 1]) |
|
eef_aa = kornia.geometry.conversions.quaternion_to_axis_angle(eef_quat) |
|
max_rad = 1.0 |
|
eef_aa_mask = torch.norm(eef_aa, dim=-1) > max_rad |
|
eef_aa[eef_aa_mask] = eef_aa[eef_aa_mask] / torch.norm(eef_aa[eef_aa_mask], dim=-1, keepdim=True) * max_rad |
|
eef_quat = kornia.geometry.conversions.axis_angle_to_quaternion(eef_aa) |
|
eef_quat = eef_quat / (torch.norm(eef_quat, dim=-1, keepdim=True) + 1e-6) |
|
eef_gripper = torch.clamp(eef_gripper, 0.0, 1.0) |
|
|
|
eef_rot = kornia.geometry.conversions.quaternion_to_rotation_matrix(eef_quat) |
|
act_seqs[:, :, :n_grippers * 3] = eef_xyz.reshape(n_sample, self.n_look_ahead, -1) |
|
act_seqs[:, :, n_grippers * 3:n_grippers * (3 + 3 * 3)] = eef_rot.reshape(n_sample, self.n_look_ahead, -1) |
|
act_seqs[:, :, n_grippers * (3 + 3 * 3):] = eef_gripper.reshape(n_sample, self.n_look_ahead, -1) |
|
if no_sample_dim: |
|
assert act_seqs.shape[0] == 1 |
|
act_seqs = act_seqs[0] |
|
return act_seqs |
|
|
|
|
|
def get_command(self, state, save_dir=None, is_first_iter=False): |
|
|
|
self.state = state |
|
|
|
best_act_seq = None |
|
best_reward_seq = None |
|
|
|
pts = state["perception_out"]["value"]["pts"].copy() |
|
pts = np.concatenate(pts, axis=0) |
|
|
|
|
|
rm_outliers = False |
|
if rm_outliers: |
|
pcd_rm = o3d.geometry.PointCloud() |
|
pcd_rm.points = o3d.utility.Vector3dVector(pts) |
|
outliers = None |
|
new_outlier = None |
|
rm_iter = 0 |
|
while new_outlier is None or len(new_outlier.points) > 0: |
|
_, inlier_idx = pcd_rm.remove_statistical_outlier( |
|
nb_neighbors = 30, std_ratio = 2.5 + rm_iter * 0.5 |
|
) |
|
new_pcd = pcd_rm.select_by_index(inlier_idx) |
|
new_outlier = pcd_rm.select_by_index(inlier_idx, invert=True) |
|
if outliers is None: |
|
outliers = new_outlier |
|
else: |
|
outliers += new_outlier |
|
pcd_rm = new_pcd |
|
rm_iter += 1 |
|
pts = np.array(pcd_rm.points) |
|
|
|
pts = torch.tensor(pts, dtype=torch.float32, device=self.torch_device) |
|
|
|
if is_first_iter: |
|
self.dynamics_module.reset_preprocess_meta(pts) |
|
self.dynamics_module.reset_downsample_indices(pts) |
|
|
|
self.pts = pts |
|
|
|
if save_dir is not None: |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
chamfer_now = batch_chamfer_dist(pts[None], self.target_state.clone()) |
|
print('chamfer_now:', chamfer_now.item()) |
|
with open(Path(save_dir) / 'chamfer.txt', 'w') as f: |
|
f.write(str(chamfer_now.item())) |
|
|
|
pts_save = pts.cpu().numpy() |
|
o3d_pts = o3d.geometry.PointCloud() |
|
o3d_pts.points = o3d.utility.Vector3dVector(pts_save) |
|
o3d.io.write_point_cloud(str(Path(save_dir) / 'pts.ply'), o3d_pts) |
|
|
|
target_state_save = self.target_state.cpu().numpy() |
|
o3d_target_state = o3d.geometry.PointCloud() |
|
o3d_target_state.points = o3d.utility.Vector3dVector(target_state_save) |
|
o3d.io.write_point_cloud(str(Path(save_dir) / 'target_state.ply'), o3d_target_state) |
|
|
|
if self.bimanual: |
|
left_robot_out = state["robot_out"]["left_value"] |
|
left_gripper_out = state["gripper_out"]["left_value"] |
|
right_robot_out = state["robot_out"]["right_value"] |
|
right_gripper_out = state["gripper_out"]["right_value"] |
|
left_robot_out = torch.tensor(left_robot_out, dtype=torch.float32, device=self.torch_device) |
|
left_gripper_out = torch.tensor(left_gripper_out, dtype=torch.float32, device=self.torch_device) |
|
right_robot_out = torch.tensor(right_robot_out, dtype=torch.float32, device=self.torch_device) |
|
right_gripper_out = torch.tensor(right_gripper_out, dtype=torch.float32, device=self.torch_device) |
|
robot_out = None |
|
gripper_out = None |
|
else: |
|
robot_out = state["robot_out"]["value"] |
|
gripper_out = state["gripper_out"]["value"] |
|
robot_out = torch.tensor(robot_out, dtype=torch.float32, device=self.torch_device) |
|
gripper_out = torch.tensor(gripper_out, dtype=torch.float32, device=self.torch_device) |
|
left_robot_out = None |
|
left_gripper_out = None |
|
right_robot_out = None |
|
right_gripper_out = None |
|
|
|
if self.bimanual: |
|
b2w_l = state["b2w_l"] |
|
b2w_r = state["b2w_r"] |
|
b2w_l = torch.tensor(b2w_l, dtype=torch.float32, device=self.torch_device) |
|
b2w_r = torch.tensor(b2w_r, dtype=torch.float32, device=self.torch_device) |
|
b2w = None |
|
else: |
|
b2w = state["b2w"] |
|
b2w = torch.tensor(b2w, dtype=torch.float32, device=self.torch_device) |
|
b2w_l = None |
|
b2w_r = None |
|
|
|
|
|
|
|
eef_xyz = torch.zeros((self.n_look_ahead + 1, self.cfg.sim.num_grippers, 3), device=self.torch_device) |
|
eef_quat = torch.zeros((self.n_look_ahead + 1, self.cfg.sim.num_grippers, 4), device=self.torch_device) |
|
eef_quat[:, :, 0] = 1.0 |
|
eef_gripper = torch.zeros((self.n_look_ahead + 1, 1), device=self.torch_device) |
|
|
|
|
|
eef_points = torch.tensor([[0.0, 0.0, 0.175, 1]], device=self.torch_device) |
|
eef_axis = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]], dtype=torch.float32, device=self.torch_device) |
|
if self.bimanual: |
|
left_eef_world_list = [] |
|
right_eef_world_list = [] |
|
assert left_robot_out is not None and right_robot_out is not None |
|
assert b2w_l is not None and b2w_r is not None |
|
for e2b, b2w, eef_world_list in zip([left_robot_out, right_robot_out], [b2w_l, b2w_r], [left_eef_world_list, right_eef_world_list]): |
|
eef_points_world = (b2w @ e2b @ eef_points.T).T[:, :3] |
|
|
|
|
|
eef_orientation_world = (b2w[:3, :3] @ e2b[:3, :3] @ eef_axis[:, :3].T).T |
|
eef_world = torch.cat([eef_points_world, eef_orientation_world], dim=0) |
|
eef_world_list.append(eef_world) |
|
left_eef_world = torch.cat(left_eef_world_list, dim=0) |
|
right_eef_world = torch.cat(right_eef_world_list, dim=0) |
|
eef_world = torch.cat([left_eef_world, right_eef_world], dim=0) |
|
else: |
|
assert robot_out is not None |
|
assert b2w is not None |
|
e2b = robot_out |
|
eef_points_world = (b2w @ e2b @ eef_points.T).T[:, :3] |
|
|
|
|
|
eef_orientation_world = (b2w[:3, :3] @ e2b[:3, :3] @ eef_axis[:, :3].T).T |
|
eef_world = torch.cat([eef_points_world, eef_orientation_world], dim=0) |
|
|
|
if self.bimanual: |
|
assert left_gripper_out is not None and right_gripper_out is not None |
|
gripper_world = torch.tensor([left_gripper_out, right_gripper_out, 0.0], device=self.torch_device)[None, :] |
|
else: |
|
assert gripper_out is not None |
|
gripper_world = torch.tensor([gripper_out, 0.0, 0.0], device=self.torch_device)[None, :] |
|
|
|
eef_world = torch.cat([eef_world, gripper_world], dim=0) |
|
robot = eef_world |
|
|
|
|
|
if len(robot.shape) > 1: |
|
assert robot.shape[0] in [5, 9] |
|
gripper = robot[-1] |
|
robot = robot[:-1] |
|
robot = robot.reshape(-1, 4, 3) |
|
robot_trans = robot[:, 0] |
|
robot_rot = robot[:, 1:] |
|
if robot_trans.shape[0] == 1: |
|
gripper = gripper[:1] |
|
else: |
|
gripper = gripper[:2] |
|
else: |
|
assert len(robot.shape) == 1 and robot.shape[0] == 3 |
|
robot_trans = robot |
|
robot_rot = torch.eye(3).to(self.torch_device).to(torch.float32) |
|
gripper = torch.tensor([0.0], device=self.torch_device).to(torch.float32) |
|
robot_trans = robot_trans |
|
gripper = torch.clamp(gripper / 800.0, 0, 1) |
|
|
|
|
|
eef_xyz = robot_trans.reshape(-1, 3) |
|
eef_rot = robot_rot.reshape(-1, 3, 3) |
|
eef_gripper = gripper.reshape(-1) |
|
assert eef_xyz.shape[0] == eef_rot.shape[0] == eef_gripper.shape[0] |
|
n_grippers = eef_xyz.shape[0] |
|
|
|
self.eef_xyz = eef_xyz |
|
self.eef_rot = eef_rot |
|
self.eef_gripper = eef_gripper |
|
|
|
|
|
act_seq = torch.zeros((self.n_look_ahead, n_grippers * (3 + 9 + 1)), device=self.torch_device) |
|
act_seq[:, :n_grippers * 3] = eef_xyz.reshape(-1) |
|
act_seq[:, n_grippers * 3:n_grippers * (3 + 3 * 3)] = eef_rot.reshape(-1) |
|
act_seq[:, n_grippers * (3 + 3 * 3):] = eef_gripper.reshape(-1) |
|
|
|
res_all = [] |
|
for ci in range(self.n_chunk): |
|
|
|
best_act_seq = act_seq |
|
for ti in range(self.n_update_iter): |
|
|
|
print(f'chunk: {ci}/{self.n_chunk}, iter: {ti}/{self.n_update_iter}') |
|
|
|
with torch.no_grad(): |
|
act_seqs = self.sample_action_seq(act_seq, iter_index=ti) |
|
model_out = self.model_rollout(act_seqs, visualize_pv=False) |
|
eval_out = self.evaluate_traj(model_out, act_seqs) |
|
reward_seqs = eval_out["reward_seqs"] |
|
act_seq = self.optimize_action_mppi(act_seqs, reward_seqs) |
|
|
|
best_reward_idx = torch.argmax(reward_seqs) |
|
|
|
if ti == 0: |
|
best_act_seq = act_seqs[best_reward_idx] |
|
best_reward_seq = reward_seqs[best_reward_idx] |
|
elif reward_seqs[best_reward_idx] > best_reward_seq: |
|
best_act_seq = act_seqs[best_reward_idx] |
|
best_reward_seq = reward_seqs[best_reward_idx] |
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
res = { |
|
"best_act_seq": best_act_seq, |
|
"best_reward_seq": best_reward_seq, |
|
} |
|
res_all.append(res) |
|
|
|
reward_list = [res["best_reward_seq"].item() for res in res_all] |
|
best_idx = np.argmax(reward_list) |
|
best_act_seq = res_all[best_idx]['best_act_seq'] |
|
|
|
torch.cuda.empty_cache() |
|
|
|
eef_xyz = best_act_seq[:, :n_grippers * 3].reshape(self.n_look_ahead, n_grippers, 3) |
|
eef_rot = best_act_seq[:, n_grippers * 3:n_grippers * (3 + 3 * 3)].reshape(self.n_look_ahead, n_grippers, 3, 3) |
|
eef_gripper = best_act_seq[:, n_grippers * (3 + 3 * 3):].reshape(self.n_look_ahead, n_grippers, 1) |
|
|
|
|
|
eef_points = torch.tensor([[0.0, 0.0, 0.175, 1]], device=self.torch_device) |
|
eef_axis = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]], dtype=torch.float32, device=self.torch_device) |
|
|
|
|
|
command = [[] for _ in range(eef_xyz.shape[-2])] |
|
e2b_command = [[] for _ in range(eef_xyz.shape[-2])] |
|
|
|
look_ahead_range = range(self.n_look_ahead) |
|
|
|
for li in look_ahead_range: |
|
|
|
if self.bimanual: |
|
assert b2w_l is not None and b2w_r is not None |
|
|
|
left_eef_xyz = eef_xyz[li:li+1, 0] |
|
right_eef_xyz = eef_xyz[li:li+1, 1] |
|
|
|
left_eef_rot = eef_rot[li, 0] |
|
right_eef_rot = eef_rot[li, 1] |
|
|
|
e2b_l = torch.eye(4, device=self.torch_device) |
|
e2b_r = torch.eye(4, device=self.torch_device) |
|
|
|
for b2w, e2b, eef_points_world, eef_orientation_world in \ |
|
zip([b2w_l, b2w_r], [e2b_l, e2b_r], [left_eef_xyz, right_eef_xyz], [left_eef_rot, right_eef_rot]): |
|
|
|
eef_orientation_world = eef_orientation_world.T |
|
eef_orientation_world = b2w[:3, :3].T @ eef_orientation_world |
|
e2b[:3, :3] = eef_orientation_world @ eef_axis[:, :3] |
|
|
|
eef_points_world = eef_points_world.T |
|
eef_points_world = eef_points_world - b2w[:3, 3].reshape(-1, 1) |
|
eef_points_world = b2w[:3, :3].T @ eef_points_world |
|
e2b[:3, 3:4] = eef_points_world - e2b[:3, :3] @ eef_points[:, :3].T |
|
|
|
e2b_list = [e2b_l, e2b_r] |
|
|
|
else: |
|
assert b2w is not None |
|
|
|
eef_points_world = eef_xyz[li:li+1, 0] |
|
eef_orientation_world = eef_rot[li, 0] |
|
|
|
e2b = torch.eye(4, device=self.torch_device) |
|
|
|
eef_orientation_world = eef_orientation_world.T |
|
eef_orientation_world = b2w[:3, :3].T @ eef_orientation_world |
|
e2b[:3, :3] = eef_orientation_world @ eef_axis[:, :3] |
|
|
|
eef_points_world = eef_points_world.T |
|
eef_points_world = eef_points_world - b2w[:3, 3].reshape(-1, 1) |
|
eef_points_world = b2w[:3, :3].T @ eef_points_world |
|
e2b[:3, 3:4] = eef_points_world - e2b[:3, :3] @ eef_points[:, :3].T |
|
|
|
e2b_list = [e2b] |
|
|
|
|
|
for gripper_id in range(eef_xyz.shape[-2]): |
|
|
|
fk_trans_mat = e2b_list[gripper_id].cpu().numpy() |
|
|
|
cur_xyzrpy = np.zeros(6) |
|
cur_xyzrpy[:3] = fk_trans_mat[:3, 3] * 1000 |
|
cur_xyzrpy[3:] = transforms3d.euler.mat2euler(fk_trans_mat[:3, :3]) |
|
cur_xyzrpy[3:] = cur_xyzrpy[3:] / np.pi * 180 |
|
|
|
gripper = eef_gripper[li, gripper_id].item() |
|
gripper = gripper * 800.0 |
|
|
|
single_command = list(cur_xyzrpy) + [gripper] |
|
command[gripper_id].append(single_command) |
|
|
|
|
|
e2b_command[gripper_id].append(e2b_list[gripper_id].cpu().numpy()) |
|
|
|
return command |
|
|