Spaces:
Sleeping
Sleeping
from typing import ClassVar | |
import numpy as np | |
class JointConfig: | |
# Standard joint names used in LeRobot training data | |
LEROBOT_JOINT_NAMES: ClassVar = [ | |
"shoulder_pan_joint", | |
"shoulder_lift_joint", | |
"elbow_joint", | |
"wrist_1_joint", | |
"wrist_2_joint", | |
"wrist_3_joint", | |
] | |
# Our custom joint names (more intuitive for users) | |
CUSTOM_JOINT_NAMES: ClassVar = [ | |
"base_rotation", | |
"shoulder_tilt", | |
"elbow_bend", | |
"wrist_rotate", | |
"wrist_tilt", | |
"wrist_twist", | |
] | |
# Mapping from our custom names to LeRobot standard names | |
CUSTOM_TO_LEROBOT_NAMES: ClassVar = { | |
"base_rotation": "shoulder_pan_joint", | |
"shoulder_tilt": "shoulder_lift_joint", | |
"elbow_bend": "elbow_joint", | |
"wrist_rotate": "wrist_1_joint", | |
"wrist_tilt": "wrist_2_joint", | |
"wrist_twist": "wrist_3_joint", | |
} | |
# Reverse mapping for convenience | |
LEROBOT_TO_CUSTOM_NAMES: ClassVar = { | |
v: k for k, v in CUSTOM_TO_LEROBOT_NAMES.items() | |
} | |
# Joint limits in normalized values (-100 to +100 for most joints, 0 to 100 for gripper) | |
JOINT_LIMITS: ClassVar = { | |
"base_rotation": (-100.0, 100.0), | |
"shoulder_tilt": (-100.0, 100.0), | |
"elbow_bend": (-100.0, 100.0), | |
"wrist_rotate": (-100.0, 100.0), | |
"wrist_tilt": (-100.0, 100.0), | |
"wrist_twist": (-100.0, 100.0), | |
} | |
def get_joint_index(cls, joint_name: str) -> int | None: | |
""" | |
Get the index of a joint in the standard joint order. | |
Args: | |
joint_name: Name of the joint (can be custom or LeRobot name) | |
Returns: | |
Index of the joint, or None if not found | |
""" | |
# Try custom names first | |
if joint_name in cls.CUSTOM_JOINT_NAMES: | |
return cls.CUSTOM_JOINT_NAMES.index(joint_name) | |
# Try LeRobot names | |
if joint_name in cls.LEROBOT_JOINT_NAMES: | |
return cls.LEROBOT_JOINT_NAMES.index(joint_name) | |
# Try case-insensitive matching | |
joint_name_lower = joint_name.lower() | |
for i, name in enumerate(cls.CUSTOM_JOINT_NAMES): | |
if name.lower() == joint_name_lower: | |
return i | |
for i, name in enumerate(cls.LEROBOT_JOINT_NAMES): | |
if name.lower() == joint_name_lower: | |
return i | |
return None | |
def parse_joint_data(cls, joints_data, policy_type: str = "act") -> list[float]: | |
""" | |
Parse joint data from Arena message into standard order. | |
Expected format: dict with joint names as keys and normalized values. | |
All values are already normalized from the training data pipeline. | |
Args: | |
joints_data: Joint data from Arena message | |
policy_type: Type of policy (for logging purposes) | |
Returns: | |
List of 6 normalized joint values in LeRobot standard order | |
""" | |
try: | |
# Handle different possible data formats | |
if hasattr(joints_data, "data"): | |
joint_dict = joints_data.data | |
else: | |
joint_dict = joints_data | |
if not isinstance(joint_dict, dict): | |
return [0.0] * 6 | |
# Extract joint values in LeRobot standard order | |
joint_values = [] | |
for lerobot_name in cls.LEROBOT_JOINT_NAMES: | |
value = None | |
# Try LeRobot name directly | |
if lerobot_name in joint_dict: | |
value = float(joint_dict[lerobot_name]) | |
else: | |
# Try custom name | |
custom_name = cls.LEROBOT_TO_CUSTOM_NAMES.get(lerobot_name) | |
if custom_name and custom_name in joint_dict: | |
value = float(joint_dict[custom_name]) | |
else: | |
# Try various common formats | |
for key in [ | |
lerobot_name, | |
f"joint_{lerobot_name}", | |
lerobot_name.upper(), | |
custom_name, | |
f"joint_{custom_name}" if custom_name else None, | |
]: | |
if key and key in joint_dict: | |
value = float(joint_dict[key]) | |
break | |
joint_values.append(value if value is not None else 0.0) | |
return joint_values | |
except Exception: | |
# Return zeros if parsing fails | |
return [0.0] * 6 | |
def create_joint_commands(cls, action_values: np.ndarray) -> list[dict]: | |
""" | |
Create joint command dictionaries from action values. | |
Args: | |
action_values: Array of 6 joint values in LeRobot standard order | |
Returns: | |
List of joint command dictionaries with custom names | |
""" | |
commands = [] | |
for i, custom_name in enumerate(cls.CUSTOM_JOINT_NAMES): | |
if i < len(action_values): | |
commands.append({"name": custom_name, "value": float(action_values[i])}) | |
return commands | |
def validate_joint_values(cls, joint_values: np.ndarray) -> np.ndarray: | |
""" | |
Validate and clamp joint values to their limits. | |
Args: | |
joint_values: Array of joint values | |
Returns: | |
Clamped joint values | |
""" | |
if len(joint_values) != 6: | |
# Pad or truncate to 6 values | |
padded = np.zeros(6, dtype=np.float32) | |
n = min(len(joint_values), 6) | |
padded[:n] = joint_values[:n] | |
joint_values = padded | |
# Clamp to limits | |
for i, custom_name in enumerate(cls.CUSTOM_JOINT_NAMES): | |
min_val, max_val = cls.JOINT_LIMITS[custom_name] | |
joint_values[i] = np.clip(joint_values[i], min_val, max_val) | |
return joint_values | |