Spaces:
Sleeping
Sleeping
File size: 6,060 Bytes
63ed3a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from typing import ClassVar
import numpy as np
class JointConfig:
# Standard joint names used in LeRobot training data
LEROBOT_JOINT_NAMES: ClassVar = [
"shoulder_pan_joint",
"shoulder_lift_joint",
"elbow_joint",
"wrist_1_joint",
"wrist_2_joint",
"wrist_3_joint",
]
# Our custom joint names (more intuitive for users)
CUSTOM_JOINT_NAMES: ClassVar = [
"base_rotation",
"shoulder_tilt",
"elbow_bend",
"wrist_rotate",
"wrist_tilt",
"wrist_twist",
]
# Mapping from our custom names to LeRobot standard names
CUSTOM_TO_LEROBOT_NAMES: ClassVar = {
"base_rotation": "shoulder_pan_joint",
"shoulder_tilt": "shoulder_lift_joint",
"elbow_bend": "elbow_joint",
"wrist_rotate": "wrist_1_joint",
"wrist_tilt": "wrist_2_joint",
"wrist_twist": "wrist_3_joint",
}
# Reverse mapping for convenience
LEROBOT_TO_CUSTOM_NAMES: ClassVar = {
v: k for k, v in CUSTOM_TO_LEROBOT_NAMES.items()
}
# Joint limits in normalized values (-100 to +100 for most joints, 0 to 100 for gripper)
JOINT_LIMITS: ClassVar = {
"base_rotation": (-100.0, 100.0),
"shoulder_tilt": (-100.0, 100.0),
"elbow_bend": (-100.0, 100.0),
"wrist_rotate": (-100.0, 100.0),
"wrist_tilt": (-100.0, 100.0),
"wrist_twist": (-100.0, 100.0),
}
@classmethod
def get_joint_index(cls, joint_name: str) -> int | None:
"""
Get the index of a joint in the standard joint order.
Args:
joint_name: Name of the joint (can be custom or LeRobot name)
Returns:
Index of the joint, or None if not found
"""
# Try custom names first
if joint_name in cls.CUSTOM_JOINT_NAMES:
return cls.CUSTOM_JOINT_NAMES.index(joint_name)
# Try LeRobot names
if joint_name in cls.LEROBOT_JOINT_NAMES:
return cls.LEROBOT_JOINT_NAMES.index(joint_name)
# Try case-insensitive matching
joint_name_lower = joint_name.lower()
for i, name in enumerate(cls.CUSTOM_JOINT_NAMES):
if name.lower() == joint_name_lower:
return i
for i, name in enumerate(cls.LEROBOT_JOINT_NAMES):
if name.lower() == joint_name_lower:
return i
return None
@classmethod
def parse_joint_data(cls, joints_data, policy_type: str = "act") -> list[float]:
"""
Parse joint data from Arena message into standard order.
Expected format: dict with joint names as keys and normalized values.
All values are already normalized from the training data pipeline.
Args:
joints_data: Joint data from Arena message
policy_type: Type of policy (for logging purposes)
Returns:
List of 6 normalized joint values in LeRobot standard order
"""
try:
# Handle different possible data formats
if hasattr(joints_data, "data"):
joint_dict = joints_data.data
else:
joint_dict = joints_data
if not isinstance(joint_dict, dict):
return [0.0] * 6
# Extract joint values in LeRobot standard order
joint_values = []
for lerobot_name in cls.LEROBOT_JOINT_NAMES:
value = None
# Try LeRobot name directly
if lerobot_name in joint_dict:
value = float(joint_dict[lerobot_name])
else:
# Try custom name
custom_name = cls.LEROBOT_TO_CUSTOM_NAMES.get(lerobot_name)
if custom_name and custom_name in joint_dict:
value = float(joint_dict[custom_name])
else:
# Try various common formats
for key in [
lerobot_name,
f"joint_{lerobot_name}",
lerobot_name.upper(),
custom_name,
f"joint_{custom_name}" if custom_name else None,
]:
if key and key in joint_dict:
value = float(joint_dict[key])
break
joint_values.append(value if value is not None else 0.0)
return joint_values
except Exception:
# Return zeros if parsing fails
return [0.0] * 6
@classmethod
def create_joint_commands(cls, action_values: np.ndarray) -> list[dict]:
"""
Create joint command dictionaries from action values.
Args:
action_values: Array of 6 joint values in LeRobot standard order
Returns:
List of joint command dictionaries with custom names
"""
commands = []
for i, custom_name in enumerate(cls.CUSTOM_JOINT_NAMES):
if i < len(action_values):
commands.append({"name": custom_name, "value": float(action_values[i])})
return commands
@classmethod
def validate_joint_values(cls, joint_values: np.ndarray) -> np.ndarray:
"""
Validate and clamp joint values to their limits.
Args:
joint_values: Array of joint values
Returns:
Clamped joint values
"""
if len(joint_values) != 6:
# Pad or truncate to 6 values
padded = np.zeros(6, dtype=np.float32)
n = min(len(joint_values), 6)
padded[:n] = joint_values[:n]
joint_values = padded
# Clamp to limits
for i, custom_name in enumerate(cls.CUSTOM_JOINT_NAMES):
min_val, max_val = cls.JOINT_LIMITS[custom_name]
joint_values[i] = np.clip(joint_values[i], min_val, max_val)
return joint_values
|