File size: 3,519 Bytes
63ed3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python

# Copyright 2024 The LeRobot Team and contributors.
# Licensed under the Apache License, Version 2.0

"""
Inference Server Models Module.

This module provides various inference engines for different robotic policy types,
including ACT, Pi0, SmolVLA, and Diffusion Policy models.
"""

import logging

from .joint_config import JointConfig

logger = logging.getLogger(__name__)

# Import engines with optional dependencies
_available_engines = {}
_available_policies = []

# Try to import ACT (should always work)
try:
    from .act_inference import ACTInferenceEngine

    _available_engines["act"] = ACTInferenceEngine
    _available_policies.append("act")
except ImportError as e:
    logger.warning(f"ACT policy not available: {e}")

# Try to import Pi0 (optional)
try:
    from .pi0_inference import Pi0InferenceEngine

    _available_engines["pi0"] = Pi0InferenceEngine
    _available_policies.append("pi0")
except ImportError as e:
    logger.warning(f"Pi0 policy not available: {e}")

# Try to import Pi0Fast (optional)
try:
    from .pi0fast_inference import Pi0FastInferenceEngine

    _available_engines["pi0fast"] = Pi0FastInferenceEngine
    _available_policies.append("pi0fast")
except ImportError as e:
    logger.warning(f"Pi0Fast policy not available: {e}")

# Try to import SmolVLA (optional)
try:
    from .smolvla_inference import SmolVLAInferenceEngine

    _available_engines["smolvla"] = SmolVLAInferenceEngine
    _available_policies.append("smolvla")
except ImportError as e:
    logger.warning(f"SmolVLA policy not available: {e}")

# Try to import Diffusion (optional - known to have dependency issues)
try:
    from .diffusion_inference import DiffusionInferenceEngine

    _available_engines["diffusion"] = DiffusionInferenceEngine
    _available_policies.append("diffusion")
except ImportError as e:
    logger.warning(f"Diffusion policy not available: {e}")

# Export what's available
__all__ = [
    # Shared configuration
    "JointConfig",
    # Factory functions
    "get_inference_engine",
    "list_supported_policies",
]

# Add available engines to exports
for policy_type in _available_policies:
    if policy_type == "act":
        __all__.append("ACTInferenceEngine")
    elif policy_type == "pi0":
        __all__.append("Pi0InferenceEngine")
    elif policy_type == "pi0fast":
        __all__.append("Pi0FastInferenceEngine")
    elif policy_type == "smolvla":
        __all__.append("SmolVLAInferenceEngine")
    elif policy_type == "diffusion":
        __all__.append("DiffusionInferenceEngine")


def list_supported_policies() -> list[str]:
    """Return a list of supported policy types based on available dependencies."""
    return _available_policies.copy()


def get_inference_engine(policy_type: str, **kwargs):
    """
    Factory function to create an inference engine based on policy type.

    Args:
        policy_type: Type of policy (act, pi0, pi0fast, smolvla, diffusion)
        **kwargs: Arguments to pass to the inference engine constructor

    Returns:
        Appropriate inference engine instance

    Raises:
        ValueError: If policy type is not supported or not available

    """
    policy_type = policy_type.lower()

    if policy_type not in _available_engines:
        available = list_supported_policies()
        msg = f"Policy type '{policy_type}' is not available. Available policies: {available}"
        raise ValueError(msg)

    engine_class = _available_engines[policy_type]
    return engine_class(**kwargs)