import os import time from pathlib import Path from typing import Any import requests from dotenv import load_dotenv from api.flux import FluxAPI class FireworksAPI(FluxAPI): def __init__(self): load_dotenv() self._api_key = os.getenv("FIREWORKS_API_TOKEN") if not self._api_key: raise ValueError("FIREWORKS_API_TOKEN not found in environment variables") self._url = "https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image" @property def name(self) -> str: return "fireworks_fp8" def generate_image(self, prompt: str, save_path: Path) -> float: start_time = time.time() headers = { "Content-Type": "application/json", "Accept": "image/jpeg", "Authorization": f"Bearer {self._api_key}", } data = { "prompt": prompt, "aspect_ratio": "1:1", "guidance_scale": 3.5, "num_inference_steps": 28, "seed": 0, } result = requests.post(self._url, headers=headers, json=data) end_time = time.time() if result.status_code == 200: self._save_image_from_result(result, save_path) else: raise Exception(f"Error: {result.status_code} {result.text}") return end_time - start_time def _save_image_from_result(self, result: Any, save_path: Path): save_path.parent.mkdir(parents=True, exist_ok=True) with open(save_path, "wb") as f: f.write(result.content)