File size: 6,723 Bytes
63ed3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import logging

import cv2
import numpy as np
import torch
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.utils.utils import init_logging

from .base_inference import BaseInferenceEngine

logger = logging.getLogger(__name__)


class ACTInferenceEngine(BaseInferenceEngine):
    """
    ACT (Action Chunking Transformer) inference engine.

    Handles image preprocessing, joint normalization, and action prediction
    for ACT models with proper action chunking.
    """

    def __init__(
        self,
        policy_path: str,
        camera_names: list[str],
        use_custom_joint_names: bool = True,
        device: str | None = None,
    ):
        super().__init__(policy_path, camera_names, use_custom_joint_names, device)

        # ACT-specific configuration
        self.chunk_size = 10  # Default chunk size for ACT
        self.action_history = []  # Store recent actions for chunking

    async def load_policy(self):
        """Load the ACT policy from the specified path."""
        logger.info(f"Loading ACT policy from: {self.policy_path}")

        try:
            # Initialize hydra config for LeRobot
            init_logging()

            # Load the ACT policy
            self.policy = ACTPolicy.from_pretrained(self.policy_path)
            self.policy.to(self.device)
            self.policy.eval()

            # Set up image transforms based on policy config
            if hasattr(self.policy, "config"):
                self._setup_image_transforms()

            self.is_loaded = True
            logger.info(f"✅ ACT policy loaded successfully on {self.device}")

        except Exception as e:
            logger.exception(f"Failed to load ACT policy from {self.policy_path}")
            msg = f"Failed to load ACT policy: {e}"
            raise RuntimeError(msg) from e

    def _setup_image_transforms(self):
        """Set up image transforms based on the policy configuration."""
        try:
            # Get image size from policy config
            config = self.policy.config
            image_size = getattr(config, "image_size", 224)

            # Create transforms for each camera
            for camera_name in self.camera_names:
                # Use policy-specific transforms if available
                if hasattr(self.policy, "image_processor"):
                    # Use the policy's image processor
                    self.image_transforms[camera_name] = self.policy.image_processor
                else:
                    # Fall back to default transform with correct size
                    from torchvision import transforms

                    self.image_transforms[camera_name] = transforms.Compose([
                        transforms.Resize((image_size, image_size)),
                        transforms.ToTensor(),
                        transforms.Normalize(
                            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                        ),
                    ])

        except Exception as e:
            logger.warning(f"Could not set up image transforms: {e}. Using defaults.")

    async def predict(
        self, images: dict[str, np.ndarray], joint_positions: np.ndarray, **kwargs
    ) -> np.ndarray:
        """
        Run ACT inference to predict actions.

        Args:
            images: Dictionary of {camera_name: rgb_image_array}
            joint_positions: Current joint positions in LeRobot standard order
            **kwargs: Additional arguments (unused for ACT)

        Returns:
            Array of predicted actions (chunk of actions for ACT)

        """
        if not self.is_loaded:
            msg = "Policy not loaded. Call load_policy() first."
            raise RuntimeError(msg)

        try:
            # Preprocess inputs
            processed_images = self.preprocess_images(images)
            processed_joints = self.preprocess_joint_positions(joint_positions)

            # Prepare batch inputs for ACT
            batch = self._prepare_batch(processed_images, processed_joints)

            # Run inference
            with torch.no_grad():
                # ACT returns a chunk of actions
                action_chunk = self.policy.predict(batch)

                # Convert to numpy
                if isinstance(action_chunk, torch.Tensor):
                    action_chunk = action_chunk.cpu().numpy()

                # Store in action history
                self.action_history.append(action_chunk)
                if len(self.action_history) > 10:  # Keep last 10 chunks
                    self.action_history.pop(0)

                logger.debug(f"ACT predicted action chunk shape: {action_chunk.shape}")
                return action_chunk

        except Exception as e:
            logger.exception("ACT inference failed")
            msg = f"ACT inference failed: {e}"
            raise RuntimeError(msg) from e

    def _prepare_batch(
        self, images: dict[str, torch.Tensor], joints: torch.Tensor
    ) -> dict:
        """
        Prepare batch inputs for ACT model.

        Args:
            images: Preprocessed images
            joints: Preprocessed joint positions

        Returns:
            Batch dictionary for ACT model

        """
        batch = {}

        # Add images to batch
        for camera_name, image_tensor in images.items():
            # Add batch dimension if needed
            if len(image_tensor.shape) == 3:
                image_tensor = image_tensor.unsqueeze(0)
            batch[f"observation.images.{camera_name}"] = image_tensor

        # Add joint positions
        if len(joints.shape) == 1:
            joints = joints.unsqueeze(0)
        batch["observation.state"] = joints

        return batch

    def reset(self):
        """Reset ACT-specific state."""
        super().reset()
        self.action_history = []

        # Reset ACT model state if it has one
        if self.policy and hasattr(self.policy, "reset"):
            self.policy.reset()

    def get_model_info(self) -> dict:
        """Get ACT-specific model information."""
        info = super().get_model_info()
        info.update({
            "policy_type": "act",
            "chunk_size": self.chunk_size,
            "action_history_length": len(self.action_history),
        })
        return info


# Utility functions for data transformation


def image_bgr_to_rgb(image: np.ndarray) -> np.ndarray:
    """Convert BGR image to RGB (useful for OpenCV cameras)."""
    return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


def resize_image(image: np.ndarray, target_size: tuple[int, int]) -> np.ndarray:
    """Resize image to target size (width, height)."""
    return cv2.resize(image, target_size)