import os import time from pathlib import Path from typing import Any import replicate from dotenv import load_dotenv from api.flux import FluxAPI class ReplicateAPI(FluxAPI): def __init__(self): load_dotenv() self._api_key = os.getenv("REPLICATE_API_TOKEN") if not self._api_key: raise ValueError("REPLICATE_API_TOKEN not found in environment variables") @property def name(self) -> str: return "replicate_go_fast" def generate_image(self, prompt: str, save_path: Path) -> float: start_time = time.time() result = replicate.run( "black-forest-labs/flux-dev", input={ "seed": 0, "prompt": prompt, "go_fast": True, "guidance": 3.5, "num_outputs": 1, "aspect_ratio": "1:1", "output_format": "png", "num_inference_steps": 28, }, ) end_time = time.time() if result and len(result) > 0: self._save_image_from_result(result[0], save_path) else: raise Exception("No result returned from Replicate API") return end_time - start_time def _save_image_from_result(self, result: Any, save_path: Path): save_path.parent.mkdir(parents=True, exist_ok=True) with open(save_path, "wb") as f: f.write(result.read())