Spaces:
Sleeping
Sleeping
File size: 5,758 Bytes
63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import asyncio
import logging
from abc import ABC, abstractmethod
import numpy as np
import torch
from PIL import Image
from .joint_config import JointConfig
logger = logging.getLogger(__name__)
class BaseInferenceEngine(ABC):
"""
Base class for all inference engines.
This class provides common functionality for:
- Image preprocessing and normalization
- Joint data handling and validation
- Model loading and management
- Action prediction interface
"""
def __init__(
self,
policy_path: str,
camera_names: list[str],
device: str | None = None,
):
self.policy_path = policy_path
self.camera_names = camera_names
# Device selection
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
logger.info(f"Using device: {self.device}")
# Model and preprocessing
self.policy = None
self.image_transforms = {} # {camera_name: transform}
self.stats = None # Dataset statistics for normalization
# State tracking
self.is_loaded = False
self.last_images = {}
self.last_joint_positions = None
@abstractmethod
async def load_policy(self):
"""Load the policy model. Must be implemented by subclasses."""
@abstractmethod
async def predict(
self, images: dict[str, np.ndarray], joint_positions: np.ndarray, **kwargs
) -> np.ndarray:
"""Run inference. Must be implemented by subclasses."""
def preprocess_images(
self, images: dict[str, np.ndarray]
) -> dict[str, torch.Tensor]:
"""
Preprocess images for inference.
Args:
images: Dictionary of {camera_name: rgb_image_array}
Returns:
Dictionary of {camera_name: preprocessed_tensor}
"""
processed_images = {}
for camera_name, image in images.items():
if camera_name not in self.camera_names:
logger.warning(f"Unexpected camera: {camera_name}")
continue
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
if image.dtype != np.uint8:
image = (image * 255).astype(np.uint8)
pil_image = Image.fromarray(image)
else:
pil_image = image
# Apply transforms if available
if camera_name in self.image_transforms:
tensor = self.image_transforms[camera_name](pil_image)
else:
# Default preprocessing: resize to 224x224 and normalize
tensor = self._default_image_transform(pil_image)
processed_images[camera_name] = tensor.to(self.device)
return processed_images
def _default_image_transform(self, image: Image.Image) -> torch.Tensor:
"""Default image preprocessing."""
# Resize to 224x224 (common size for vision models)
image = image.resize((224, 224), Image.Resampling.LANCZOS)
# Convert to tensor and normalize to [0, 1]
tensor = torch.from_numpy(np.array(image)).float() / 255.0
# Rearrange from HWC to CHW
if len(tensor.shape) == 3:
tensor = tensor.permute(2, 0, 1)
return tensor
def preprocess_joint_positions(self, joint_positions: np.ndarray) -> torch.Tensor:
"""
Preprocess joint positions for inference.
Args:
joint_positions: Array of joint positions in standard order
Returns:
Preprocessed joint tensor
"""
# Validate and clamp joint values
joint_positions = JointConfig.validate_joint_values(joint_positions)
# Convert to tensor
joint_tensor = torch.from_numpy(joint_positions).float().to(self.device)
# Normalize if we have dataset statistics
if self.stats and hasattr(self.stats, "joint_stats"):
joint_tensor = self._normalize_joints(joint_tensor)
return joint_tensor
def _normalize_joints(self, joint_tensor: torch.Tensor) -> torch.Tensor:
"""Normalize joint values using dataset statistics."""
# This would use the actual dataset statistics
# For now, we assume joints are already normalized
return joint_tensor
def get_joint_commands_with_names(self, action: np.ndarray) -> list[dict]:
"""
Convert action array to joint commands with names.
Args:
action: Array of joint actions in standard order
Returns:
List of joint command dictionaries
"""
# Validate action values
action = JointConfig.validate_joint_values(action)
# Create commands with AI names (always use AI names for output)
return JointConfig.create_joint_commands(action)
def reset(self):
"""Reset the inference engine state."""
self.last_images = {}
self.last_joint_positions = None
# Clear any model-specific state
if hasattr(self.policy, "reset"):
self.policy.reset()
async def cleanup(self):
"""Clean up resources."""
if self.policy:
del self.policy
self.policy = None
# Clear GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.is_loaded = False
logger.info(f"Cleaned up inference engine for {self.policy_path}")
def __del__(self):
"""Destructor to ensure cleanup."""
if hasattr(self, "policy") and self.policy:
asyncio.create_task(self.cleanup())
|