blanchon's picture
Update
3380376
#!/usr/bin/env python
# Copyright 2024 The LeRobot Team and contributors.
# Licensed under the Apache License, Version 2.0
"""
RobotHub Inference Server - Model inference engines for various policy types.
This module provides unified inference engines for different policy architectures
including ACT, Pi0, SmolVLA, and Diffusion policies.
"""
import logging
from .act_inference import ACTInferenceEngine
from .base_inference import BaseInferenceEngine
from .diffusion_inference import DiffusionInferenceEngine
from .joint_config import JointConfig
from .pi0_inference import Pi0InferenceEngine
from .pi0fast_inference import Pi0FastInferenceEngine
from .smolvla_inference import SmolVLAInferenceEngine
logger = logging.getLogger(__name__)
# Core exports that are always available
__all__ = [
"ACTInferenceEngine",
"BaseInferenceEngine",
"DiffusionInferenceEngine",
"JointConfig",
"Pi0FastInferenceEngine",
"Pi0InferenceEngine",
"SmolVLAInferenceEngine",
"get_inference_engine",
]
POLICY_ENGINES = {
"act": ACTInferenceEngine,
"pi0": Pi0InferenceEngine,
"pi0fast": Pi0FastInferenceEngine,
"smolvla": SmolVLAInferenceEngine,
"diffusion": DiffusionInferenceEngine,
}
def get_inference_engine(policy_type: str, **kwargs) -> BaseInferenceEngine:
"""
Get an inference engine instance for the specified policy type.
Args:
policy_type: Type of policy ('act', 'pi0', 'pi0fast', 'smolvla', 'diffusion')
**kwargs: Additional arguments passed to the engine constructor
Returns:
BaseInferenceEngine: Configured inference engine instance
Raises:
ValueError: If policy_type is not supported or not available
"""
if policy_type not in POLICY_ENGINES:
available = list(POLICY_ENGINES.keys())
if not available:
msg = "No policy engines are available. Check your LeRobot installation."
else:
msg = f"Unsupported policy type: {policy_type}. Available: {available}"
raise ValueError(msg)
engine_class = POLICY_ENGINES[policy_type]
return engine_class(**kwargs)