blanchon's picture
Update
3380376
import logging
import numpy as np
import torch
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.common.utils.utils import init_logging
from torchvision import transforms
from .base_inference import BaseInferenceEngine
logger = logging.getLogger(__name__)
class SmolVLAInferenceEngine(BaseInferenceEngine):
"""
SmolVLA (Small Vision-Language-Action) inference engine.
Handles image preprocessing, joint normalization, and action prediction
for SmolVLA models with vision-language understanding.
"""
def __init__(
self,
policy_path: str,
camera_names: list[str],
device: str | None = None,
language_instruction: str | None = None,
):
super().__init__(policy_path, camera_names, device)
# SmolVLA-specific configuration
self.language_instruction = language_instruction
self.supports_language = True
async def load_policy(self):
"""Load the SmolVLA policy from the specified path."""
logger.info(f"Loading SmolVLA policy from: {self.policy_path}")
# Initialize hydra config for LeRobot
init_logging()
# Load the SmolVLA policy
self.policy = SmolVLAPolicy.from_pretrained(self.policy_path)
self.policy.to(self.device)
self.policy.eval()
# Set up image transforms based on policy config
if hasattr(self.policy, "config"):
self._setup_image_transforms()
self.is_loaded = True
logger.info(f"SmolVLA policy loaded successfully on {self.device}")
def _setup_image_transforms(self):
"""Set up image transforms based on the policy configuration."""
# Get image size from policy config
config = self.policy.config
image_size = getattr(config, "image_size", 224)
# Create transforms for each camera
for camera_name in self.camera_names:
# Use policy-specific transforms if available
if hasattr(self.policy, "image_processor"):
self.image_transforms[camera_name] = self.policy.image_processor
else:
# Fall back to default transform
self.image_transforms[camera_name] = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
])
async def predict(
self, images: dict[str, np.ndarray], joint_positions: np.ndarray, **kwargs
) -> np.ndarray:
"""
Run SmolVLA inference to predict actions.
Args:
images: Dictionary of {camera_name: rgb_image_array}
joint_positions: Current joint positions in LeRobot standard order
task: Optional language instruction (overrides instance language_instruction)
Returns:
Array of predicted actions
"""
if not self.is_loaded:
msg = "Policy not loaded. Call load_policy() first."
raise RuntimeError(msg)
# Preprocess inputs
processed_images = self.preprocess_images(images)
processed_joints = self.preprocess_joint_positions(joint_positions)
# Get language instruction
task = kwargs.get("task", self.language_instruction)
# Prepare batch inputs for SmolVLA
batch = self._prepare_batch(processed_images, processed_joints, task)
# Run inference
with torch.no_grad():
action = self.policy.predict(batch)
# Convert to numpy
if isinstance(action, torch.Tensor):
action = action.cpu().numpy()
return action
def _prepare_batch(
self,
images: dict[str, torch.Tensor],
joints: torch.Tensor,
task: str | None = None,
) -> dict:
"""
Prepare batch inputs for SmolVLA model.
Args:
images: Preprocessed images
joints: Preprocessed joint positions
task: Language instruction
Returns:
Batch dictionary for SmolVLA model
"""
batch = {}
# Add images to batch
for camera_name, image_tensor in images.items():
# Add batch dimension if needed
if len(image_tensor.shape) == 3:
image_tensor = image_tensor.unsqueeze(0)
batch[f"observation.images.{camera_name}"] = image_tensor
# Add joint positions
if len(joints.shape) == 1:
joints = joints.unsqueeze(0)
batch["observation.state"] = joints
# Add language instruction if provided
if task:
batch["task"] = task
return batch