Spaces:
Sleeping
Sleeping
File size: 4,862 Bytes
63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import logging
import numpy as np
import torch
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
from lerobot.common.utils.utils import init_logging
from torchvision import transforms
from .base_inference import BaseInferenceEngine
logger = logging.getLogger(__name__)
class Pi0FastInferenceEngine(BaseInferenceEngine):
"""
Pi0Fast (Physical Intelligence Fast) inference engine.
Handles image preprocessing, joint normalization, and fast action prediction
for Pi0Fast models with language instruction support.
"""
def __init__(
self,
policy_path: str,
camera_names: list[str],
device: str | None = None,
language_instruction: str | None = None,
):
super().__init__(policy_path, camera_names, device)
# Pi0Fast-specific configuration
self.language_instruction = language_instruction
self.supports_language = True
async def load_policy(self):
"""Load the Pi0Fast policy from the specified path."""
logger.info(f"Loading Pi0Fast policy from: {self.policy_path}")
# Initialize hydra config for LeRobot
init_logging()
# Load the Pi0Fast policy
self.policy = PI0FASTPolicy.from_pretrained(self.policy_path)
self.policy.to(self.device)
self.policy.eval()
# Set up image transforms based on policy config
if hasattr(self.policy, "config"):
self._setup_image_transforms()
self.is_loaded = True
logger.info(f"Pi0Fast policy loaded successfully on {self.device}")
def _setup_image_transforms(self):
"""Set up image transforms based on the policy configuration."""
# Get image size from policy config
config = self.policy.config
image_size = getattr(config, "image_size", 224)
# Create transforms for each camera
for camera_name in self.camera_names:
# Use policy-specific transforms if available
if hasattr(self.policy, "image_processor"):
self.image_transforms[camera_name] = self.policy.image_processor
else:
# Fall back to default transform
self.image_transforms[camera_name] = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
])
async def predict(
self, images: dict[str, np.ndarray], joint_positions: np.ndarray, **kwargs
) -> np.ndarray:
"""
Run Pi0Fast inference to predict actions.
Args:
images: Dictionary of {camera_name: rgb_image_array}
joint_positions: Current joint positions in LeRobot standard order
task: Optional language instruction (overrides instance language_instruction)
Returns:
Array of predicted actions
"""
if not self.is_loaded:
msg = "Policy not loaded. Call load_policy() first."
raise RuntimeError(msg)
# Preprocess inputs
processed_images = self.preprocess_images(images)
processed_joints = self.preprocess_joint_positions(joint_positions)
# Get language instruction
task = kwargs.get("task", self.language_instruction)
# Prepare batch inputs for Pi0Fast
batch = self._prepare_batch(processed_images, processed_joints, task)
# Run inference
with torch.no_grad():
action = self.policy.predict(batch)
# Convert to numpy
if isinstance(action, torch.Tensor):
action = action.cpu().numpy()
return action
def _prepare_batch(
self,
images: dict[str, torch.Tensor],
joints: torch.Tensor,
task: str | None = None,
) -> dict:
"""
Prepare batch inputs for Pi0Fast model.
Args:
images: Preprocessed images
joints: Preprocessed joint positions
task: Language instruction
Returns:
Batch dictionary for Pi0Fast model
"""
batch = {}
# Add images to batch
for camera_name, image_tensor in images.items():
# Add batch dimension if needed
if len(image_tensor.shape) == 3:
image_tensor = image_tensor.unsqueeze(0)
batch[f"observation.images.{camera_name}"] = image_tensor
# Add joint positions
if len(joints.shape) == 1:
joints = joints.unsqueeze(0)
batch["observation.state"] = joints
# Add language instruction if provided
if task:
batch["task"] = task
return batch
|