InferBench / evaluate.py
davidberenstein1957's picture
refactor: improve code formatting and organization across multiple API and benchmark files
34046e2
raw
history blame
3.81 kB
import argparse
import json
import warnings
from pathlib import Path
from typing import Dict
import numpy as np
from PIL import Image
from tqdm import tqdm
from benchmark import create_benchmark
from benchmark.metrics import create_metric
warnings.filterwarnings("ignore", category=FutureWarning)
def evaluate_benchmark(
benchmark_type: str, api_type: str, images_dir: Path = Path("images")
) -> Dict:
"""
Evaluate a benchmark's images using its specific metrics.
Args:
benchmark_type (str): Type of benchmark to evaluate
api_type (str): Type of API used to generate images
images_dir (Path): Base directory containing generated images
Returns:
Dict containing evaluation results
"""
benchmark = create_benchmark(benchmark_type)
benchmark_dir = images_dir / api_type / benchmark_type
metadata_file = benchmark_dir / "metadata.jsonl"
if not metadata_file.exists():
raise FileNotFoundError(
f"No metadata file found for {api_type}/{benchmark_type}. Please run sample.py first."
)
metadata = []
with open(metadata_file, "r") as f:
for line in f:
metadata.append(json.loads(line))
metrics = {
metric_type: create_metric(metric_type) for metric_type in benchmark.metrics
}
results = {
"api": api_type,
"benchmark": benchmark_type,
"metrics": {metric: 0.0 for metric in benchmark.metrics},
"total_images": len(metadata),
}
inference_times = []
for entry in tqdm(metadata):
image_path = benchmark_dir / entry["filepath"]
if not image_path.exists():
continue
for metric_type, metric in metrics.items():
try:
if metric_type == "vqa":
score = metric.compute_score(image_path, entry["prompt"])
else:
image = Image.open(image_path)
score = metric.compute_score(image, entry["prompt"])
results["metrics"][metric_type] += score[metric_type]
except Exception as e:
print(f"Error computing {metric_type} for {image_path}: {str(e)}")
inference_times.append(entry["inference_time"])
for metric in results["metrics"]:
results["metrics"][metric] /= len(metadata)
results["median_inference_time"] = np.median(inference_times).item()
return results
def main():
parser = argparse.ArgumentParser(
description="Evaluate generated images using benchmark-specific metrics"
)
parser.add_argument("api_type", help="Type of API to evaluate")
parser.add_argument(
"benchmarks", nargs="+", help="List of benchmark types to evaluate"
)
args = parser.parse_args()
results_dir = Path("evaluation_results")
results_dir.mkdir(exist_ok=True)
results_file = results_dir / f"{args.api_type}.jsonl"
existing_results = set()
if results_file.exists():
with open(results_file, "r") as f:
for line in f:
result = json.loads(line)
existing_results.add(result["benchmark"])
for benchmark_type in args.benchmarks:
if benchmark_type in existing_results:
print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
continue
try:
print(f"Evaluating {args.api_type}/{benchmark_type}")
results = evaluate_benchmark(benchmark_type, args.api_type)
# Append results to file
with open(results_file, "a") as f:
f.write(json.dumps(results) + "\n")
except Exception as e:
print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
if __name__ == "__main__":
main()