Spaces:
Running
on
Zero
Running
on
Zero
# Project EmbodiedGen | |
# | |
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
# implied. See the License for the specific language governing | |
# permissions and limitations under the License. | |
from embodied_gen.utils.monkey_patches import monkey_patch_maniskill | |
monkey_patch_maniskill() | |
import json | |
from collections import defaultdict | |
from dataclasses import dataclass | |
from typing import Literal | |
import gymnasium as gym | |
import numpy as np | |
import torch | |
import tyro | |
from mani_skill.utils.wrappers import RecordEpisode | |
from tqdm import tqdm | |
import embodied_gen.envs.pick_embodiedgen | |
from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum | |
from embodied_gen.utils.log import logger | |
from embodied_gen.utils.simulation import FrankaPandaGrasper | |
class ParallelSimConfig: | |
"""CLI parameters for Parallel Sapien simulation.""" | |
# Environment configuration | |
layout_file: str | |
"""Path to the layout JSON file""" | |
output_dir: str | |
"""Directory to save recorded videos""" | |
gym_env_name: str = "PickEmbodiedGen-v1" | |
"""Name of the Gym environment to use""" | |
num_envs: int = 4 | |
"""Number of parallel environments""" | |
render_mode: Literal["rgb_array", "hybrid"] = "hybrid" | |
"""Rendering mode: rgb_array or hybrid""" | |
enable_shadow: bool = True | |
"""Whether to enable shadows in rendering""" | |
control_mode: str = "pd_joint_pos" | |
"""Control mode for the agent""" | |
# Recording configuration | |
max_steps_per_video: int = 1000 | |
"""Maximum steps to record per video""" | |
save_trajectory: bool = False | |
"""Whether to save trajectory data""" | |
# Simulation parameters | |
seed: int = 0 | |
"""Random seed for environment reset""" | |
warmup_steps: int = 50 | |
"""Number of warmup steps before action computation""" | |
reach_target_only: bool = True | |
"""Whether to only reach target without full action""" | |
def entrypoint(**kwargs): | |
if kwargs is None or len(kwargs) == 0: | |
cfg = tyro.cli(ParallelSimConfig) | |
else: | |
cfg = ParallelSimConfig(**kwargs) | |
env = gym.make( | |
cfg.gym_env_name, | |
num_envs=cfg.num_envs, | |
render_mode=cfg.render_mode, | |
enable_shadow=cfg.enable_shadow, | |
layout_file=cfg.layout_file, | |
control_mode=cfg.control_mode, | |
) | |
env = RecordEpisode( | |
env, | |
cfg.output_dir, | |
max_steps_per_video=cfg.max_steps_per_video, | |
save_trajectory=cfg.save_trajectory, | |
) | |
env.reset(seed=cfg.seed) | |
default_action = env.unwrapped.agent.init_qpos[:, :8] | |
for _ in tqdm(range(cfg.warmup_steps), desc="SIM Warmup"): | |
# action = env.action_space.sample() # Random action | |
obs, reward, terminated, truncated, info = env.step(default_action) | |
grasper = FrankaPandaGrasper( | |
env.unwrapped.agent, | |
env.unwrapped.sim_config.control_freq, | |
) | |
layout_data = LayoutInfo.from_dict(json.load(open(cfg.layout_file, "r"))) | |
actions = defaultdict(list) | |
# Plan Grasp reach pose for each manipulated object in each env. | |
for env_idx in range(env.num_envs): | |
actors = env.unwrapped.env_actors[f"env{env_idx}"] | |
for node in layout_data.relation[ | |
Scene3DItemEnum.MANIPULATED_OBJS.value | |
]: | |
action = grasper.compute_grasp_action( | |
actor=actors[node]._objs[0], | |
reach_target_only=True, | |
env_idx=env_idx, | |
) | |
actions[node].append(action) | |
# Excute the planned actions for each manipulated object in each env. | |
for node in actions: | |
max_env_steps = 0 | |
for env_idx in range(env.num_envs): | |
if actions[node][env_idx] is None: | |
continue | |
max_env_steps = max(max_env_steps, len(actions[node][env_idx])) | |
action_tensor = np.ones( | |
(max_env_steps, env.num_envs, env.action_space.shape[-1]) | |
) | |
action_tensor *= default_action[None, ...] | |
for env_idx in range(env.num_envs): | |
action = actions[node][env_idx] | |
if action is None: | |
continue | |
action_tensor[: len(action), env_idx, :] = action | |
for step in tqdm(range(max_env_steps), desc=f"Grasping: {node}"): | |
action = torch.Tensor(action_tensor[step]).to(env.unwrapped.device) | |
env.unwrapped.agent.set_action(action) | |
obs, reward, terminated, truncated, info = env.step(action) | |
env.close() | |
logger.info(f"Results saved in {cfg.output_dir}") | |
if __name__ == "__main__": | |
entrypoint() | |