Spaces:
Sleeping
Sleeping
import asyncio | |
import logging | |
from abc import ABC, abstractmethod | |
import numpy as np | |
import torch | |
from PIL import Image | |
from .joint_config import JointConfig | |
logger = logging.getLogger(__name__) | |
class BaseInferenceEngine(ABC): | |
""" | |
Base class for all inference engines. | |
This class provides common functionality for: | |
- Image preprocessing and normalization | |
- Joint data handling and validation | |
- Model loading and management | |
- Action prediction interface | |
""" | |
def __init__( | |
self, | |
policy_path: str, | |
camera_names: list[str], | |
device: str | None = None, | |
): | |
self.policy_path = policy_path | |
self.camera_names = camera_names | |
# Device selection | |
if device is None: | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
else: | |
self.device = torch.device(device) | |
logger.info(f"Using device: {self.device}") | |
# Model and preprocessing | |
self.policy = None | |
self.image_transforms = {} # {camera_name: transform} | |
self.stats = None # Dataset statistics for normalization | |
# State tracking | |
self.is_loaded = False | |
self.last_images = {} | |
self.last_joint_positions = None | |
async def load_policy(self): | |
"""Load the policy model. Must be implemented by subclasses.""" | |
async def predict( | |
self, images: dict[str, np.ndarray], joint_positions: np.ndarray, **kwargs | |
) -> np.ndarray: | |
"""Run inference. Must be implemented by subclasses.""" | |
def preprocess_images( | |
self, images: dict[str, np.ndarray] | |
) -> dict[str, torch.Tensor]: | |
""" | |
Preprocess images for inference. | |
Args: | |
images: Dictionary of {camera_name: rgb_image_array} | |
Returns: | |
Dictionary of {camera_name: preprocessed_tensor} | |
""" | |
processed_images = {} | |
for camera_name, image in images.items(): | |
if camera_name not in self.camera_names: | |
logger.warning(f"Unexpected camera: {camera_name}") | |
continue | |
# Convert numpy array to PIL Image if needed | |
if isinstance(image, np.ndarray): | |
if image.dtype != np.uint8: | |
image = (image * 255).astype(np.uint8) | |
pil_image = Image.fromarray(image) | |
else: | |
pil_image = image | |
# Apply transforms if available | |
if camera_name in self.image_transforms: | |
tensor = self.image_transforms[camera_name](pil_image) | |
else: | |
# Default preprocessing: resize to 224x224 and normalize | |
tensor = self._default_image_transform(pil_image) | |
processed_images[camera_name] = tensor.to(self.device) | |
return processed_images | |
def _default_image_transform(self, image: Image.Image) -> torch.Tensor: | |
"""Default image preprocessing.""" | |
# Resize to 224x224 (common size for vision models) | |
image = image.resize((224, 224), Image.Resampling.LANCZOS) | |
# Convert to tensor and normalize to [0, 1] | |
tensor = torch.from_numpy(np.array(image)).float() / 255.0 | |
# Rearrange from HWC to CHW | |
if len(tensor.shape) == 3: | |
tensor = tensor.permute(2, 0, 1) | |
return tensor | |
def preprocess_joint_positions(self, joint_positions: np.ndarray) -> torch.Tensor: | |
""" | |
Preprocess joint positions for inference. | |
Args: | |
joint_positions: Array of joint positions in standard order | |
Returns: | |
Preprocessed joint tensor | |
""" | |
# Validate and clamp joint values | |
joint_positions = JointConfig.validate_joint_values(joint_positions) | |
# Convert to tensor | |
joint_tensor = torch.from_numpy(joint_positions).float().to(self.device) | |
# Normalize if we have dataset statistics | |
if self.stats and hasattr(self.stats, "joint_stats"): | |
joint_tensor = self._normalize_joints(joint_tensor) | |
return joint_tensor | |
def _normalize_joints(self, joint_tensor: torch.Tensor) -> torch.Tensor: | |
"""Normalize joint values using dataset statistics.""" | |
# This would use the actual dataset statistics | |
# For now, we assume joints are already normalized | |
return joint_tensor | |
def get_joint_commands_with_names(self, action: np.ndarray) -> list[dict]: | |
""" | |
Convert action array to joint commands with names. | |
Args: | |
action: Array of joint actions in standard order | |
Returns: | |
List of joint command dictionaries | |
""" | |
# Validate action values | |
action = JointConfig.validate_joint_values(action) | |
# Create commands with AI names (always use AI names for output) | |
return JointConfig.create_joint_commands(action) | |
def reset(self): | |
"""Reset the inference engine state.""" | |
self.last_images = {} | |
self.last_joint_positions = None | |
# Clear any model-specific state | |
if hasattr(self.policy, "reset"): | |
self.policy.reset() | |
async def cleanup(self): | |
"""Clean up resources.""" | |
if self.policy: | |
del self.policy | |
self.policy = None | |
# Clear GPU memory | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
self.is_loaded = False | |
logger.info(f"Cleaned up inference engine for {self.policy_path}") | |
def __del__(self): | |
"""Destructor to ensure cleanup.""" | |
if hasattr(self, "policy") and self.policy: | |
asyncio.create_task(self.cleanup()) | |