InferBench / api /baseline.py
davidberenstein1957's picture
refactor: improve code formatting and organization across multiple API and benchmark files
34046e2
raw
history blame
1.53 kB
import os
import time
from pathlib import Path
from typing import Any
from dotenv import load_dotenv
from api.flux import FluxAPI
class BaselineAPI(FluxAPI):
"""
As our baseline, we use the Replicate API with go_fast=False.
"""
def __init__(self):
load_dotenv()
self._api_key = os.getenv("REPLICATE_API_TOKEN")
if not self._api_key:
raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
@property
def name(self) -> str:
return "baseline"
def generate_image(self, prompt: str, save_path: Path) -> float:
import replicate
start_time = time.time()
result = replicate.run(
"black-forest-labs/flux-dev",
input={
"prompt": prompt,
"go_fast": False,
"guidance": 3.5,
"num_outputs": 1,
"aspect_ratio": "1:1",
"output_format": "png",
"num_inference_steps": 28,
"seed": 0,
},
)
end_time = time.time()
if result and len(result) > 0:
self._save_image_from_result(result[0], save_path)
else:
raise Exception("No result returned from Replicate API")
return end_time - start_time
def _save_image_from_result(self, result: Any, save_path: Path):
save_path.parent.mkdir(parents=True, exist_ok=True)
with open(save_path, "wb") as f:
f.write(result.read())