InferBench / api /pruna.py
davidberenstein1957's picture
refactor: improve code formatting and organization across multiple API and benchmark files
34046e2
raw
history blame
1.67 kB
import os
import time
from pathlib import Path
from typing import Any
import replicate
from dotenv import load_dotenv
from api.flux import FluxAPI
class PrunaAPI(FluxAPI):
def __init__(self, speed_mode: str):
self._speed_mode = speed_mode
self._speed_mode_name = (
speed_mode.split(" ")[0].strip().lower().replace(" ", "_")
)
load_dotenv()
self._api_key = os.getenv("REPLICATE_API_TOKEN")
if not self._api_key:
raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
@property
def name(self) -> str:
return f"pruna_{self._speed_mode_name}"
def generate_image(self, prompt: str, save_path: Path) -> float:
start_time = time.time()
result = replicate.run(
"prunaai/flux.1-juiced:58977759ff2870cc010597ae75f4d87866d169b248e02b6e86c4e1bf8afe2410",
input={
"seed": 0,
"prompt": prompt,
"guidance": 3.5,
"num_outputs": 1,
"aspect_ratio": "1:1",
"output_format": "png",
"speed_mode": self._speed_mode,
"num_inference_steps": 28,
},
)
end_time = time.time()
if result:
self._save_image_from_result(result, save_path)
else:
raise Exception("No result returned from Replicate API")
return end_time - start_time
def _save_image_from_result(self, result: Any, save_path: Path):
save_path.parent.mkdir(parents=True, exist_ok=True)
with open(save_path, "wb") as f:
f.write(result.read())