Spaces:
Sleeping
Sleeping
File size: 4,315 Bytes
3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
"""
Joint configuration and mapping utilities for RobotHub Inference Server.
This module handles joint data parsing and normalization between different
robot configurations and the standardized training data format.
"""
from typing import ClassVar
import numpy as np
class JointConfig:
"""Joint configuration and mapping utilities."""
# Joint name mapping from AI server names to standard names
AI_TO_STANDARD_NAMES: ClassVar = {
"Rotation": "shoulder_pan",
"Pitch": "shoulder_lift",
"Elbow": "elbow_flex",
"Wrist_Pitch": "wrist_flex",
"Wrist_Roll": "wrist_roll",
"Jaw": "gripper",
}
# Standard joint names in order
STANDARD_JOINT_NAMES: ClassVar = [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
]
# AI server joint names in order
AI_JOINT_NAMES: ClassVar = [
"Rotation",
"Pitch",
"Elbow",
"Wrist_Pitch",
"Wrist_Roll",
"Jaw",
]
# Normalization ranges for robot joints
# Most joints: [-100, 100], Gripper: [0, 100]
ROBOT_NORMALIZATION_RANGES: ClassVar = {
"shoulder_pan": (-100, 100),
"shoulder_lift": (-100, 100),
"elbow_flex": (-100, 100),
"wrist_flex": (-100, 100),
"wrist_roll": (-100, 100),
"gripper": (0, 100),
}
@classmethod
def parse_joint_data(cls, joints_data, policy_type: str = "act") -> list[float]:
"""
Parse joint data from Transport Server message into standard order.
Args:
joints_data: Joint data from Transport Server message
policy_type: Type of policy (for logging purposes)
Returns:
List of 6 normalized joint values in standard order
"""
# Handle different possible data formats
joint_dict = joints_data.data if hasattr(joints_data, "data") else joints_data
if not isinstance(joint_dict, dict):
return [0.0] * 6
# Extract joint values in standard order
joint_values = []
for standard_name in cls.STANDARD_JOINT_NAMES:
value = 0.0 # Default value
# Try standard name first
if standard_name in joint_dict:
value = float(joint_dict[standard_name])
else:
# Try AI name
for ai_name, std_name in cls.AI_TO_STANDARD_NAMES.items():
if std_name == standard_name and ai_name in joint_dict:
value = float(joint_dict[ai_name])
break
joint_values.append(value)
return joint_values
@classmethod
def create_joint_commands(cls, action_values: np.ndarray | list) -> list[dict]:
"""
Create joint command messages from action values.
Args:
action_values: Array of 6 joint values in standard order
Returns:
List of joint command dictionaries with AI server names
"""
if len(action_values) != 6:
msg = f"Expected 6 joint values, got {len(action_values)}"
raise ValueError(msg)
commands = []
for i, ai_name in enumerate(cls.AI_JOINT_NAMES):
commands.append({
"name": ai_name,
"value": float(action_values[i]),
"index": i,
})
return commands
@classmethod
def validate_joint_values(cls, joint_values: np.ndarray) -> np.ndarray:
"""
Validate and clamp joint values to their normalized limits.
Args:
joint_values: Array of joint values
Returns:
Clamped joint values
"""
if len(joint_values) != 6:
# Pad or truncate to 6 values
padded = np.zeros(6, dtype=np.float32)
n = min(len(joint_values), 6)
padded[:n] = joint_values[:n]
joint_values = padded
# Clamp to normalized limits
for i, standard_name in enumerate(cls.STANDARD_JOINT_NAMES):
min_val, max_val = cls.ROBOT_NORMALIZATION_RANGES[standard_name]
joint_values[i] = np.clip(joint_values[i], min_val, max_val)
return joint_values
|