File size: 4,315 Bytes
3380376
 
 
 
 
 
 
63ed3a7
 
 
 
 
 
3380376
 
 
 
 
 
 
 
 
 
 
63ed3a7
3380376
 
 
 
 
 
 
 
63ed3a7
 
3380376
 
 
 
 
 
 
 
 
63ed3a7
3380376
 
 
 
 
 
 
 
 
63ed3a7
 
 
3380376
63ed3a7
3380376
63ed3a7
 
3380376
 
63ed3a7
 
3380376
63ed3a7
 
3380376
 
63ed3a7
3380376
 
63ed3a7
3380376
 
 
 
63ed3a7
3380376
 
 
 
 
 
 
 
 
63ed3a7
3380376
63ed3a7
3380376
63ed3a7
 
3380376
63ed3a7
3380376
63ed3a7
 
3380376
63ed3a7
 
3380376
63ed3a7
 
3380376
 
 
 
63ed3a7
3380376
 
 
 
 
 
63ed3a7
 
 
 
 
3380376
63ed3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3380376
 
 
63ed3a7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Joint configuration and mapping utilities for RobotHub Inference Server.

This module handles joint data parsing and normalization between different
robot configurations and the standardized training data format.
"""

from typing import ClassVar

import numpy as np


class JointConfig:
    """Joint configuration and mapping utilities."""

    # Joint name mapping from AI server names to standard names
    AI_TO_STANDARD_NAMES: ClassVar = {
        "Rotation": "shoulder_pan",
        "Pitch": "shoulder_lift",
        "Elbow": "elbow_flex",
        "Wrist_Pitch": "wrist_flex",
        "Wrist_Roll": "wrist_roll",
        "Jaw": "gripper",
    }

    # Standard joint names in order
    STANDARD_JOINT_NAMES: ClassVar = [
        "shoulder_pan",
        "shoulder_lift",
        "elbow_flex",
        "wrist_flex",
        "wrist_roll",
        "gripper",
    ]

    # AI server joint names in order
    AI_JOINT_NAMES: ClassVar = [
        "Rotation",
        "Pitch",
        "Elbow",
        "Wrist_Pitch",
        "Wrist_Roll",
        "Jaw",
    ]

    # Normalization ranges for robot joints
    # Most joints: [-100, 100], Gripper: [0, 100]
    ROBOT_NORMALIZATION_RANGES: ClassVar = {
        "shoulder_pan": (-100, 100),
        "shoulder_lift": (-100, 100),
        "elbow_flex": (-100, 100),
        "wrist_flex": (-100, 100),
        "wrist_roll": (-100, 100),
        "gripper": (0, 100),
    }

    @classmethod
    def parse_joint_data(cls, joints_data, policy_type: str = "act") -> list[float]:
        """
        Parse joint data from Transport Server message into standard order.

        Args:
            joints_data: Joint data from Transport Server message
            policy_type: Type of policy (for logging purposes)

        Returns:
            List of 6 normalized joint values in standard order

        """
        # Handle different possible data formats
        joint_dict = joints_data.data if hasattr(joints_data, "data") else joints_data

        if not isinstance(joint_dict, dict):
            return [0.0] * 6

        # Extract joint values in standard order
        joint_values = []
        for standard_name in cls.STANDARD_JOINT_NAMES:
            value = 0.0  # Default value

            # Try standard name first
            if standard_name in joint_dict:
                value = float(joint_dict[standard_name])
            else:
                # Try AI name
                for ai_name, std_name in cls.AI_TO_STANDARD_NAMES.items():
                    if std_name == standard_name and ai_name in joint_dict:
                        value = float(joint_dict[ai_name])
                        break

            joint_values.append(value)

        return joint_values

    @classmethod
    def create_joint_commands(cls, action_values: np.ndarray | list) -> list[dict]:
        """
        Create joint command messages from action values.

        Args:
            action_values: Array of 6 joint values in standard order

        Returns:
            List of joint command dictionaries with AI server names

        """
        if len(action_values) != 6:
            msg = f"Expected 6 joint values, got {len(action_values)}"
            raise ValueError(msg)

        commands = []
        for i, ai_name in enumerate(cls.AI_JOINT_NAMES):
            commands.append({
                "name": ai_name,
                "value": float(action_values[i]),
                "index": i,
            })
        return commands

    @classmethod
    def validate_joint_values(cls, joint_values: np.ndarray) -> np.ndarray:
        """
        Validate and clamp joint values to their normalized limits.

        Args:
            joint_values: Array of joint values

        Returns:
            Clamped joint values

        """
        if len(joint_values) != 6:
            # Pad or truncate to 6 values
            padded = np.zeros(6, dtype=np.float32)
            n = min(len(joint_values), 6)
            padded[:n] = joint_values[:n]
            joint_values = padded

        # Clamp to normalized limits
        for i, standard_name in enumerate(cls.STANDARD_JOINT_NAMES):
            min_val, max_val = cls.ROBOT_NORMALIZATION_RANGES[standard_name]
            joint_values[i] = np.clip(joint_values[i], min_val, max_val)

        return joint_values