Spaces:
Sleeping
Sleeping
File size: 2,133 Bytes
63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 3380376 63ed3a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
#!/usr/bin/env python
# Copyright 2024 The LeRobot Team and contributors.
# Licensed under the Apache License, Version 2.0
"""
RobotHub Inference Server - Model inference engines for various policy types.
This module provides unified inference engines for different policy architectures
including ACT, Pi0, SmolVLA, and Diffusion policies.
"""
import logging
from .act_inference import ACTInferenceEngine
from .base_inference import BaseInferenceEngine
from .diffusion_inference import DiffusionInferenceEngine
from .joint_config import JointConfig
from .pi0_inference import Pi0InferenceEngine
from .pi0fast_inference import Pi0FastInferenceEngine
from .smolvla_inference import SmolVLAInferenceEngine
logger = logging.getLogger(__name__)
# Core exports that are always available
__all__ = [
"ACTInferenceEngine",
"BaseInferenceEngine",
"DiffusionInferenceEngine",
"JointConfig",
"Pi0FastInferenceEngine",
"Pi0InferenceEngine",
"SmolVLAInferenceEngine",
"get_inference_engine",
]
POLICY_ENGINES = {
"act": ACTInferenceEngine,
"pi0": Pi0InferenceEngine,
"pi0fast": Pi0FastInferenceEngine,
"smolvla": SmolVLAInferenceEngine,
"diffusion": DiffusionInferenceEngine,
}
def get_inference_engine(policy_type: str, **kwargs) -> BaseInferenceEngine:
"""
Get an inference engine instance for the specified policy type.
Args:
policy_type: Type of policy ('act', 'pi0', 'pi0fast', 'smolvla', 'diffusion')
**kwargs: Additional arguments passed to the engine constructor
Returns:
BaseInferenceEngine: Configured inference engine instance
Raises:
ValueError: If policy_type is not supported or not available
"""
if policy_type not in POLICY_ENGINES:
available = list(POLICY_ENGINES.keys())
if not available:
msg = "No policy engines are available. Check your LeRobot installation."
else:
msg = f"Unsupported policy type: {policy_type}. Available: {available}"
raise ValueError(msg)
engine_class = POLICY_ENGINES[policy_type]
return engine_class(**kwargs)
|