xinjie.wang
update
be0ecc3
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from embodied_gen.utils.monkey_patches import monkey_patch_maniskill
monkey_patch_maniskill()
import json
from collections import defaultdict
from dataclasses import dataclass
from typing import Literal
import gymnasium as gym
import numpy as np
import torch
import tyro
from mani_skill.utils.wrappers import RecordEpisode
from tqdm import tqdm
import embodied_gen.envs.pick_embodiedgen
from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
from embodied_gen.utils.log import logger
from embodied_gen.utils.simulation import FrankaPandaGrasper
@dataclass
class ParallelSimConfig:
"""CLI parameters for Parallel Sapien simulation."""
# Environment configuration
layout_file: str
"""Path to the layout JSON file"""
output_dir: str
"""Directory to save recorded videos"""
gym_env_name: str = "PickEmbodiedGen-v1"
"""Name of the Gym environment to use"""
num_envs: int = 4
"""Number of parallel environments"""
render_mode: Literal["rgb_array", "hybrid"] = "hybrid"
"""Rendering mode: rgb_array or hybrid"""
enable_shadow: bool = True
"""Whether to enable shadows in rendering"""
control_mode: str = "pd_joint_pos"
"""Control mode for the agent"""
# Recording configuration
max_steps_per_video: int = 1000
"""Maximum steps to record per video"""
save_trajectory: bool = False
"""Whether to save trajectory data"""
# Simulation parameters
seed: int = 0
"""Random seed for environment reset"""
warmup_steps: int = 50
"""Number of warmup steps before action computation"""
reach_target_only: bool = True
"""Whether to only reach target without full action"""
def entrypoint(**kwargs):
if kwargs is None or len(kwargs) == 0:
cfg = tyro.cli(ParallelSimConfig)
else:
cfg = ParallelSimConfig(**kwargs)
env = gym.make(
cfg.gym_env_name,
num_envs=cfg.num_envs,
render_mode=cfg.render_mode,
enable_shadow=cfg.enable_shadow,
layout_file=cfg.layout_file,
control_mode=cfg.control_mode,
)
env = RecordEpisode(
env,
cfg.output_dir,
max_steps_per_video=cfg.max_steps_per_video,
save_trajectory=cfg.save_trajectory,
)
env.reset(seed=cfg.seed)
default_action = env.unwrapped.agent.init_qpos[:, :8]
for _ in tqdm(range(cfg.warmup_steps), desc="SIM Warmup"):
# action = env.action_space.sample() # Random action
obs, reward, terminated, truncated, info = env.step(default_action)
grasper = FrankaPandaGrasper(
env.unwrapped.agent,
env.unwrapped.sim_config.control_freq,
)
layout_data = LayoutInfo.from_dict(json.load(open(cfg.layout_file, "r")))
actions = defaultdict(list)
# Plan Grasp reach pose for each manipulated object in each env.
for env_idx in range(env.num_envs):
actors = env.unwrapped.env_actors[f"env{env_idx}"]
for node in layout_data.relation[
Scene3DItemEnum.MANIPULATED_OBJS.value
]:
action = grasper.compute_grasp_action(
actor=actors[node]._objs[0],
reach_target_only=True,
env_idx=env_idx,
)
actions[node].append(action)
# Excute the planned actions for each manipulated object in each env.
for node in actions:
max_env_steps = 0
for env_idx in range(env.num_envs):
if actions[node][env_idx] is None:
continue
max_env_steps = max(max_env_steps, len(actions[node][env_idx]))
action_tensor = np.ones(
(max_env_steps, env.num_envs, env.action_space.shape[-1])
)
action_tensor *= default_action[None, ...]
for env_idx in range(env.num_envs):
action = actions[node][env_idx]
if action is None:
continue
action_tensor[: len(action), env_idx, :] = action
for step in tqdm(range(max_env_steps), desc=f"Grasping: {node}"):
action = torch.Tensor(action_tensor[step]).to(env.unwrapped.device)
env.unwrapped.agent.set_action(action)
obs, reward, terminated, truncated, info = env.step(action)
env.close()
logger.info(f"Results saved in {cfg.output_dir}")
if __name__ == "__main__":
entrypoint()