InferBench / sample.py
davidberenstein1957's picture
fix: change file opening mode to append for metadata file in generate_images function
9fa4df6
raw
history blame
4.41 kB
import argparse
import json
from pathlib import Path
from typing import List
from tqdm import tqdm
from api import create_api
from benchmark import create_benchmark
def generate_images(api_type: str, benchmarks: List[str]):
images_dir = Path("images")
api = create_api(api_type)
api_dir = images_dir / api_type
api_dir.mkdir(parents=True, exist_ok=True)
for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
print(f"\nProcessing benchmark: {benchmark_type}")
benchmark = create_benchmark(benchmark_type)
if benchmark_type == "geneval":
benchmark_dir = api_dir / benchmark_type
benchmark_dir.mkdir(parents=True, exist_ok=True)
metadata_file = benchmark_dir / "metadata.jsonl"
existing_metadata = {}
if metadata_file.exists():
with open(metadata_file, "r") as f:
for line in f:
entry = json.loads(line)
existing_metadata[entry["filepath"]] = entry
with open(metadata_file, "a") as f:
for metadata, folder_name in tqdm(
benchmark,
desc=f"Generating images for {benchmark_type}",
leave=False,
):
sample_path = benchmark_dir / folder_name
samples_path = sample_path / "samples"
samples_path.mkdir(parents=True, exist_ok=True)
image_path = samples_path / "0000.png"
if image_path.exists():
continue
try:
inference_time = api.generate_image(
metadata["prompt"], image_path
)
metadata_entry = {
"filepath": str(image_path),
"prompt": metadata["prompt"],
"inference_time": inference_time,
}
f.write(json.dumps(metadata_entry) + "\n")
except Exception as e:
print(
f"\nError generating image for prompt: {metadata['prompt']}"
)
print(f"Error: {str(e)}")
continue
else:
benchmark_dir = api_dir / benchmark_type
benchmark_dir.mkdir(parents=True, exist_ok=True)
metadata_file = benchmark_dir / "metadata.jsonl"
existing_metadata = {}
if metadata_file.exists():
with open(metadata_file, "r") as f:
for line in f:
entry = json.loads(line)
existing_metadata[entry["filepath"]] = entry
with open(metadata_file, "a") as f:
for prompt, image_path in tqdm(
benchmark,
desc=f"Generating images for {benchmark_type}",
leave=False,
):
if image_path in existing_metadata:
continue
full_image_path = benchmark_dir / image_path
if full_image_path.exists():
continue
try:
inference_time = api.generate_image(prompt, full_image_path)
metadata_entry = {
"filepath": str(image_path),
"prompt": prompt,
"inference_time": inference_time,
}
f.write(json.dumps(metadata_entry) + "\n")
except Exception as e:
print(f"\nError generating image for prompt: {prompt}")
print(f"Error: {str(e)}")
continue
def main():
parser = argparse.ArgumentParser(
description="Generate images for specified benchmarks using a given API"
)
parser.add_argument("api_type", help="Type of API to use for image generation")
parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")
args = parser.parse_args()
generate_images(args.api_type, args.benchmarks)
if __name__ == "__main__":
main()