Spaces:
Running
Running
nifleisch
commited on
Commit
·
4f41410
1
Parent(s):
2c50826
fix: fix several erros
Browse files- .env.example +4 -0
- .gitignore +4 -0
- README.md +8 -0
- api/__init__.py +16 -1
- api/fal.py +1 -1
- api/pruna_dev.py +49 -0
- benchmark/hps.py +2 -2
- benchmark/metrics/__init__.py +3 -2
- benchmark/metrics/arniqa.py +3 -2
- benchmark/metrics/clip_iqa.py +1 -1
- benchmark/metrics/vqa.py +4 -9
- environment.yml +27 -0
- evaluate.py +15 -7
- pyproject.toml +0 -22
- uv.lock +0 -0
.env.example
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FIREWORKS_API_TOKEN=your_fireworks_api_token_here
|
2 |
+
REPLICATE_API_TOKEN=your_replicate_api_token_here
|
3 |
+
FAL_KEY=your_fal_key_here
|
4 |
+
TOGETHER_API_KEY=your_together_api_key_here
|
.gitignore
CHANGED
@@ -172,3 +172,7 @@ cython_debug/
|
|
172 |
|
173 |
# PyPI configuration file
|
174 |
.pypirc
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
# PyPI configuration file
|
174 |
.pypirc
|
175 |
+
|
176 |
+
evaluation_results/
|
177 |
+
images/
|
178 |
+
hf_cache/
|
README.md
CHANGED
@@ -1,2 +1,10 @@
|
|
1 |
# InferBench-
|
2 |
Evaluate the quality and efficiency of image gen api's.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# InferBench-
|
2 |
Evaluate the quality and efficiency of image gen api's.
|
3 |
+
|
4 |
+
Install dependencies with conda like that:
|
5 |
+
|
6 |
+
conda env create -f environment.yml
|
7 |
+
|
8 |
+
|
9 |
+
Create .env file with all the files you will need.
|
10 |
+
python sample.py replicate draw_bench genai_bench geneval hps parti
|
api/__init__.py
CHANGED
@@ -1,13 +1,26 @@
|
|
1 |
-
from typing import
|
2 |
|
3 |
from api.baseline import BaselineAPI
|
4 |
from api.fireworks import FireworksAPI
|
5 |
from api.flux import FluxAPI
|
6 |
from api.pruna import PrunaAPI
|
|
|
7 |
from api.replicate import ReplicateAPI
|
8 |
from api.together import TogetherAPI
|
9 |
from api.fal import FalAPI
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def create_api(api_type: str) -> FluxAPI:
|
12 |
"""
|
13 |
Factory function to create API instances.
|
@@ -27,6 +40,8 @@ def create_api(api_type: str) -> FluxAPI:
|
|
27 |
Raises:
|
28 |
ValueError: If an invalid API type is provided
|
29 |
"""
|
|
|
|
|
30 |
if api_type.startswith("pruna_"):
|
31 |
speed_mode = api_type[6:] # Remove "pruna_" prefix
|
32 |
return PrunaAPI(speed_mode)
|
|
|
1 |
+
from typing import Type
|
2 |
|
3 |
from api.baseline import BaselineAPI
|
4 |
from api.fireworks import FireworksAPI
|
5 |
from api.flux import FluxAPI
|
6 |
from api.pruna import PrunaAPI
|
7 |
+
from api.pruna_dev import PrunaDevAPI
|
8 |
from api.replicate import ReplicateAPI
|
9 |
from api.together import TogetherAPI
|
10 |
from api.fal import FalAPI
|
11 |
|
12 |
+
__all__ = [
|
13 |
+
'create_api',
|
14 |
+
'FluxAPI',
|
15 |
+
'BaselineAPI',
|
16 |
+
'FireworksAPI',
|
17 |
+
'PrunaAPI',
|
18 |
+
'ReplicateAPI',
|
19 |
+
'TogetherAPI',
|
20 |
+
'FalAPI',
|
21 |
+
'PrunaDevAPI',
|
22 |
+
]
|
23 |
+
|
24 |
def create_api(api_type: str) -> FluxAPI:
|
25 |
"""
|
26 |
Factory function to create API instances.
|
|
|
40 |
Raises:
|
41 |
ValueError: If an invalid API type is provided
|
42 |
"""
|
43 |
+
if api_type == "pruna_dev":
|
44 |
+
return PrunaDevAPI()
|
45 |
if api_type.startswith("pruna_"):
|
46 |
speed_mode = api_type[6:] # Remove "pruna_" prefix
|
47 |
return PrunaAPI(speed_mode)
|
api/fal.py
CHANGED
@@ -7,7 +7,7 @@ import fal_client
|
|
7 |
import requests
|
8 |
from PIL import Image
|
9 |
|
10 |
-
from flux import FluxAPI
|
11 |
|
12 |
|
13 |
class FalAPI(FluxAPI):
|
|
|
7 |
import requests
|
8 |
from PIL import Image
|
9 |
|
10 |
+
from api.flux import FluxAPI
|
11 |
|
12 |
|
13 |
class FalAPI(FluxAPI):
|
api/pruna_dev.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
import replicate
|
8 |
+
|
9 |
+
from api.flux import FluxAPI
|
10 |
+
|
11 |
+
|
12 |
+
class PrunaDevAPI(FluxAPI):
|
13 |
+
def __init__(self):
|
14 |
+
load_dotenv()
|
15 |
+
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
16 |
+
if not self._api_key:
|
17 |
+
raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
|
18 |
+
|
19 |
+
@property
|
20 |
+
def name(self) -> str:
|
21 |
+
return "pruna_dev"
|
22 |
+
|
23 |
+
def generate_image(self, prompt: str, save_path: Path) -> float:
|
24 |
+
start_time = time.time()
|
25 |
+
result = replicate.run(
|
26 |
+
"prunaai/flux.1-dev:938a4eb31a87d65fb7b23fc300fb5b7ab88a36844bb26e54e1d1dec7acf4eefe",
|
27 |
+
input={
|
28 |
+
"seed": 0,
|
29 |
+
"prompt": prompt,
|
30 |
+
"guidance": 3.5,
|
31 |
+
"num_outputs": 1,
|
32 |
+
"aspect_ratio": "1:1",
|
33 |
+
"output_format": "png",
|
34 |
+
"speed_mode": "Juiced 🔥 (default)",
|
35 |
+
"num_inference_steps": 28,
|
36 |
+
},
|
37 |
+
)
|
38 |
+
end_time = time.time()
|
39 |
+
|
40 |
+
if result:
|
41 |
+
self._save_image_from_result(result, save_path)
|
42 |
+
else:
|
43 |
+
raise Exception("No result returned from Replicate API")
|
44 |
+
return end_time - start_time
|
45 |
+
|
46 |
+
def _save_image_from_result(self, result: Any, save_path: Path):
|
47 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
48 |
+
with open(save_path, "wb") as f:
|
49 |
+
f.write(result.read())
|
benchmark/hps.py
CHANGED
@@ -15,7 +15,7 @@ class HPSPrompts:
|
|
15 |
self._size = 0
|
16 |
for file in self.hps_prompt_files:
|
17 |
category = file.replace('.json', '')
|
18 |
-
with open(os.path.join('
|
19 |
prompts = json.load(f)
|
20 |
for i, prompt in enumerate(prompts):
|
21 |
if i == 100:
|
@@ -26,7 +26,7 @@ class HPSPrompts:
|
|
26 |
|
27 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
28 |
for filename, prompt in self.prompts.items():
|
29 |
-
yield prompt, filename
|
30 |
|
31 |
@property
|
32 |
def name(self) -> str:
|
|
|
15 |
self._size = 0
|
16 |
for file in self.hps_prompt_files:
|
17 |
category = file.replace('.json', '')
|
18 |
+
with open(os.path.join('downloads/hps', file), 'r') as f:
|
19 |
prompts = json.load(f)
|
20 |
for i, prompt in enumerate(prompts):
|
21 |
if i == 100:
|
|
|
26 |
|
27 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
28 |
for filename, prompt in self.prompts.items():
|
29 |
+
yield prompt, Path(filename)
|
30 |
|
31 |
@property
|
32 |
def name(self) -> str:
|
benchmark/metrics/__init__.py
CHANGED
@@ -6,7 +6,7 @@ from benchmark.metrics.clip_iqa import CLIPIQAMetric
|
|
6 |
from benchmark.metrics.image_reward import ImageRewardMetric
|
7 |
from benchmark.metrics.sharpness import SharpnessMetric
|
8 |
from benchmark.metrics.vqa import VQAMetric
|
9 |
-
|
10 |
|
11 |
def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
|
12 |
"""
|
@@ -20,7 +20,7 @@ def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAM
|
|
20 |
- "image_reward"
|
21 |
- "sharpness"
|
22 |
- "vqa"
|
23 |
-
|
24 |
Returns:
|
25 |
An instance of the requested metric implementation
|
26 |
|
@@ -34,6 +34,7 @@ def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAM
|
|
34 |
"image_reward": ImageRewardMetric,
|
35 |
"sharpness": SharpnessMetric,
|
36 |
"vqa": VQAMetric,
|
|
|
37 |
}
|
38 |
|
39 |
if metric_type not in metric_map:
|
|
|
6 |
from benchmark.metrics.image_reward import ImageRewardMetric
|
7 |
from benchmark.metrics.sharpness import SharpnessMetric
|
8 |
from benchmark.metrics.vqa import VQAMetric
|
9 |
+
from benchmark.metrics.hps import HPSMetric
|
10 |
|
11 |
def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
|
12 |
"""
|
|
|
20 |
- "image_reward"
|
21 |
- "sharpness"
|
22 |
- "vqa"
|
23 |
+
- "hps"
|
24 |
Returns:
|
25 |
An instance of the requested metric implementation
|
26 |
|
|
|
34 |
"image_reward": ImageRewardMetric,
|
35 |
"sharpness": SharpnessMetric,
|
36 |
"vqa": VQAMetric,
|
37 |
+
"hps": HPSMetric,
|
38 |
}
|
39 |
|
40 |
if metric_type not in metric_map:
|
benchmark/metrics/arniqa.py
CHANGED
@@ -8,19 +8,20 @@ from torchmetrics.image.arniqa import ARNIQA
|
|
8 |
|
9 |
class ARNIQAMetric:
|
10 |
def __init__(self):
|
|
|
11 |
self.metric = ARNIQA(
|
12 |
regressor_dataset="koniq10k",
|
13 |
reduction="mean",
|
14 |
normalize=True,
|
15 |
autocast=False
|
16 |
)
|
17 |
-
|
18 |
@property
|
19 |
def name(self) -> str:
|
20 |
return "arniqa"
|
21 |
|
22 |
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
23 |
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
|
24 |
-
image_tensor = image_tensor.unsqueeze(0)
|
25 |
score = self.metric(image_tensor)
|
26 |
return {"arniqa": score.item()}
|
|
|
8 |
|
9 |
class ARNIQAMetric:
|
10 |
def __init__(self):
|
11 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
self.metric = ARNIQA(
|
13 |
regressor_dataset="koniq10k",
|
14 |
reduction="mean",
|
15 |
normalize=True,
|
16 |
autocast=False
|
17 |
)
|
18 |
+
self.metric.to(self.device)
|
19 |
@property
|
20 |
def name(self) -> str:
|
21 |
return "arniqa"
|
22 |
|
23 |
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
24 |
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
|
25 |
+
image_tensor = image_tensor.unsqueeze(0).to(self.device)
|
26 |
score = self.metric(image_tensor)
|
27 |
return {"arniqa": score.item()}
|
benchmark/metrics/clip_iqa.py
CHANGED
@@ -12,7 +12,7 @@ class CLIPIQAMetric:
|
|
12 |
self.metric = CLIPImageQualityAssessment(
|
13 |
model_name_or_path="clip_iqa",
|
14 |
data_range=255.0,
|
15 |
-
prompts=
|
16 |
)
|
17 |
self.metric.to(self.device)
|
18 |
|
|
|
12 |
self.metric = CLIPImageQualityAssessment(
|
13 |
model_name_or_path="clip_iqa",
|
14 |
data_range=255.0,
|
15 |
+
prompts=("quality",)
|
16 |
)
|
17 |
self.metric.to(self.device)
|
18 |
|
benchmark/metrics/vqa.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
-
import
|
2 |
-
import tempfile
|
3 |
from typing import Dict
|
4 |
|
5 |
import t2v_metrics
|
6 |
-
from PIL import Image
|
7 |
|
8 |
class VQAMetric:
|
9 |
def __init__(self):
|
@@ -15,11 +13,8 @@ class VQAMetric:
|
|
15 |
|
16 |
def compute_score(
|
17 |
self,
|
18 |
-
|
19 |
prompt: str,
|
20 |
) -> Dict[str, float]:
|
21 |
-
|
22 |
-
|
23 |
-
score = self.metric(images=[tmp.name], texts=[prompt])
|
24 |
-
os.unlink(tmp.name)
|
25 |
-
return {"vqa_score": score[0][0].item()}
|
|
|
1 |
+
from pathlib import Path
|
|
|
2 |
from typing import Dict
|
3 |
|
4 |
import t2v_metrics
|
|
|
5 |
|
6 |
class VQAMetric:
|
7 |
def __init__(self):
|
|
|
13 |
|
14 |
def compute_score(
|
15 |
self,
|
16 |
+
image_path: Path,
|
17 |
prompt: str,
|
18 |
) -> Dict[str, float]:
|
19 |
+
score = self.metric(images=[str(image_path)], texts=[prompt])
|
20 |
+
return {"vqa": score[0][0].item()}
|
|
|
|
|
|
environment.yml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: inferbench
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.12
|
7 |
+
- numpy
|
8 |
+
- opencv
|
9 |
+
- pillow
|
10 |
+
- python-dotenv
|
11 |
+
- requests
|
12 |
+
- tqdm
|
13 |
+
- pip
|
14 |
+
- pip:
|
15 |
+
- datasets>=3.5.0
|
16 |
+
- fal-client>=0.5.9
|
17 |
+
- hpsv2>=1.2.0
|
18 |
+
- huggingface-hub>=0.30.2
|
19 |
+
- image-reward>=1.5
|
20 |
+
- replicate>=1.0.4
|
21 |
+
- t2v-metrics>=1.2
|
22 |
+
- together>=1.5.5
|
23 |
+
- torch>=2.7.0
|
24 |
+
- torchmetrics>=1.7.1
|
25 |
+
- clip
|
26 |
+
- diffusers<=0.31
|
27 |
+
- piq>=0.8.0
|
evaluate.py
CHANGED
@@ -2,10 +2,16 @@ import argparse
|
|
2 |
import json
|
3 |
from pathlib import Path
|
4 |
from typing import Dict
|
|
|
5 |
|
6 |
from benchmark import create_benchmark
|
7 |
from benchmark.metrics import create_metric
|
|
|
8 |
from PIL import Image
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Path("images")) -> Dict:
|
@@ -39,29 +45,31 @@ def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Pa
|
|
39 |
"api": api_type,
|
40 |
"benchmark": benchmark_type,
|
41 |
"metrics": {metric: 0.0 for metric in benchmark.metrics},
|
42 |
-
"avg_inference_time": 0.0,
|
43 |
"total_images": len(metadata)
|
44 |
}
|
|
|
45 |
|
46 |
-
for entry in metadata:
|
47 |
image_path = benchmark_dir / entry["filepath"]
|
48 |
if not image_path.exists():
|
49 |
continue
|
50 |
|
51 |
-
image = Image.open(image_path)
|
52 |
-
|
53 |
for metric_type, metric in metrics.items():
|
54 |
try:
|
55 |
-
|
|
|
|
|
|
|
|
|
56 |
results["metrics"][metric_type] += score[metric_type]
|
57 |
except Exception as e:
|
58 |
print(f"Error computing {metric_type} for {image_path}: {str(e)}")
|
59 |
|
60 |
-
|
61 |
|
62 |
for metric in results["metrics"]:
|
63 |
results["metrics"][metric] /= len(metadata)
|
64 |
-
results["
|
65 |
|
66 |
return results
|
67 |
|
|
|
2 |
import json
|
3 |
from pathlib import Path
|
4 |
from typing import Dict
|
5 |
+
import warnings
|
6 |
|
7 |
from benchmark import create_benchmark
|
8 |
from benchmark.metrics import create_metric
|
9 |
+
import numpy as np
|
10 |
from PIL import Image
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
15 |
|
16 |
|
17 |
def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Path("images")) -> Dict:
|
|
|
45 |
"api": api_type,
|
46 |
"benchmark": benchmark_type,
|
47 |
"metrics": {metric: 0.0 for metric in benchmark.metrics},
|
|
|
48 |
"total_images": len(metadata)
|
49 |
}
|
50 |
+
inference_times = []
|
51 |
|
52 |
+
for entry in tqdm(metadata):
|
53 |
image_path = benchmark_dir / entry["filepath"]
|
54 |
if not image_path.exists():
|
55 |
continue
|
56 |
|
|
|
|
|
57 |
for metric_type, metric in metrics.items():
|
58 |
try:
|
59 |
+
if metric_type == "vqa":
|
60 |
+
score = metric.compute_score(image_path, entry["prompt"])
|
61 |
+
else:
|
62 |
+
image = Image.open(image_path)
|
63 |
+
score = metric.compute_score(image, entry["prompt"])
|
64 |
results["metrics"][metric_type] += score[metric_type]
|
65 |
except Exception as e:
|
66 |
print(f"Error computing {metric_type} for {image_path}: {str(e)}")
|
67 |
|
68 |
+
inference_times.append(entry["inference_time"])
|
69 |
|
70 |
for metric in results["metrics"]:
|
71 |
results["metrics"][metric] /= len(metadata)
|
72 |
+
results["median_inference_time"] = np.median(inference_times).item()
|
73 |
|
74 |
return results
|
75 |
|
pyproject.toml
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
[project]
|
2 |
-
name = "inferbench"
|
3 |
-
version = "0.1.0"
|
4 |
-
requires-python = ">=3.12"
|
5 |
-
dependencies = [
|
6 |
-
"datasets>=3.5.0",
|
7 |
-
"fal-client>=0.5.9",
|
8 |
-
"hpsv2>=1.2.0",
|
9 |
-
"huggingface-hub>=0.30.2",
|
10 |
-
"image-reward>=1.5",
|
11 |
-
"numpy>=2.2.5",
|
12 |
-
"opencv-python>=4.11.0.86",
|
13 |
-
"pillow>=11.2.1",
|
14 |
-
"python-dotenv>=1.1.0",
|
15 |
-
"replicate>=1.0.4",
|
16 |
-
"requests>=2.32.3",
|
17 |
-
"t2v-metrics>=1.2",
|
18 |
-
"together>=1.5.5",
|
19 |
-
"torch>=2.7.0",
|
20 |
-
"torchmetrics>=1.7.1",
|
21 |
-
"tqdm>=4.67.1",
|
22 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uv.lock
DELETED
The diff for this file is too large to render.
See raw diff
|
|