Spaces:
Running
Running
import argparse | |
import json | |
from pathlib import Path | |
from typing import List | |
from tqdm import tqdm | |
from api import create_api | |
from benchmark import create_benchmark | |
def generate_images(api_type: str, benchmarks: List[str]): | |
images_dir = Path("images") | |
api = create_api(api_type) | |
api_dir = images_dir / api_type | |
api_dir.mkdir(parents=True, exist_ok=True) | |
for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"): | |
print(f"\nProcessing benchmark: {benchmark_type}") | |
benchmark = create_benchmark(benchmark_type) | |
if benchmark_type == "geneval": | |
benchmark_dir = api_dir / benchmark_type | |
benchmark_dir.mkdir(parents=True, exist_ok=True) | |
metadata_file = benchmark_dir / "metadata.jsonl" | |
existing_metadata = {} | |
if metadata_file.exists(): | |
with open(metadata_file, "r") as f: | |
for line in f: | |
entry = json.loads(line) | |
existing_metadata[entry["filepath"]] = entry | |
with open(metadata_file, "a") as f: | |
for metadata, folder_name in tqdm( | |
benchmark, | |
desc=f"Generating images for {benchmark_type}", | |
leave=False, | |
): | |
sample_path = benchmark_dir / folder_name | |
samples_path = sample_path / "samples" | |
samples_path.mkdir(parents=True, exist_ok=True) | |
image_path = samples_path / "0000.png" | |
if image_path.exists(): | |
continue | |
try: | |
inference_time = api.generate_image( | |
metadata["prompt"], image_path | |
) | |
metadata_entry = { | |
"filepath": str(image_path), | |
"prompt": metadata["prompt"], | |
"inference_time": inference_time, | |
} | |
f.write(json.dumps(metadata_entry) + "\n") | |
except Exception as e: | |
print( | |
f"\nError generating image for prompt: {metadata['prompt']}" | |
) | |
print(f"Error: {str(e)}") | |
continue | |
else: | |
benchmark_dir = api_dir / benchmark_type | |
benchmark_dir.mkdir(parents=True, exist_ok=True) | |
metadata_file = benchmark_dir / "metadata.jsonl" | |
existing_metadata = {} | |
if metadata_file.exists(): | |
with open(metadata_file, "r") as f: | |
for line in f: | |
entry = json.loads(line) | |
existing_metadata[entry["filepath"]] = entry | |
with open(metadata_file, "a") as f: | |
for prompt, image_path in tqdm( | |
benchmark, | |
desc=f"Generating images for {benchmark_type}", | |
leave=False, | |
): | |
if image_path in existing_metadata: | |
continue | |
full_image_path = benchmark_dir / image_path | |
if full_image_path.exists(): | |
continue | |
try: | |
inference_time = api.generate_image(prompt, full_image_path) | |
metadata_entry = { | |
"filepath": str(image_path), | |
"prompt": prompt, | |
"inference_time": inference_time, | |
} | |
f.write(json.dumps(metadata_entry) + "\n") | |
except Exception as e: | |
print(f"\nError generating image for prompt: {prompt}") | |
print(f"Error: {str(e)}") | |
continue | |
def main(): | |
parser = argparse.ArgumentParser( | |
description="Generate images for specified benchmarks using a given API" | |
) | |
parser.add_argument("api_type", help="Type of API to use for image generation") | |
parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run") | |
args = parser.parse_args() | |
generate_images(args.api_type, args.benchmarks) | |
if __name__ == "__main__": | |
main() | |