File size: 5,758 Bytes
63ed3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3380376
63ed3a7
 
 
 
 
 
3380376
63ed3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3380376
63ed3a7
 
 
 
 
 
3380376
63ed3a7
3380376
 
63ed3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import asyncio
import logging
from abc import ABC, abstractmethod

import numpy as np
import torch
from PIL import Image

from .joint_config import JointConfig

logger = logging.getLogger(__name__)


class BaseInferenceEngine(ABC):
    """
    Base class for all inference engines.

    This class provides common functionality for:
    - Image preprocessing and normalization
    - Joint data handling and validation
    - Model loading and management
    - Action prediction interface
    """

    def __init__(
        self,
        policy_path: str,
        camera_names: list[str],
        device: str | None = None,
    ):
        self.policy_path = policy_path
        self.camera_names = camera_names

        # Device selection
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)

        logger.info(f"Using device: {self.device}")

        # Model and preprocessing
        self.policy = None
        self.image_transforms = {}  # {camera_name: transform}
        self.stats = None  # Dataset statistics for normalization

        # State tracking
        self.is_loaded = False
        self.last_images = {}
        self.last_joint_positions = None

    @abstractmethod
    async def load_policy(self):
        """Load the policy model. Must be implemented by subclasses."""

    @abstractmethod
    async def predict(
        self, images: dict[str, np.ndarray], joint_positions: np.ndarray, **kwargs
    ) -> np.ndarray:
        """Run inference. Must be implemented by subclasses."""

    def preprocess_images(
        self, images: dict[str, np.ndarray]
    ) -> dict[str, torch.Tensor]:
        """
        Preprocess images for inference.

        Args:
            images: Dictionary of {camera_name: rgb_image_array}

        Returns:
            Dictionary of {camera_name: preprocessed_tensor}

        """
        processed_images = {}

        for camera_name, image in images.items():
            if camera_name not in self.camera_names:
                logger.warning(f"Unexpected camera: {camera_name}")
                continue

            # Convert numpy array to PIL Image if needed
            if isinstance(image, np.ndarray):
                if image.dtype != np.uint8:
                    image = (image * 255).astype(np.uint8)
                pil_image = Image.fromarray(image)
            else:
                pil_image = image

            # Apply transforms if available
            if camera_name in self.image_transforms:
                tensor = self.image_transforms[camera_name](pil_image)
            else:
                # Default preprocessing: resize to 224x224 and normalize
                tensor = self._default_image_transform(pil_image)

            processed_images[camera_name] = tensor.to(self.device)

        return processed_images

    def _default_image_transform(self, image: Image.Image) -> torch.Tensor:
        """Default image preprocessing."""
        # Resize to 224x224 (common size for vision models)
        image = image.resize((224, 224), Image.Resampling.LANCZOS)

        # Convert to tensor and normalize to [0, 1]
        tensor = torch.from_numpy(np.array(image)).float() / 255.0

        # Rearrange from HWC to CHW
        if len(tensor.shape) == 3:
            tensor = tensor.permute(2, 0, 1)

        return tensor

    def preprocess_joint_positions(self, joint_positions: np.ndarray) -> torch.Tensor:
        """
        Preprocess joint positions for inference.

        Args:
            joint_positions: Array of joint positions in standard order

        Returns:
            Preprocessed joint tensor

        """
        # Validate and clamp joint values
        joint_positions = JointConfig.validate_joint_values(joint_positions)

        # Convert to tensor
        joint_tensor = torch.from_numpy(joint_positions).float().to(self.device)

        # Normalize if we have dataset statistics
        if self.stats and hasattr(self.stats, "joint_stats"):
            joint_tensor = self._normalize_joints(joint_tensor)

        return joint_tensor

    def _normalize_joints(self, joint_tensor: torch.Tensor) -> torch.Tensor:
        """Normalize joint values using dataset statistics."""
        # This would use the actual dataset statistics
        # For now, we assume joints are already normalized
        return joint_tensor

    def get_joint_commands_with_names(self, action: np.ndarray) -> list[dict]:
        """
        Convert action array to joint commands with names.

        Args:
            action: Array of joint actions in standard order

        Returns:
            List of joint command dictionaries

        """
        # Validate action values
        action = JointConfig.validate_joint_values(action)

        # Create commands with AI names (always use AI names for output)
        return JointConfig.create_joint_commands(action)

    def reset(self):
        """Reset the inference engine state."""
        self.last_images = {}
        self.last_joint_positions = None

        # Clear any model-specific state
        if hasattr(self.policy, "reset"):
            self.policy.reset()

    async def cleanup(self):
        """Clean up resources."""
        if self.policy:
            del self.policy
            self.policy = None

        # Clear GPU memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        self.is_loaded = False
        logger.info(f"Cleaned up inference engine for {self.policy_path}")

    def __del__(self):
        """Destructor to ensure cleanup."""
        if hasattr(self, "policy") and self.policy:
            asyncio.create_task(self.cleanup())