File size: 2,133 Bytes
63ed3a7
 
 
 
 
 
3380376
63ed3a7
3380376
 
63ed3a7
 
 
 
3380376
 
 
63ed3a7
3380376
 
 
63ed3a7
 
 
3380376
63ed3a7
3380376
 
 
63ed3a7
3380376
 
 
63ed3a7
 
 
 
3380376
 
 
 
 
 
 
63ed3a7
 
3380376
63ed3a7
3380376
63ed3a7
 
3380376
 
63ed3a7
 
3380376
63ed3a7
 
3380376
63ed3a7
 
3380376
 
 
 
 
 
63ed3a7
 
3380376
63ed3a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python

# Copyright 2024 The LeRobot Team and contributors.
# Licensed under the Apache License, Version 2.0

"""
RobotHub Inference Server - Model inference engines for various policy types.

This module provides unified inference engines for different policy architectures
including ACT, Pi0, SmolVLA, and Diffusion policies.
"""

import logging

from .act_inference import ACTInferenceEngine
from .base_inference import BaseInferenceEngine
from .diffusion_inference import DiffusionInferenceEngine
from .joint_config import JointConfig
from .pi0_inference import Pi0InferenceEngine
from .pi0fast_inference import Pi0FastInferenceEngine
from .smolvla_inference import SmolVLAInferenceEngine

logger = logging.getLogger(__name__)

# Core exports that are always available
__all__ = [
    "ACTInferenceEngine",
    "BaseInferenceEngine",
    "DiffusionInferenceEngine",
    "JointConfig",
    "Pi0FastInferenceEngine",
    "Pi0InferenceEngine",
    "SmolVLAInferenceEngine",
    "get_inference_engine",
]


POLICY_ENGINES = {
    "act": ACTInferenceEngine,
    "pi0": Pi0InferenceEngine,
    "pi0fast": Pi0FastInferenceEngine,
    "smolvla": SmolVLAInferenceEngine,
    "diffusion": DiffusionInferenceEngine,
}


def get_inference_engine(policy_type: str, **kwargs) -> BaseInferenceEngine:
    """
    Get an inference engine instance for the specified policy type.

    Args:
        policy_type: Type of policy ('act', 'pi0', 'pi0fast', 'smolvla', 'diffusion')
        **kwargs: Additional arguments passed to the engine constructor

    Returns:
        BaseInferenceEngine: Configured inference engine instance

    Raises:
        ValueError: If policy_type is not supported or not available

    """
    if policy_type not in POLICY_ENGINES:
        available = list(POLICY_ENGINES.keys())
        if not available:
            msg = "No policy engines are available. Check your LeRobot installation."
        else:
            msg = f"Unsupported policy type: {policy_type}. Available: {available}"
        raise ValueError(msg)

    engine_class = POLICY_ENGINES[policy_type]
    return engine_class(**kwargs)