Nils Fleischmann commited on
Commit
5291ba9
·
1 Parent(s): 4f41410

feat: add aws canva + examples in readme + my current environment + disable HPS for now

Browse files
Files changed (5) hide show
  1. README.md +11 -2
  2. api/__init__.py +3 -0
  3. api/aws.py +73 -0
  4. benchmark/metrics/__init__.py +2 -2
  5. nils_installs.txt +177 -0
README.md CHANGED
@@ -6,5 +6,14 @@ Install dependencies with conda like that:
6
  conda env create -f environment.yml
7
 
8
 
9
- Create .env file with all the files you will need.
10
- python sample.py replicate draw_bench genai_bench geneval hps parti
 
 
 
 
 
 
 
 
 
 
6
  conda env create -f environment.yml
7
 
8
 
9
+ Create .env file with all the credentials you will need.
10
+
11
+ This is how you can generate the images.
12
+ ```
13
+ python sample.py replicate draw_bench genai_bench geneval hps parti
14
+ ```
15
+
16
+ This is how you would evaluate the benchmarks once you have all images:
17
+ ```
18
+ python evaluate.py replicate draw_bench genai_bench geneval hps parti
19
+ ```
api/__init__.py CHANGED
@@ -8,6 +8,7 @@ from api.pruna_dev import PrunaDevAPI
8
  from api.replicate import ReplicateAPI
9
  from api.together import TogetherAPI
10
  from api.fal import FalAPI
 
11
 
12
  __all__ = [
13
  'create_api',
@@ -33,6 +34,7 @@ def create_api(api_type: str) -> FluxAPI:
33
  - "replicate"
34
  - "together"
35
  - "fal"
 
36
 
37
  Returns:
38
  FluxAPI: An instance of the requested API implementation
@@ -52,6 +54,7 @@ def create_api(api_type: str) -> FluxAPI:
52
  "replicate": ReplicateAPI,
53
  "together": TogetherAPI,
54
  "fal": FalAPI,
 
55
  }
56
 
57
  if api_type not in api_map:
 
8
  from api.replicate import ReplicateAPI
9
  from api.together import TogetherAPI
10
  from api.fal import FalAPI
11
+ from api.aws import AWSBedrockAPI
12
 
13
  __all__ = [
14
  'create_api',
 
34
  - "replicate"
35
  - "together"
36
  - "fal"
37
+ - "aws"
38
 
39
  Returns:
40
  FluxAPI: An instance of the requested API implementation
 
54
  "replicate": ReplicateAPI,
55
  "together": TogetherAPI,
56
  "fal": FalAPI,
57
+ "aws": AWSBedrockAPI,
58
  }
59
 
60
  if api_type not in api_map:
api/aws.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import base64
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import boto3
9
+ from dotenv import load_dotenv
10
+
11
+ from api.flux import FluxAPI
12
+
13
+
14
+ class AWSBedrockAPI(FluxAPI):
15
+ def __init__(self):
16
+ load_dotenv()
17
+ # AWS credentials should be set via environment variables
18
+ # AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN
19
+ os.environ["AWS_ACCESS_KEY_ID"] = ""
20
+ os.environ["AWS_SECRET_ACCESS_KEY"] = ""
21
+ os.environ["AWS_SESSION_TOKEN"] = ""
22
+ os.environ["AWS_REGION"] = "us-east-1"
23
+ self._client = boto3.client("bedrock-runtime")
24
+ self._model_id = "amazon.nova-canvas-v1:0"
25
+
26
+ @property
27
+ def name(self) -> str:
28
+ return "aws_nova_canvas"
29
+
30
+ def generate_image(self, prompt: str, save_path: Path) -> float:
31
+ start_time = time.time()
32
+ # Format the request payload
33
+ native_request = {
34
+ "taskType": "TEXT_IMAGE",
35
+ "textToImageParams": {"text": prompt},
36
+ "imageGenerationConfig": {
37
+ "seed": 0,
38
+ "quality": "standard",
39
+ "height": 1024,
40
+ "width": 1024,
41
+ "numberOfImages": 1,
42
+ },
43
+ }
44
+
45
+ try:
46
+ # Convert request to JSON and invoke the model
47
+ request = json.dumps(native_request)
48
+ response = self._client.invoke_model(
49
+ modelId=self._model_id,
50
+ body=request
51
+ )
52
+
53
+ # Process the response
54
+ model_response = json.loads(response["body"].read())
55
+ if not model_response.get("images"):
56
+ raise Exception("No images returned from AWS Bedrock API")
57
+
58
+ # Save the image
59
+ base64_image_data = model_response["images"][0]
60
+ self._save_image_from_base64(base64_image_data, save_path)
61
+
62
+ except Exception as e:
63
+ raise Exception(f"Error generating image with AWS Bedrock: {str(e)}")
64
+
65
+ end_time = time.time()
66
+ return end_time - start_time
67
+
68
+ def _save_image_from_base64(self, base64_data: str, save_path: Path):
69
+ """Save a base64 encoded image to the specified path."""
70
+ save_path.parent.mkdir(parents=True, exist_ok=True)
71
+ image_data = base64.b64decode(base64_data)
72
+ with open(save_path, "wb") as f:
73
+ f.write(image_data)
benchmark/metrics/__init__.py CHANGED
@@ -6,7 +6,7 @@ from benchmark.metrics.clip_iqa import CLIPIQAMetric
6
  from benchmark.metrics.image_reward import ImageRewardMetric
