Spaces:
Running
Running
from typing import Type | |
from api.aws import AWSBedrockAPI | |
from api.baseline import BaselineAPI | |
from api.fal import FalAPI | |
from api.fireworks import FireworksAPI | |
from api.flux import FluxAPI | |
from api.pruna import PrunaAPI | |
from api.pruna_dev import PrunaDevAPI | |
from api.replicate import ReplicateAPI | |
from api.together import TogetherAPI | |
__all__ = [ | |
"create_api", | |
"FluxAPI", | |
"BaselineAPI", | |
"FireworksAPI", | |
"PrunaAPI", | |
"ReplicateAPI", | |
"TogetherAPI", | |
"FalAPI", | |
"PrunaDevAPI", | |
] | |
def create_api(api_type: str) -> FluxAPI: | |
""" | |
Factory function to create API instances. | |
Args: | |
api_type (str): The type of API to create. Must be one of: | |
- "baseline" | |
- "fireworks" | |
- "pruna_speed_mode" (where speed_mode is the desired speed mode) | |
- "replicate" | |
- "together" | |
- "fal" | |
- "aws" | |
Returns: | |
FluxAPI: An instance of the requested API implementation | |
Raises: | |
ValueError: If an invalid API type is provided | |
""" | |
if api_type == "pruna_dev": | |
return PrunaDevAPI() | |
if api_type.startswith("pruna_"): | |
speed_mode = api_type[6:] # Remove "pruna_" prefix | |
return PrunaAPI(speed_mode) | |
api_map: dict[str, Type[FluxAPI]] = { | |
"baseline": BaselineAPI, | |
"fireworks": FireworksAPI, | |
"replicate": ReplicateAPI, | |
"together": TogetherAPI, | |
"fal": FalAPI, | |
"aws": AWSBedrockAPI, | |
} | |
if api_type not in api_map: | |
raise ValueError( | |
f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'" | |
) | |
return api_map[api_type]() | |