File size: 6,060 Bytes
63ed3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from typing import ClassVar

import numpy as np


class JointConfig:
    # Standard joint names used in LeRobot training data
    LEROBOT_JOINT_NAMES: ClassVar = [
        "shoulder_pan_joint",
        "shoulder_lift_joint",
        "elbow_joint",
        "wrist_1_joint",
        "wrist_2_joint",
        "wrist_3_joint",
    ]

    # Our custom joint names (more intuitive for users)
    CUSTOM_JOINT_NAMES: ClassVar = [
        "base_rotation",
        "shoulder_tilt",
        "elbow_bend",
        "wrist_rotate",
        "wrist_tilt",
        "wrist_twist",
    ]

    # Mapping from our custom names to LeRobot standard names
    CUSTOM_TO_LEROBOT_NAMES: ClassVar = {
        "base_rotation": "shoulder_pan_joint",
        "shoulder_tilt": "shoulder_lift_joint",
        "elbow_bend": "elbow_joint",
        "wrist_rotate": "wrist_1_joint",
        "wrist_tilt": "wrist_2_joint",
        "wrist_twist": "wrist_3_joint",
    }

    # Reverse mapping for convenience
    LEROBOT_TO_CUSTOM_NAMES: ClassVar = {
        v: k for k, v in CUSTOM_TO_LEROBOT_NAMES.items()
    }

    # Joint limits in normalized values (-100 to +100 for most joints, 0 to 100 for gripper)
    JOINT_LIMITS: ClassVar = {
        "base_rotation": (-100.0, 100.0),
        "shoulder_tilt": (-100.0, 100.0),
        "elbow_bend": (-100.0, 100.0),
        "wrist_rotate": (-100.0, 100.0),
        "wrist_tilt": (-100.0, 100.0),
        "wrist_twist": (-100.0, 100.0),
    }

    @classmethod
    def get_joint_index(cls, joint_name: str) -> int | None:
        """
        Get the index of a joint in the standard joint order.

        Args:
            joint_name: Name of the joint (can be custom or LeRobot name)

        Returns:
            Index of the joint, or None if not found

        """
        # Try custom names first
        if joint_name in cls.CUSTOM_JOINT_NAMES:
            return cls.CUSTOM_JOINT_NAMES.index(joint_name)

        # Try LeRobot names
        if joint_name in cls.LEROBOT_JOINT_NAMES:
            return cls.LEROBOT_JOINT_NAMES.index(joint_name)

        # Try case-insensitive matching
        joint_name_lower = joint_name.lower()
        for i, name in enumerate(cls.CUSTOM_JOINT_NAMES):
            if name.lower() == joint_name_lower:
                return i

        for i, name in enumerate(cls.LEROBOT_JOINT_NAMES):
            if name.lower() == joint_name_lower:
                return i

        return None

    @classmethod
    def parse_joint_data(cls, joints_data, policy_type: str = "act") -> list[float]:
        """
        Parse joint data from Arena message into standard order.

        Expected format: dict with joint names as keys and normalized values.
        All values are already normalized from the training data pipeline.

        Args:
            joints_data: Joint data from Arena message
            policy_type: Type of policy (for logging purposes)

        Returns:
            List of 6 normalized joint values in LeRobot standard order

        """
        try:
            # Handle different possible data formats
            if hasattr(joints_data, "data"):
                joint_dict = joints_data.data
            else:
                joint_dict = joints_data

            if not isinstance(joint_dict, dict):
                return [0.0] * 6

            # Extract joint values in LeRobot standard order
            joint_values = []
            for lerobot_name in cls.LEROBOT_JOINT_NAMES:
                value = None

                # Try LeRobot name directly
                if lerobot_name in joint_dict:
                    value = float(joint_dict[lerobot_name])
                else:
                    # Try custom name
                    custom_name = cls.LEROBOT_TO_CUSTOM_NAMES.get(lerobot_name)
                    if custom_name and custom_name in joint_dict:
                        value = float(joint_dict[custom_name])
                    else:
                        # Try various common formats
                        for key in [
                            lerobot_name,
                            f"joint_{lerobot_name}",
                            lerobot_name.upper(),
                            custom_name,
                            f"joint_{custom_name}" if custom_name else None,
                        ]:
                            if key and key in joint_dict:
                                value = float(joint_dict[key])
                                break

                joint_values.append(value if value is not None else 0.0)

            return joint_values

        except Exception:
            # Return zeros if parsing fails
            return [0.0] * 6

    @classmethod
    def create_joint_commands(cls, action_values: np.ndarray) -> list[dict]:
        """
        Create joint command dictionaries from action values.

        Args:
            action_values: Array of 6 joint values in LeRobot standard order

        Returns:
            List of joint command dictionaries with custom names

        """
        commands = []
        for i, custom_name in enumerate(cls.CUSTOM_JOINT_NAMES):
            if i < len(action_values):
                commands.append({"name": custom_name, "value": float(action_values[i])})
        return commands

    @classmethod
    def validate_joint_values(cls, joint_values: np.ndarray) -> np.ndarray:
        """
        Validate and clamp joint values to their limits.

        Args:
            joint_values: Array of joint values

        Returns:
            Clamped joint values

        """
        if len(joint_values) != 6:
            # Pad or truncate to 6 values
            padded = np.zeros(6, dtype=np.float32)
            n = min(len(joint_values), 6)
            padded[:n] = joint_values[:n]
            joint_values = padded

        # Clamp to limits
        for i, custom_name in enumerate(cls.CUSTOM_JOINT_NAMES):
            min_val, max_val = cls.JOINT_LIMITS[custom_name]
            joint_values[i] = np.clip(joint_values[i], min_val, max_val)

        return joint_values