7
  from benchmark.metrics.sharpness import SharpnessMetric
8
  from benchmark.metrics.vqa import VQAMetric
9
- from benchmark.metrics.hps import HPSMetric
10
 
11
  def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
12
  """
@@ -34,7 +34,7 @@ def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAM
34
  "image_reward": ImageRewardMetric,
35
  "sharpness": SharpnessMetric,
36
  "vqa": VQAMetric,
37
- "hps": HPSMetric,
38
  }
39
 
40
  if metric_type not in metric_map:
 
6
  from benchmark.metrics.image_reward import ImageRewardMetric
7
  from benchmark.metrics.sharpness import SharpnessMetric
8
  from benchmark.metrics.vqa import VQAMetric
9
+ #from benchmark.metrics.hps import HPSMetric
10
 
11
  def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
12
  """
 
34
  "image_reward": ImageRewardMetric,
35
  "sharpness": SharpnessMetric,
36
  "vqa": VQAMetric,
37
+ #"hps": HPSMetric,
38
  }
39
 
40
  if metric_type not in metric_map:
nils_installs.txt ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.7.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.12.12
4
+ aiosignal==1.3.2
5
+ annotated-types==0.7.0
6
+ antlr4-python3-runtime==4.9.3
7
+ anyio==4.9.0
8
+ args==0.1.0
9
+ asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
10
+ attrs==25.3.0
11
+ beautifulsoup4==4.13.4
12
+ boto3==1.38.33
13
+ botocore==1.38.33
14
+ braceexpand==0.1.7
15
+ Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1749229842835/work
16
+ certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1746569525376/work/certifi
17
+ cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1725560558132/work
18
+ charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1746214863626/work
19
+ click==8.1.8
20
+ clint==0.5.1
21
+ clip==0.2.0
22
+ colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work
23
+ comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
24
+ contourpy==1.3.2
25
+ cycler==0.12.1
26
+ datasets==3.6.0
27
+ debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1744321241074/work
28
+ decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
29
+ diffusers==0.31.0
30
+ dill==0.3.8
31
+ distro==1.9.0
32
+ einops==0.8.1
33
+ eval_type_backport==0.2.2
34
+ exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
35
+ executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
36
+ fairscale==0.4.13
37
+ fal_client==0.7.0
38
+ filelock==3.18.0
39
+ fire==0.4.0
40
+ fonttools==4.58.2
41
+ frozenlist==1.7.0
42
+ fsspec==2025.3.0
43
+ ftfy==6.3.1
44
+ gdown==5.2.0
45
+ h11==0.16.0
46
+ h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1738578511449/work
47
+ hf-xet==1.1.3
48
+ hpack @ file:///home/conda/feedstock_root/build_artifacts/hpack_1737618293087/work
49
+ hpsv2==1.2.0
50
+ httpcore==1.0.9
51
+ httpx==0.28.1
52
+ httpx-sse==0.4.0
53
+ huggingface-hub==0.32.5
54
+ hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1737618333194/work
55
+ idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1733211830134/work
56
+ image-reward==1.5
57
+ importlib_metadata==8.7.0
58
+ iniconfig==2.1.0
59
+ iopath==0.1.10
60
+ ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
61
+ ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1748713870/work
62
+ ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work
63
+ jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
64
+ Jinja2==3.1.6
65
+ jiter==0.10.0
66
+ jmespath==1.0.1
67
+ joblib==1.5.1
68
+ jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
69
+ jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1748333051527/work
70
+ kiwisolver==1.4.8
71
+ lightning-utilities==0.14.3
72
+ markdown-it-py==3.0.0
73
+ MarkupSafe==3.0.2
74
+ matplotlib==3.10.3
75
+ matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
76
+ mdurl==0.1.2
77
+ mpmath==1.3.0
78
+ multidict==6.4.4
79
+ multiprocess==0.70.16
80
+ nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
81
+ networkx==3.5
82
+ numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1749430504934/work/dist/numpy-2.3.0-cp312-cp312-linux_x86_64.whl#sha256=3c4437a0cbe50dbae872ad4cd8dc5316009165bce459c4ffe2c46cd30aba13d4
83
+ nvidia-cublas-cu12==12.6.4.1
84
+ nvidia-cuda-cupti-cu12==12.6.80
85
+ nvidia-cuda-nvrtc-cu12==12.6.77
86
+ nvidia-cuda-runtime-cu12==12.6.77
87
+ nvidia-cudnn-cu12==9.5.1.17
88
+ nvidia-cufft-cu12==11.3.0.4
89
+ nvidia-cufile-cu12==1.11.1.6
90
+ nvidia-curand-cu12==10.3.7.77
91
+ nvidia-cusolver-cu12==11.7.1.2
92
+ nvidia-cusparse-cu12==12.5.4.2
93
+ nvidia-cusparselt-cu12==0.6.3
94
+ nvidia-nccl-cu12==2.26.2
95
+ nvidia-nvjitlink-cu12==12.6.85
96
+ nvidia-nvtx-cu12==12.6.77
97
+ omegaconf==2.3.0
98
+ open_clip_torch==2.32.0
99
+ openai==1.85.0
100
+ opencv-python==4.11.0
101
+ opencv-python-headless==4.11.0
102
+ packaging==25.0
103
+ pandas==2.3.0
104
+ parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
105
+ pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
106
+ pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
107
+ pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1746646208260/work
108
+ piq==0.8.0
109
+ platformdirs @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_platformdirs_1746710438/work
110
+ pluggy==1.6.0
111
+ portalocker==3.1.1
112
+ prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
113
+ propcache==0.3.2
114
+ protobuf==3.20.3
115
+ psutil==7.0.0
116
+ ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
117
+ pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
118
+ pyarrow==20.0.0
119
+ pycocoevalcap==1.2
120
+ pycocotools==2.0.10
121
+ pycparser @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycparser_1733195786/work
122
+ pydantic==2.11.5
123
+ pydantic_core==2.33.2
124
+ Pygments==2.19.1
125
+ pyparsing==3.2.3
126
+ PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1733217236728/work
127
+ pytest==7.2.0
128
+ pytest-split==0.8.0
129
+ python-dateutil==2.9.0.post0
130
+ python-dotenv @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dotenv_1742948348/work
131
+ pytz==2025.2
132
+ PyYAML==6.0.2
133
+ pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1743831245863/work
134
+ regex==2024.11.6
135
+ replicate==1.0.7
136
+ requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1749498106507/work
137
+ rich==14.0.0
138
+ s3transfer==0.13.0
139
+ safetensors==0.5.3
140
+ scikit-learn==1.7.0
141
+ scipy==1.15.3
142
+ sentencepiece==0.2.0
143
+ setuptools==80.9.0
144
+ shellingham==1.5.4
145
+ six==1.17.0
146
+ sniffio==1.3.1
147
+ soupsieve==2.7
148
+ stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
149
+ sympy==1.14.0
150
+ t2v_metrics==1.2
151
+ tabulate==0.9.0
152
+ termcolor==3.1.0
153
+ threadpoolctl==3.6.0
154
+ tiktoken==0.9.0
155
+ timm==0.6.13
156
+ together==1.5.11
157
+ tokenizers==0.15.2
158
+ torch==2.7.1
159
+ torchmetrics==1.7.2
160
+ torchvision==0.22.1
161
+ tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1748003300911/work
162
+ tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1735661334605/work
163
+ traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
164
+ transformers==4.36.1
165
+ triton==3.3.1
166
+ typer==0.15.4
167
+ typing-inspection==0.4.1
168
+ typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1748959427/work
169
+ tzdata==2025.2
170
+ urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1744323578849/work
171
+ wcwidth==0.2.13
172
+ webdataset==0.2.111
173
+ wheel==0.45.1
174
+ xxhash==3.5.0
175
+ yarl==1.20.1
176
+ zipp==3.23.0
177
+ zstandard==0.23.0