Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# Copyright 2024 The LeRobot Team and contributors. | |
# Licensed under the Apache License, Version 2.0 | |
""" | |
Inference Server Models Module. | |
This module provides various inference engines for different robotic policy types, | |
including ACT, Pi0, SmolVLA, and Diffusion Policy models. | |
""" | |
import logging | |
from .joint_config import JointConfig | |
logger = logging.getLogger(__name__) | |
# Import engines with optional dependencies | |
_available_engines = {} | |
_available_policies = [] | |
# Try to import ACT (should always work) | |
try: | |
from .act_inference import ACTInferenceEngine | |
_available_engines["act"] = ACTInferenceEngine | |
_available_policies.append("act") | |
except ImportError as e: | |
logger.warning(f"ACT policy not available: {e}") | |
# Try to import Pi0 (optional) | |
try: | |
from .pi0_inference import Pi0InferenceEngine | |
_available_engines["pi0"] = Pi0InferenceEngine | |
_available_policies.append("pi0") | |
except ImportError as e: | |
logger.warning(f"Pi0 policy not available: {e}") | |
# Try to import Pi0Fast (optional) | |
try: | |
from .pi0fast_inference import Pi0FastInferenceEngine | |
_available_engines["pi0fast"] = Pi0FastInferenceEngine | |
_available_policies.append("pi0fast") | |
except ImportError as e: | |
logger.warning(f"Pi0Fast policy not available: {e}") | |
# Try to import SmolVLA (optional) | |
try: | |
from .smolvla_inference import SmolVLAInferenceEngine | |
_available_engines["smolvla"] = SmolVLAInferenceEngine | |
_available_policies.append("smolvla") | |
except ImportError as e: | |
logger.warning(f"SmolVLA policy not available: {e}") | |
# Try to import Diffusion (optional - known to have dependency issues) | |
try: | |
from .diffusion_inference import DiffusionInferenceEngine | |
_available_engines["diffusion"] = DiffusionInferenceEngine | |
_available_policies.append("diffusion") | |
except ImportError as e: | |
logger.warning(f"Diffusion policy not available: {e}") | |
# Export what's available | |
__all__ = [ | |
# Shared configuration | |
"JointConfig", | |
# Factory functions | |
"get_inference_engine", | |
"list_supported_policies", | |
] | |
# Add available engines to exports | |
for policy_type in _available_policies: | |
if policy_type == "act": | |
__all__.append("ACTInferenceEngine") | |
elif policy_type == "pi0": | |
__all__.append("Pi0InferenceEngine") | |
elif policy_type == "pi0fast": | |
__all__.append("Pi0FastInferenceEngine") | |
elif policy_type == "smolvla": | |
__all__.append("SmolVLAInferenceEngine") | |
elif policy_type == "diffusion": | |
__all__.append("DiffusionInferenceEngine") | |
def list_supported_policies() -> list[str]: | |
"""Return a list of supported policy types based on available dependencies.""" | |
return _available_policies.copy() | |
def get_inference_engine(policy_type: str, **kwargs): | |
""" | |
Factory function to create an inference engine based on policy type. | |
Args: | |
policy_type: Type of policy (act, pi0, pi0fast, smolvla, diffusion) | |
**kwargs: Arguments to pass to the inference engine constructor | |
Returns: | |
Appropriate inference engine instance | |
Raises: | |
ValueError: If policy type is not supported or not available | |
""" | |
policy_type = policy_type.lower() | |
if policy_type not in _available_engines: | |
available = list_supported_policies() | |
msg = f"Policy type '{policy_type}' is not available. Available policies: {available}" | |
raise ValueError(msg) | |
engine_class = _available_engines[policy_type] | |
return engine_class(**kwargs) | |