Spaces:
Running
Running
Nils Fleischmann
commited on
Commit
·
5291ba9
1
Parent(s):
4f41410
feat: add aws canva + examples in readme + my current environment + disable HPS for now
Browse files- README.md +11 -2
- api/__init__.py +3 -0
- api/aws.py +73 -0
- benchmark/metrics/__init__.py +2 -2
- nils_installs.txt +177 -0
README.md
CHANGED
@@ -6,5 +6,14 @@ Install dependencies with conda like that:
|
|
6 |
conda env create -f environment.yml
|
7 |
|
8 |
|
9 |
-
Create .env file with all the
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
conda env create -f environment.yml
|
7 |
|
8 |
|
9 |
+
Create .env file with all the credentials you will need.
|
10 |
+
|
11 |
+
This is how you can generate the images.
|
12 |
+
```
|
13 |
+
python sample.py replicate draw_bench genai_bench geneval hps parti
|
14 |
+
```
|
15 |
+
|
16 |
+
This is how you would evaluate the benchmarks once you have all images:
|
17 |
+
```
|
18 |
+
python evaluate.py replicate draw_bench genai_bench geneval hps parti
|
19 |
+
```
|
api/__init__.py
CHANGED
@@ -8,6 +8,7 @@ from api.pruna_dev import PrunaDevAPI
|
|
8 |
from api.replicate import ReplicateAPI
|
9 |
from api.together import TogetherAPI
|
10 |
from api.fal import FalAPI
|
|
|
11 |
|
12 |
__all__ = [
|
13 |
'create_api',
|
@@ -33,6 +34,7 @@ def create_api(api_type: str) -> FluxAPI:
|
|
33 |
- "replicate"
|
34 |
- "together"
|
35 |
- "fal"
|
|
|
36 |
|
37 |
Returns:
|
38 |
FluxAPI: An instance of the requested API implementation
|
@@ -52,6 +54,7 @@ def create_api(api_type: str) -> FluxAPI:
|
|
52 |
"replicate": ReplicateAPI,
|
53 |
"together": TogetherAPI,
|
54 |
"fal": FalAPI,
|
|
|
55 |
}
|
56 |
|
57 |
if api_type not in api_map:
|
|
|
8 |
from api.replicate import ReplicateAPI
|
9 |
from api.together import TogetherAPI
|
10 |
from api.fal import FalAPI
|
11 |
+
from api.aws import AWSBedrockAPI
|
12 |
|
13 |
__all__ = [
|
14 |
'create_api',
|
|
|
34 |
- "replicate"
|
35 |
- "together"
|
36 |
- "fal"
|
37 |
+
- "aws"
|
38 |
|
39 |
Returns:
|
40 |
FluxAPI: An instance of the requested API implementation
|
|
|
54 |
"replicate": ReplicateAPI,
|
55 |
"together": TogetherAPI,
|
56 |
"fal": FalAPI,
|
57 |
+
"aws": AWSBedrockAPI,
|
58 |
}
|
59 |
|
60 |
if api_type not in api_map:
|
api/aws.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import base64
|
4 |
+
import json
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import boto3
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
|
11 |
+
from api.flux import FluxAPI
|
12 |
+
|
13 |
+
|
14 |
+
class AWSBedrockAPI(FluxAPI):
|
15 |
+
def __init__(self):
|
16 |
+
load_dotenv()
|
17 |
+
# AWS credentials should be set via environment variables
|
18 |
+
# AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN
|
19 |
+
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
20 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
21 |
+
os.environ["AWS_SESSION_TOKEN"] = ""
|
22 |
+
os.environ["AWS_REGION"] = "us-east-1"
|
23 |
+
self._client = boto3.client("bedrock-runtime")
|
24 |
+
self._model_id = "amazon.nova-canvas-v1:0"
|
25 |
+
|
26 |
+
@property
|
27 |
+
def name(self) -> str:
|
28 |
+
return "aws_nova_canvas"
|
29 |
+
|
30 |
+
def generate_image(self, prompt: str, save_path: Path) -> float:
|
31 |
+
start_time = time.time()
|
32 |
+
# Format the request payload
|
33 |
+
native_request = {
|
34 |
+
"taskType": "TEXT_IMAGE",
|
35 |
+
"textToImageParams": {"text": prompt},
|
36 |
+
"imageGenerationConfig": {
|
37 |
+
"seed": 0,
|
38 |
+
"quality": "standard",
|
39 |
+
"height": 1024,
|
40 |
+
"width": 1024,
|
41 |
+
"numberOfImages": 1,
|
42 |
+
},
|
43 |
+
}
|
44 |
+
|
45 |
+
try:
|
46 |
+
# Convert request to JSON and invoke the model
|
47 |
+
request = json.dumps(native_request)
|
48 |
+
response = self._client.invoke_model(
|
49 |
+
modelId=self._model_id,
|
50 |
+
body=request
|
51 |
+
)
|
52 |
+
|
53 |
+
# Process the response
|
54 |
+
model_response = json.loads(response["body"].read())
|
55 |
+
if not model_response.get("images"):
|
56 |
+
raise Exception("No images returned from AWS Bedrock API")
|
57 |
+
|
58 |
+
# Save the image
|
59 |
+
base64_image_data = model_response["images"][0]
|
60 |
+
self._save_image_from_base64(base64_image_data, save_path)
|
61 |
+
|
62 |
+
except Exception as e:
|
63 |
+
raise Exception(f"Error generating image with AWS Bedrock: {str(e)}")
|
64 |
+
|
65 |
+
end_time = time.time()
|
66 |
+
return end_time - start_time
|
67 |
+
|
68 |
+
def _save_image_from_base64(self, base64_data: str, save_path: Path):
|
69 |
+
"""Save a base64 encoded image to the specified path."""
|
70 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
71 |
+
image_data = base64.b64decode(base64_data)
|
72 |
+
with open(save_path, "wb") as f:
|
73 |
+
f.write(image_data)
|
benchmark/metrics/__init__.py
CHANGED
@@ -6,7 +6,7 @@ from benchmark.metrics.clip_iqa import CLIPIQAMetric
|
|
6 |
from benchmark.metrics.image_reward import ImageRewardMetric
|
7 |
from benchmark.metrics.sharpness import SharpnessMetric
|
8 |
from benchmark.metrics.vqa import VQAMetric
|
9 |
-
from benchmark.metrics.hps import HPSMetric
|
10 |
|
11 |
def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
|
12 |
"""
|
@@ -34,7 +34,7 @@ def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAM
|
|
34 |
"image_reward": ImageRewardMetric,
|
35 |
"sharpness": SharpnessMetric,
|
36 |
"vqa": VQAMetric,
|
37 |
-
"hps": HPSMetric,
|
38 |
}
|
39 |
|
40 |
if metric_type not in metric_map:
|
|
|
6 |
from benchmark.metrics.image_reward import ImageRewardMetric
|
7 |
from benchmark.metrics.sharpness import SharpnessMetric
|
8 |
from benchmark.metrics.vqa import VQAMetric
|
9 |
+
#from benchmark.metrics.hps import HPSMetric
|
10 |
|
11 |
def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
|
12 |
"""
|
|
|
34 |
"image_reward": ImageRewardMetric,
|
35 |
"sharpness": SharpnessMetric,
|
36 |
"vqa": VQAMetric,
|
37 |
+
#"hps": HPSMetric,
|
38 |
}
|
39 |
|
40 |
if metric_type not in metric_map:
|
nils_installs.txt
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==1.7.0
|
2 |
+
aiohappyeyeballs==2.6.1
|
3 |
+
aiohttp==3.12.12
|
4 |
+
aiosignal==1.3.2
|
5 |
+
annotated-types==0.7.0
|
6 |
+
antlr4-python3-runtime==4.9.3
|
7 |
+
anyio==4.9.0
|
8 |
+
args==0.1.0
|
9 |
+
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
|
10 |
+
attrs==25.3.0
|
11 |
+
beautifulsoup4==4.13.4
|
12 |
+
boto3==1.38.33
|
13 |
+
botocore==1.38.33
|
14 |
+
braceexpand==0.1.7
|
15 |
+
Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1749229842835/work
|
16 |
+
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1746569525376/work/certifi
|
17 |
+
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1725560558132/work
|
18 |
+
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1746214863626/work
|
19 |
+
click==8.1.8
|
20 |
+
clint==0.5.1
|
21 |
+
clip==0.2.0
|
22 |
+
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work
|
23 |
+
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
|
24 |
+
contourpy==1.3.2
|
25 |
+
cycler==0.12.1
|
26 |
+
datasets==3.6.0
|
27 |
+
debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1744321241074/work
|
28 |
+
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
|
29 |
+
diffusers==0.31.0
|
30 |
+
dill==0.3.8
|
31 |
+
distro==1.9.0
|
32 |
+
einops==0.8.1
|
33 |
+
eval_type_backport==0.2.2
|
34 |
+
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
|
35 |
+
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
|
36 |
+
fairscale==0.4.13
|
37 |
+
fal_client==0.7.0
|
38 |
+
filelock==3.18.0
|
39 |
+
fire==0.4.0
|
40 |
+
fonttools==4.58.2
|
41 |
+
frozenlist==1.7.0
|
42 |
+
fsspec==2025.3.0
|
43 |
+
ftfy==6.3.1
|
44 |
+
gdown==5.2.0
|
45 |
+
h11==0.16.0
|
46 |
+
h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1738578511449/work
|
47 |
+
hf-xet==1.1.3
|
48 |
+
hpack @ file:///home/conda/feedstock_root/build_artifacts/hpack_1737618293087/work
|
49 |
+
hpsv2==1.2.0
|
50 |
+
httpcore==1.0.9
|
51 |
+
httpx==0.28.1
|
52 |
+
httpx-sse==0.4.0
|
53 |
+
huggingface-hub==0.32.5
|
54 |
+
hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1737618333194/work
|
55 |
+
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1733211830134/work
|
56 |
+
image-reward==1.5
|
57 |
+
importlib_metadata==8.7.0
|
58 |
+
iniconfig==2.1.0
|
59 |
+
iopath==0.1.10
|
60 |
+
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
|
61 |
+
ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1748713870/work
|
62 |
+
ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work
|
63 |
+
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
|
64 |
+
Jinja2==3.1.6
|
65 |
+
jiter==0.10.0
|
66 |
+
jmespath==1.0.1
|
67 |
+
joblib==1.5.1
|
68 |
+
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
|
69 |
+
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1748333051527/work
|
70 |
+
kiwisolver==1.4.8
|
71 |
+
lightning-utilities==0.14.3
|
72 |
+
markdown-it-py==3.0.0
|
73 |
+
MarkupSafe==3.0.2
|
74 |
+
matplotlib==3.10.3
|
75 |
+
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
|
76 |
+
mdurl==0.1.2
|
77 |
+
mpmath==1.3.0
|
78 |
+
multidict==6.4.4
|
79 |
+
multiprocess==0.70.16
|
80 |
+
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
|
81 |
+
networkx==3.5
|
82 |
+
numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1749430504934/work/dist/numpy-2.3.0-cp312-cp312-linux_x86_64.whl#sha256=3c4437a0cbe50dbae872ad4cd8dc5316009165bce459c4ffe2c46cd30aba13d4
|
83 |
+
nvidia-cublas-cu12==12.6.4.1
|
84 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
85 |
+
nvidia-cuda-nvrtc-cu12==12.6.77
|
86 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
87 |
+
nvidia-cudnn-cu12==9.5.1.17
|
88 |
+
nvidia-cufft-cu12==11.3.0.4
|
89 |
+
nvidia-cufile-cu12==1.11.1.6
|
90 |
+
nvidia-curand-cu12==10.3.7.77
|
91 |
+
nvidia-cusolver-cu12==11.7.1.2
|
92 |
+
nvidia-cusparse-cu12==12.5.4.2
|
93 |
+
nvidia-cusparselt-cu12==0.6.3
|
94 |
+
nvidia-nccl-cu12==2.26.2
|
95 |
+
nvidia-nvjitlink-cu12==12.6.85
|
96 |
+
nvidia-nvtx-cu12==12.6.77
|
97 |
+
omegaconf==2.3.0
|
98 |
+
open_clip_torch==2.32.0
|
99 |
+
openai==1.85.0
|
100 |
+
opencv-python==4.11.0
|
101 |
+
opencv-python-headless==4.11.0
|
102 |
+
packaging==25.0
|
103 |
+
pandas==2.3.0
|
104 |
+
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
|
105 |
+
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
|
106 |
+
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
|
107 |
+
pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1746646208260/work
|
108 |
+
piq==0.8.0
|
109 |
+
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_platformdirs_1746710438/work
|
110 |
+
pluggy==1.6.0
|
111 |
+
portalocker==3.1.1
|
112 |
+
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
|
113 |
+
propcache==0.3.2
|
114 |
+
protobuf==3.20.3
|
115 |
+
psutil==7.0.0
|
116 |
+
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
|
117 |
+
pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
|
118 |
+
pyarrow==20.0.0
|
119 |
+
pycocoevalcap==1.2
|
120 |
+
pycocotools==2.0.10
|
121 |
+
pycparser @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycparser_1733195786/work
|
122 |
+
pydantic==2.11.5
|
123 |
+
pydantic_core==2.33.2
|
124 |
+
Pygments==2.19.1
|
125 |
+
pyparsing==3.2.3
|
126 |
+
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1733217236728/work
|
127 |
+
pytest==7.2.0
|
128 |
+
pytest-split==0.8.0
|
129 |
+
python-dateutil==2.9.0.post0
|
130 |
+
python-dotenv @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dotenv_1742948348/work
|
131 |
+
pytz==2025.2
|
132 |
+
PyYAML==6.0.2
|
133 |
+
pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1743831245863/work
|
134 |
+
regex==2024.11.6
|
135 |
+
replicate==1.0.7
|
136 |
+
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1749498106507/work
|
137 |
+
rich==14.0.0
|
138 |
+
s3transfer==0.13.0
|
139 |
+
safetensors==0.5.3
|
140 |
+
scikit-learn==1.7.0
|
141 |
+
scipy==1.15.3
|
142 |
+
sentencepiece==0.2.0
|
143 |
+
setuptools==80.9.0
|
144 |
+
shellingham==1.5.4
|
145 |
+
six==1.17.0
|
146 |
+
sniffio==1.3.1
|
147 |
+
soupsieve==2.7
|
148 |
+
stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
|
149 |
+
sympy==1.14.0
|
150 |
+
t2v_metrics==1.2
|
151 |
+
tabulate==0.9.0
|
152 |
+
termcolor==3.1.0
|
153 |
+
threadpoolctl==3.6.0
|
154 |
+
tiktoken==0.9.0
|
155 |
+
timm==0.6.13
|
156 |
+
together==1.5.11
|
157 |
+
tokenizers==0.15.2
|
158 |
+
torch==2.7.1
|
159 |
+
torchmetrics==1.7.2
|
160 |
+
torchvision==0.22.1
|
161 |
+
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1748003300911/work
|
162 |
+
tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1735661334605/work
|
163 |
+
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
|
164 |
+
transformers==4.36.1
|
165 |
+
triton==3.3.1
|
166 |
+
typer==0.15.4
|
167 |
+
typing-inspection==0.4.1
|
168 |
+
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1748959427/work
|
169 |
+
tzdata==2025.2
|
170 |
+
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1744323578849/work
|
171 |
+
wcwidth==0.2.13
|
172 |
+
webdataset==0.2.111
|
173 |
+
wheel==0.45.1
|
174 |
+
xxhash==3.5.0
|
175 |
+
yarl==1.20.1
|
176 |
+
zipp==3.23.0
|
177 |
+
zstandard==0.23.0
|