Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# Copyright 2024 The LeRobot Team and contributors. | |
# Licensed under the Apache License, Version 2.0 | |
""" | |
RobotHub Inference Server - Model inference engines for various policy types. | |
This module provides unified inference engines for different policy architectures | |
including ACT, Pi0, SmolVLA, and Diffusion policies. | |
""" | |
import logging | |
from .act_inference import ACTInferenceEngine | |
from .base_inference import BaseInferenceEngine | |
from .diffusion_inference import DiffusionInferenceEngine | |
from .joint_config import JointConfig | |
from .pi0_inference import Pi0InferenceEngine | |
from .pi0fast_inference import Pi0FastInferenceEngine | |
from .smolvla_inference import SmolVLAInferenceEngine | |
logger = logging.getLogger(__name__) | |
# Core exports that are always available | |
__all__ = [ | |
"ACTInferenceEngine", | |
"BaseInferenceEngine", | |
"DiffusionInferenceEngine", | |
"JointConfig", | |
"Pi0FastInferenceEngine", | |
"Pi0InferenceEngine", | |
"SmolVLAInferenceEngine", | |
"get_inference_engine", | |
] | |
POLICY_ENGINES = { | |
"act": ACTInferenceEngine, | |
"pi0": Pi0InferenceEngine, | |
"pi0fast": Pi0FastInferenceEngine, | |
"smolvla": SmolVLAInferenceEngine, | |
"diffusion": DiffusionInferenceEngine, | |
} | |
def get_inference_engine(policy_type: str, **kwargs) -> BaseInferenceEngine: | |
""" | |
Get an inference engine instance for the specified policy type. | |
Args: | |
policy_type: Type of policy ('act', 'pi0', 'pi0fast', 'smolvla', 'diffusion') | |
**kwargs: Additional arguments passed to the engine constructor | |
Returns: | |
BaseInferenceEngine: Configured inference engine instance | |
Raises: | |
ValueError: If policy_type is not supported or not available | |
""" | |
if policy_type not in POLICY_ENGINES: | |
available = list(POLICY_ENGINES.keys()) | |
if not available: | |
msg = "No policy engines are available. Check your LeRobot installation." | |
else: | |
msg = f"Unsupported policy type: {policy_type}. Available: {available}" | |
raise ValueError(msg) | |
engine_class = POLICY_ENGINES[policy_type] | |
return engine_class(**kwargs) | |