File size: 1,702 Bytes
4f41410
2c50826
34046e2
2c50826
34046e2
2c50826
 
 
4f41410
2c50826
 
 
4f41410
34046e2
 
 
 
 
 
 
 
 
4f41410
 
34046e2
2c50826
 
 
34046e2
2c50826
 
 
 
 
 
 
 
5291ba9
34046e2
2c50826
 
34046e2
2c50826
 
 
4f41410
 
2c50826
 
 
34046e2
2c50826
 
 
 
 
 
5291ba9
2c50826
34046e2
2c50826
34046e2
 
 
 
2c50826
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Type

from api.aws import AWSBedrockAPI
from api.baseline import BaselineAPI
from api.fal import FalAPI
from api.fireworks import FireworksAPI
from api.flux import FluxAPI
from api.pruna import PrunaAPI
from api.pruna_dev import PrunaDevAPI
from api.replicate import ReplicateAPI
from api.together import TogetherAPI

__all__ = [
    "create_api",
    "FluxAPI",
    "BaselineAPI",
    "FireworksAPI",
    "PrunaAPI",
    "ReplicateAPI",
    "TogetherAPI",
    "FalAPI",
    "PrunaDevAPI",
]


def create_api(api_type: str) -> FluxAPI:
    """
    Factory function to create API instances.

    Args:
        api_type (str): The type of API to create. Must be one of:
            - "baseline"
            - "fireworks"
            - "pruna_speed_mode" (where speed_mode is the desired speed mode)
            - "replicate"
            - "together"
            - "fal"
            - "aws"

    Returns:
        FluxAPI: An instance of the requested API implementation

    Raises:
        ValueError: If an invalid API type is provided
    """
    if api_type == "pruna_dev":
        return PrunaDevAPI()
    if api_type.startswith("pruna_"):
        speed_mode = api_type[6:]  # Remove "pruna_" prefix
        return PrunaAPI(speed_mode)

    api_map: dict[str, Type[FluxAPI]] = {
        "baseline": BaselineAPI,
        "fireworks": FireworksAPI,
        "replicate": ReplicateAPI,
        "together": TogetherAPI,
        "fal": FalAPI,
        "aws": AWSBedrockAPI,
    }

    if api_type not in api_map:
        raise ValueError(
            f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'"
        )

    return api_map[api_type]()