InferBench / api /fal.py
davidberenstein1957's picture
refactor: improve code formatting and organization across multiple API and benchmark files
34046e2
raw
history blame
1.38 kB
import time
from io import BytesIO
from pathlib import Path
from typing import Any
import fal_client
import requests
from PIL import Image
from api.flux import FluxAPI
class FalAPI(FluxAPI):
@property
def name(self) -> str:
return "fal"
def generate_image(self, prompt: str, save_path: Path) -> float:
start_time = time.time()
result = fal_client.subscribe(
"fal-ai/flux/dev",
arguments={
"seed": 0,
"prompt": prompt,
"image_size": "square_hd", # 1024x1024 image
"num_images": 1,
"guidance_scale": 3.5,
"num_inference_steps": 28,
"enable_safety_checker": True,
},
)
end_time = time.time()
url = result["images"][0]["url"]
self._save_image_from_url(url, save_path)
return end_time - start_time
def _save_image_from_url(self, url: str, save_path: Path):
response = requests.get(url)
image = Image.open(BytesIO(response.content))
save_path.parent.mkdir(parents=True, exist_ok=True)
image.save(save_path)
def _save_image_from_result(self, result: Any, save_path: Path):
save_path.parent.mkdir(parents=True, exist_ok=True)
with open(save_path, "wb") as f:
f.write(result.content)