blanchon's picture
Update
63ed3a7
raw
history blame
6.06 kB
from typing import ClassVar
import numpy as np
class JointConfig:
# Standard joint names used in LeRobot training data
LEROBOT_JOINT_NAMES: ClassVar = [
"shoulder_pan_joint",
"shoulder_lift_joint",
"elbow_joint",
"wrist_1_joint",
"wrist_2_joint",
"wrist_3_joint",
]
# Our custom joint names (more intuitive for users)
CUSTOM_JOINT_NAMES: ClassVar = [
"base_rotation",
"shoulder_tilt",
"elbow_bend",
"wrist_rotate",
"wrist_tilt",
"wrist_twist",
]
# Mapping from our custom names to LeRobot standard names
CUSTOM_TO_LEROBOT_NAMES: ClassVar = {
"base_rotation": "shoulder_pan_joint",
"shoulder_tilt": "shoulder_lift_joint",
"elbow_bend": "elbow_joint",
"wrist_rotate": "wrist_1_joint",
"wrist_tilt": "wrist_2_joint",
"wrist_twist": "wrist_3_joint",
}
# Reverse mapping for convenience
LEROBOT_TO_CUSTOM_NAMES: ClassVar = {
v: k for k, v in CUSTOM_TO_LEROBOT_NAMES.items()
}
# Joint limits in normalized values (-100 to +100 for most joints, 0 to 100 for gripper)
JOINT_LIMITS: ClassVar = {
"base_rotation": (-100.0, 100.0),
"shoulder_tilt": (-100.0, 100.0),
"elbow_bend": (-100.0, 100.0),
"wrist_rotate": (-100.0, 100.0),
"wrist_tilt": (-100.0, 100.0),
"wrist_twist": (-100.0, 100.0),
}
@classmethod
def get_joint_index(cls, joint_name: str) -> int | None:
"""
Get the index of a joint in the standard joint order.
Args:
joint_name: Name of the joint (can be custom or LeRobot name)
Returns:
Index of the joint, or None if not found
"""
# Try custom names first
if joint_name in cls.CUSTOM_JOINT_NAMES:
return cls.CUSTOM_JOINT_NAMES.index(joint_name)
# Try LeRobot names
if joint_name in cls.LEROBOT_JOINT_NAMES:
return cls.LEROBOT_JOINT_NAMES.index(joint_name)
# Try case-insensitive matching
joint_name_lower = joint_name.lower()
for i, name in enumerate(cls.CUSTOM_JOINT_NAMES):
if name.lower() == joint_name_lower:
return i
for i, name in enumerate(cls.LEROBOT_JOINT_NAMES):
if name.lower() == joint_name_lower:
return i
return None
@classmethod
def parse_joint_data(cls, joints_data, policy_type: str = "act") -> list[float]:
"""
Parse joint data from Arena message into standard order.
Expected format: dict with joint names as keys and normalized values.
All values are already normalized from the training data pipeline.
Args:
joints_data: Joint data from Arena message
policy_type: Type of policy (for logging purposes)
Returns:
List of 6 normalized joint values in LeRobot standard order
"""
try:
# Handle different possible data formats
if hasattr(joints_data, "data"):
joint_dict = joints_data.data
else:
joint_dict = joints_data
if not isinstance(joint_dict, dict):
return [0.0] * 6
# Extract joint values in LeRobot standard order
joint_values = []
for lerobot_name in cls.LEROBOT_JOINT_NAMES:
value = None
# Try LeRobot name directly
if lerobot_name in joint_dict:
value = float(joint_dict[lerobot_name])
else:
# Try custom name
custom_name = cls.LEROBOT_TO_CUSTOM_NAMES.get(lerobot_name)
if custom_name and custom_name in joint_dict:
value = float(joint_dict[custom_name])
else:
# Try various common formats
for key in [
lerobot_name,
f"joint_{lerobot_name}",
lerobot_name.upper(),
custom_name,
f"joint_{custom_name}" if custom_name else None,
]:
if key and key in joint_dict:
value = float(joint_dict[key])
break
joint_values.append(value if value is not None else 0.0)
return joint_values
except Exception:
# Return zeros if parsing fails
return [0.0] * 6
@classmethod
def create_joint_commands(cls, action_values: np.ndarray) -> list[dict]:
"""
Create joint command dictionaries from action values.
Args:
action_values: Array of 6 joint values in LeRobot standard order
Returns:
List of joint command dictionaries with custom names
"""
commands = []
for i, custom_name in enumerate(cls.CUSTOM_JOINT_NAMES):
if i < len(action_values):
commands.append({"name": custom_name, "value": float(action_values[i])})
return commands
@classmethod
def validate_joint_values(cls, joint_values: np.ndarray) -> np.ndarray:
"""
Validate and clamp joint values to their limits.
Args:
joint_values: Array of joint values
Returns:
Clamped joint values
"""
if len(joint_values) != 6:
# Pad or truncate to 6 values
padded = np.zeros(6, dtype=np.float32)
n = min(len(joint_values), 6)
padded[:n] = joint_values[:n]
joint_values = padded
# Clamp to limits
for i, custom_name in enumerate(cls.CUSTOM_JOINT_NAMES):
min_val, max_val = cls.JOINT_LIMITS[custom_name]
joint_values[i] = np.clip(joint_values[i], min_val, max_val)
return joint_values