Spaces:
Runtime error
Runtime error
Commit
Β·
b462bee
0
Parent(s):
commit message
Browse files- .gitattributes +34 -0
- .gitignore +3 -0
- README.md +13 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- app.py +143 -0
- assets/test_images/cat_1.png +0 -0
- assets/test_images/cat_2.png +0 -0
- assets/test_images/cat_5.png +0 -0
- environment.yml +23 -0
- requirements.txt +7 -0
- submodules/pix2pix-zero/.gitignore +6 -0
- submodules/pix2pix-zero/LICENSE +21 -0
- submodules/pix2pix-zero/README.md +154 -0
- submodules/pix2pix-zero/environment.yml +23 -0
- submodules/pix2pix-zero/src/edit_real.py +65 -0
- submodules/pix2pix-zero/src/edit_synthetic.py +52 -0
- submodules/pix2pix-zero/src/inversion.py +64 -0
- submodules/pix2pix-zero/src/make_edit_direction.py +61 -0
- submodules/pix2pix-zero/src/utils/__pycache__/base_pipeline.cpython-310.pyc +0 -0
- submodules/pix2pix-zero/src/utils/__pycache__/cross_attention.cpython-310.pyc +0 -0
- submodules/pix2pix-zero/src/utils/__pycache__/ddim_inv.cpython-310.pyc +0 -0
- submodules/pix2pix-zero/src/utils/__pycache__/edit_directions.cpython-310.pyc +0 -0
- submodules/pix2pix-zero/src/utils/__pycache__/edit_pipeline.cpython-310.pyc +0 -0
- submodules/pix2pix-zero/src/utils/__pycache__/scheduler.cpython-310.pyc +0 -0
- submodules/pix2pix-zero/src/utils/base_pipeline.py +322 -0
- submodules/pix2pix-zero/src/utils/cross_attention.py +57 -0
- submodules/pix2pix-zero/src/utils/ddim_inv.py +140 -0
- submodules/pix2pix-zero/src/utils/edit_directions.py +29 -0
- submodules/pix2pix-zero/src/utils/edit_pipeline.py +179 -0
- submodules/pix2pix-zero/src/utils/scheduler.py +289 -0
- utils.py +0 -0
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/__pycache__/direction_utils.cpython-310.pyc +0 -0
- utils/__pycache__/generate_synthetic.cpython-310.pyc +0 -0
- utils/__pycache__/gradio_utils.cpython-310.pyc +0 -0
- utils/direction_utils.py +79 -0
- utils/generate_synthetic.py +316 -0
- utils/gradio_utils.py +616 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
tmp
|
2 |
+
app_fin_v1.py
|
3 |
+
__*.py
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Demo Temp
|
3 |
+
emoji: π
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.18.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
__pycache__/utils.cpython-310.pyc
ADDED
Binary file (163 Bytes). View file
|
|
app.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
from utils.gradio_utils import *
|
7 |
+
from utils.direction_utils import *
|
8 |
+
from utils.generate_synthetic import *
|
9 |
+
|
10 |
+
|
11 |
+
if __name__=="__main__":
|
12 |
+
|
13 |
+
# populate the list of editing directions
|
14 |
+
d_name2desc = get_all_directions_names()
|
15 |
+
d_name2desc["make your own!"] = "make your own!"
|
16 |
+
|
17 |
+
with gr.Blocks(css=CSS_main) as demo:
|
18 |
+
# Make the header of the demo website
|
19 |
+
gr.HTML(HTML_header)
|
20 |
+
|
21 |
+
with gr.Row():
|
22 |
+
# col A: the input image or synthetic image prompt
|
23 |
+
with gr.Column(scale=2) as gc_left:
|
24 |
+
gr.HTML(" <center> <p style='font-size:150%;'> input </p> </center>")
|
25 |
+
img_in_real = gr.Image(type="pil", label="Start by uploading an image", elem_id="input_image")
|
26 |
+
img_in_synth = gr.Image(type="pil", label="Synthesized image", elem_id="input_image_synth", visible=False)
|
27 |
+
gr.Examples( examples="assets/test_images/", inputs=[img_in_real])
|
28 |
+
prompt = gr.Textbox(value="a high resolution painting of a cat in the style of van gogh", label="Or use a synthetic image. Prompt:", interactive=True)
|
29 |
+
with gr.Row():
|
30 |
+
seed = gr.Number(value=42, label="random seed:", interactive=True)
|
31 |
+
negative_guidance = gr.Number(value=5, label="negative guidance:", interactive=True)
|
32 |
+
btn_generate = gr.Button("Generate", label="")
|
33 |
+
fpath_z_gen = gr.Textbox(value="placeholder", visible=False)
|
34 |
+
|
35 |
+
# col B: the output image
|
36 |
+
with gr.Column(scale=2) as gc_left:
|
37 |
+
gr.HTML(" <center> <p style='font-size:150%;'> output </p> </center>")
|
38 |
+
img_out = gr.Image(type="pil", label="Output Image", visible=True)
|
39 |
+
with gr.Row():
|
40 |
+
with gr.Column():
|
41 |
+
src = gr.Dropdown(list(d_name2desc.values()), label="source", interactive=True, value="cat")
|
42 |
+
src_custom = gr.Textbox(placeholder="enter new task here!", interactive=True, visible=False, label="custom source direction:")
|
43 |
+
rad_src = gr.Radio(["GPT3", "flan-t5-xl (free)!", "BLOOMZ-7B (free)!", "fixed-template", "custom sentences"], label="Sentence type:", value="GPT3", interactive=True, visible=False)
|
44 |
+
custom_sentences_src = gr.Textbox(placeholder="paste list of sentences here", interactive=True, visible=False, label="custom sentences:", lines=5, max_lines=20)
|
45 |
+
|
46 |
+
|
47 |
+
with gr.Column():
|
48 |
+
dest = gr.Dropdown(list(d_name2desc.values()), label="target", interactive=True, value="dog")
|
49 |
+
dest_custom = gr.Textbox(placeholder="enter new task here!", interactive=True, visible=False, label="custom target direction:")
|
50 |
+
rad_dest = gr.Radio(["GPT3", "flan-t5-xl (free)!", "BLOOMZ-7B (free)!", "fixed-template", "custom sentences"], label="Sentence type:", value="GPT3", interactive=True, visible=False)
|
51 |
+
custom_sentences_dest = gr.Textbox(placeholder="paste list of sentences here", interactive=True, visible=False, label="custom sentences:", lines=5, max_lines=20)
|
52 |
+
|
53 |
+
|
54 |
+
with gr.Row():
|
55 |
+
api_key = gr.Textbox(placeholder="enter you OpenAI API key here", interactive=True, visible=False, label="OpenAI API key:", type="password")
|
56 |
+
org_key = gr.Textbox(placeholder="enter you OpenAI organization key here", interactive=True, visible=False, label="OpenAI Organization:", type="password")
|
57 |
+
with gr.Row():
|
58 |
+
btn_edit = gr.Button("Run", label="")
|
59 |
+
btn_clear = gr.Button("Clear")
|
60 |
+
|
61 |
+
with gr.Accordion("Change editing settings?", open=False):
|
62 |
+
num_ddim = gr.Slider(0, 200, 100, label="Number of DDIM steps", interactive=True, elem_id="slider_ddim", step=10)
|
63 |
+
xa_guidance = gr.Slider(0, 0.25, 0.1, label="Cross Attention guidance", interactive=True, elem_id="slider_xa", step=0.01)
|
64 |
+
edit_mul = gr.Slider(0, 2, 1.0, label="Edit multiplier", interactive=True, elem_id="slider_edit_mul", step=0.05)
|
65 |
+
|
66 |
+
with gr.Accordion("Generating your own directions", open=False):
|
67 |
+
gr.Textbox("We provide 5 different ways of computing new custom directions:", show_label=False)
|
68 |
+
gr.Textbox("We use GPT3 to generate a list of sentences that describe the desired edit. For this options, the users need to make an OpenAI account and enter the API and organizations keys. This option typically results is the best directions and costs roughly $0.14 for one concept.", label="1. GPT3", show_label=True)
|
69 |
+
gr.Textbox("Alternatively flan-t5-xl model can also be used to to generate a list of sentences that describe the desired edit. This option is free and does not require creating any new accounts.", label="2. flan-t5-xl (free)", show_label=True)
|
70 |
+
gr.Textbox("Similarly BLOOMZ-7B model can also be used to to generate the sentences for free.", label="3. BLOOMZ-7B (free)", show_label=True)
|
71 |
+
gr.Textbox("Next, we provide a fixed template based sentence generation. This option does not require any language model and is therefore free and much faster. However the edit directions with this method are often entangled.", label="4. Fixed template", show_label=True)
|
72 |
+
gr.Textbox("Finally, the user can also generate their own sentences.", label="5. Custom sentences", show_label=True)
|
73 |
+
|
74 |
+
|
75 |
+
with gr.Accordion("Tips for getting better results", open=True):
|
76 |
+
gr.Textbox("The 'Cross Attention guidance' controls the amount of structure guidance to be applied when performing the edit. If the output edited image does not retain the structure from the input, increasing the value will typically address the issue. We recommend changing the value in increments of 0.05.", label="1. Controlling the image structure", show_label=True)
|
77 |
+
gr.Textbox("If the output image quality is low or has some artifacts, using more steps would be helpful. This can be controlled with the 'Number of DDIM steps' slider.", label="2. Improving Image Quality", show_label=True)
|
78 |
+
gr.Textbox("There can be two reasons why the output image does not have the desired edit applied. Either the cross attention guidance is too strong, or the edit is insufficient. These can be addressed by reducing the 'Cross Attention guidance' slider or increasing the 'Edit multiplier' respectively.", label="3. Amount of edit applied", show_label=True)
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
# txt_image_type = gr.Textbox(visible=False)
|
86 |
+
btn_generate.click(launch_generate_sample, [prompt, seed, negative_guidance, num_ddim], [img_in_synth, fpath_z_gen])
|
87 |
+
btn_generate.click(set_visible_true, [], img_in_synth)
|
88 |
+
btn_generate.click(set_visible_false, [], img_in_real)
|
89 |
+
|
90 |
+
def fn_clear_all():
|
91 |
+
return gr.update(value=None), gr.update(value=None), gr.update(value=None)
|
92 |
+
btn_clear.click(fn_clear_all, [], [img_out, img_in_real, img_in_synth])
|
93 |
+
btn_clear.click(set_visible_true, [], img_in_real)
|
94 |
+
btn_clear.click(set_visible_false, [], img_in_synth)
|
95 |
+
|
96 |
+
|
97 |
+
# handling custom directions
|
98 |
+
def on_custom_seleceted(src):
|
99 |
+
if src=="make your own!": return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
100 |
+
else: return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
101 |
+
|
102 |
+
src.change(on_custom_seleceted, [src], [src_custom, rad_src, api_key, org_key])
|
103 |
+
dest.change(on_custom_seleceted, [dest], [dest_custom, rad_dest, api_key, org_key])
|
104 |
+
|
105 |
+
|
106 |
+
def fn_sentence_type_change(rad):
|
107 |
+
print(rad)
|
108 |
+
if rad=="GPT3":
|
109 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
|
110 |
+
elif rad=="custom sentences":
|
111 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
|
112 |
+
else:
|
113 |
+
print("using template sentence or flan-t5-xl or bloomz-7b")
|
114 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
115 |
+
|
116 |
+
rad_dest.change(fn_sentence_type_change, [rad_dest], [api_key, org_key, custom_sentences_dest])
|
117 |
+
rad_src.change(fn_sentence_type_change, [rad_src], [api_key, org_key, custom_sentences_src])
|
118 |
+
|
119 |
+
btn_edit.click(launch_main,
|
120 |
+
[
|
121 |
+
img_in_real, img_in_synth,
|
122 |
+
src, src_custom, dest,
|
123 |
+
dest_custom, num_ddim,
|
124 |
+
xa_guidance, edit_mul,
|
125 |
+
fpath_z_gen, prompt,
|
126 |
+
rad_src, rad_dest,
|
127 |
+
api_key, org_key,
|
128 |
+
custom_sentences_src, custom_sentences_dest
|
129 |
+
],
|
130 |
+
[img_out])
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
gr.HTML("<hr>")
|
135 |
+
|
136 |
+
|
137 |
+
demo.queue(concurrency_count=8)
|
138 |
+
demo.launch(debug=True)
|
139 |
+
|
140 |
+
# gr.close_all()
|
141 |
+
# demo.launch(server_port=8089, server_name="0.0.0.0", debug=True)
|
142 |
+
|
143 |
+
|
assets/test_images/cat_1.png
ADDED
![]() |
assets/test_images/cat_2.png
ADDED
![]() |
assets/test_images/cat_5.png
ADDED
![]() |
environment.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: pix2pix-zero
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- pip
|
8 |
+
- pytorch-cuda=11.6
|
9 |
+
- torchvision
|
10 |
+
- pytorch
|
11 |
+
- pip:
|
12 |
+
- accelerate
|
13 |
+
- diffusers
|
14 |
+
- einops
|
15 |
+
- gradio
|
16 |
+
- ipython
|
17 |
+
- numpy
|
18 |
+
- opencv-python-headless
|
19 |
+
- pillow
|
20 |
+
- psutil
|
21 |
+
- tqdm
|
22 |
+
- transformers
|
23 |
+
- salesforce-lavis
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
joblib
|
3 |
+
accelerate
|
4 |
+
diffusers==0.12.1
|
5 |
+
salesforce-lavis
|
6 |
+
openai
|
7 |
+
#git+https://github.com/pix2pixzero/pix2pix-zero.git
|
submodules/pix2pix-zero/.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output
|
2 |
+
scripts
|
3 |
+
src/folder_*.py
|
4 |
+
src/ig_*.py
|
5 |
+
assets/edit_sentences
|
6 |
+
src/utils/edit_pipeline_spatial.py
|
submodules/pix2pix-zero/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 pix2pixzero
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
submodules/pix2pix-zero/README.md
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pix2pix-zero
|
2 |
+
|
3 |
+
## [**[website]**](https://pix2pixzero.github.io/)
|
4 |
+
|
5 |
+
|
6 |
+
This is author's reimplementation of "Zero-shot Image-to-Image Translation" using the diffusers library. <br>
|
7 |
+
The results in the paper are based on the [CompVis](https://github.com/CompVis/stable-diffusion) library, which will be released later.
|
8 |
+
|
9 |
+
**[New!]** Code for editing real and synthetic images released!
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
<br>
|
14 |
+
<div class="gif">
|
15 |
+
<p align="center">
|
16 |
+
<img src='assets/main.gif' align="center">
|
17 |
+
</p>
|
18 |
+
</div>
|
19 |
+
|
20 |
+
|
21 |
+
We propose pix2pix-zero, a diffusion-based image-to-image approach that allows users to specify the edit direction on-the-fly (e.g., cat to dog). Our method can directly use pre-trained [Stable Diffusion](https://github.com/CompVis/stable-diffusion), for editing real and synthetic images while preserving the input image's structure. Our method is training-free and prompt-free, as it requires neither manual text prompting for each input image nor costly fine-tuning for each task.
|
22 |
+
|
23 |
+
**TL;DR**: no finetuning required, no text input needed, input structure preserved.
|
24 |
+
|
25 |
+
## Results
|
26 |
+
All our results are based on [stable-diffusion-v1-4](https://github.com/CompVis/stable-diffusion) model. Please the website for more results.
|
27 |
+
|
28 |
+
<div>
|
29 |
+
<p align="center">
|
30 |
+
<img src='assets/results_teaser.jpg' align="center" width=800px>
|
31 |
+
</p>
|
32 |
+
</div>
|
33 |
+
<hr>
|
34 |
+
|
35 |
+
The top row for each of the results below show editing of real images, and the bottom row shows synthetic image editing.
|
36 |
+
<div>
|
37 |
+
<p align="center">
|
38 |
+
<img src='assets/grid_dog2cat.jpg' align="center" width=800px>
|
39 |
+
</p>
|
40 |
+
<p align="center">
|
41 |
+
<img src='assets/grid_zebra2horse.jpg' align="center" width=800px>
|
42 |
+
</p>
|
43 |
+
<p align="center">
|
44 |
+
<img src='assets/grid_cat2dog.jpg' align="center" width=800px>
|
45 |
+
</p>
|
46 |
+
<p align="center">
|
47 |
+
<img src='assets/grid_horse2zebra.jpg' align="center" width=800px>
|
48 |
+
</p>
|
49 |
+
<p align="center">
|
50 |
+
<img src='assets/grid_tree2fall.jpg' align="center" width=800px>
|
51 |
+
</p>
|
52 |
+
</div>
|
53 |
+
|
54 |
+
## Real Image Editing
|
55 |
+
<div>
|
56 |
+
<p align="center">
|
57 |
+
<img src='assets/results_real.jpg' align="center" width=800px>
|
58 |
+
</p>
|
59 |
+
</div>
|
60 |
+
|
61 |
+
## Synthetic Image Editing
|
62 |
+
<div>
|
63 |
+
<p align="center">
|
64 |
+
<img src='assets/results_syn.jpg' align="center" width=800px>
|
65 |
+
</p>
|
66 |
+
</div>
|
67 |
+
|
68 |
+
## Method Details
|
69 |
+
|
70 |
+
Given an input image, we first generate text captions using [BLIP](https://github.com/salesforce/LAVIS) and apply regularized DDIM inversion to obtain our inverted noise map.
|
71 |
+
Then, we obtain reference cross-attention maps that correspoind to the structure of the input image by denoising, guided with the CLIP embeddings
|
72 |
+
of our generated text (c). Next, we denoise with edited text embeddings, while enforcing a loss to match current cross-attention maps with the
|
73 |
+
reference cross-attention maps.
|
74 |
+
|
75 |
+
<div>
|
76 |
+
<p align="center">
|
77 |
+
<img src='assets/method.jpeg' align="center" width=900>
|
78 |
+
</p>
|
79 |
+
</div>
|
80 |
+
|
81 |
+
|
82 |
+
## Getting Started
|
83 |
+
|
84 |
+
**Environment Setup**
|
85 |
+
- We provide a [conda env file](environment.yml) that contains all the required dependencies
|
86 |
+
```
|
87 |
+
conda env create -f environment.yml
|
88 |
+
```
|
89 |
+
- Following this, you can activate the conda environment with the command below.
|
90 |
+
```
|
91 |
+
conda activate pix2pix-zero
|
92 |
+
```
|
93 |
+
|
94 |
+
**Real Image Translation**
|
95 |
+
- First, run the inversion command below to obtain the input noise that reconstructs the image.
|
96 |
+
The command below will save the inversion in the results folder as `output/test_cat/inversion/cat_1.pt`
|
97 |
+
and the BLIP-generated prompt as `output/test_cat/prompt/cat_1.txt`
|
98 |
+
```
|
99 |
+
python src/inversion.py \
|
100 |
+
--input_image "assets/test_images/cats/cat_1.png" \
|
101 |
+
--results_folder "output/test_cat"
|
102 |
+
```
|
103 |
+
- Next, we can perform image editing with the editing direction as shown below.
|
104 |
+
The command below will save the edited image as `output/test_cat/edit/cat_1.png`
|
105 |
+
```
|
106 |
+
python src/edit_real.py \
|
107 |
+
--inversion "output/test_cat/inversion/cat_1.pt" \
|
108 |
+
--prompt "output/test_cat/prompt/cat_1.txt" \
|
109 |
+
--task_name "cat2dog" \
|
110 |
+
--results_folder "output/test_cat/"
|
111 |
+
```
|
112 |
+
|
113 |
+
**Editing Synthetic Images**
|
114 |
+
- Similarly, we can edit the synthetic images generated by Stable Diffusion with the following command.
|
115 |
+
```
|
116 |
+
python src/edit_synthetic.py \
|
117 |
+
--results_folder "output/synth_editing" \
|
118 |
+
--prompt_str "a high resolution painting of a cat in the style of van gough" \
|
119 |
+
--task "cat2dog"
|
120 |
+
```
|
121 |
+
|
122 |
+
### **Tips and Debugging**
|
123 |
+
- **Controlling the Image Structure:**<br>
|
124 |
+
The `--xa_guidance` flag controls the amount of cross-attention guidance to be applied when performing the edit. If the output edited image does not retain the structure from the input, increasing the value will typically address the issue. We recommend changing the value in increments of 0.05.
|
125 |
+
|
126 |
+
- **Improving Image Quality:**<br>
|
127 |
+
If the output image quality is low or has some artifacts, using more steps for both the inversion and editing would be helpful.
|
128 |
+
This can be controlled with the `--num_ddim_steps` flag.
|
129 |
+
|
130 |
+
- **Reducing the VRAM Requirements:**<br>
|
131 |
+
We can reduce the VRAM requirements using lower precision and setting the flag `--use_float_16`.
|
132 |
+
|
133 |
+
<br>
|
134 |
+
|
135 |
+
**Finding Custom Edit Directions**<br>
|
136 |
+
- We provide some pre-computed directions in the assets [folder](assets/embeddings_sd_1.4).
|
137 |
+
To generate new edit directions, users can first generate two files containing a large number of sentences (~1000) and then run the command as shown below.
|
138 |
+
```
|
139 |
+
python src/make_edit_direction.py \
|
140 |
+
--file_source_sentences sentences/apple.txt \
|
141 |
+
--file_target_sentences sentences/orange.txt \
|
142 |
+
--output_folder assets/embeddings_sd_1.4
|
143 |
+
```
|
144 |
+
- After running the above command, you can set the flag `--task apple2orange` for the new edit.
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
## Comparison
|
149 |
+
Comparisons with different baselines, including, SDEdit + word swap, DDIM + word swap, and prompt-to-propmt. Our method successfully applies the edit, while preserving the structure of the input image.
|
150 |
+
<div>
|
151 |
+
<p align="center">
|
152 |
+
<img src='assets/comparison.jpg' align="center" width=900>
|
153 |
+
</p>
|
154 |
+
</div>
|
submodules/pix2pix-zero/environment.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: pix2pix-zero
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- pip
|
8 |
+
- pytorch-cuda=11.6
|
9 |
+
- torchvision
|
10 |
+
- pytorch
|
11 |
+
- pip:
|
12 |
+
- accelerate
|
13 |
+
- diffusers
|
14 |
+
- einops
|
15 |
+
- gradio
|
16 |
+
- ipython
|
17 |
+
- numpy
|
18 |
+
- opencv-python-headless
|
19 |
+
- pillow
|
20 |
+
- psutil
|
21 |
+
- tqdm
|
22 |
+
- transformers
|
23 |
+
- salesforce-lavis
|
submodules/pix2pix-zero/src/edit_real.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, pdb
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from diffusers import DDIMScheduler
|
10 |
+
from utils.ddim_inv import DDIMInversion
|
11 |
+
from utils.edit_directions import construct_direction
|
12 |
+
from utils.edit_pipeline import EditingPipeline
|
13 |
+
|
14 |
+
|
15 |
+
if __name__=="__main__":
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument('--inversion', required=True)
|
18 |
+
parser.add_argument('--prompt', type=str, required=True)
|
19 |
+
parser.add_argument('--task_name', type=str, default='cat2dog')
|
20 |
+
parser.add_argument('--results_folder', type=str, default='output/test_cat')
|
21 |
+
parser.add_argument('--num_ddim_steps', type=int, default=50)
|
22 |
+
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
|
23 |
+
parser.add_argument('--xa_guidance', default=0.1, type=float)
|
24 |
+
parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
|
25 |
+
parser.add_argument('--use_float_16', action='store_true')
|
26 |
+
|
27 |
+
args = parser.parse_args()
|
28 |
+
|
29 |
+
os.makedirs(os.path.join(args.results_folder, "edit"), exist_ok=True)
|
30 |
+
os.makedirs(os.path.join(args.results_folder, "reconstruction"), exist_ok=True)
|
31 |
+
|
32 |
+
if args.use_float_16:
|
33 |
+
torch_dtype = torch.float16
|
34 |
+
else:
|
35 |
+
torch_dtype = torch.float32
|
36 |
+
|
37 |
+
# if the inversion is a folder, the prompt should also be a folder
|
38 |
+
assert (os.path.isdir(args.inversion)==os.path.isdir(args.prompt)), "If the inversion is a folder, the prompt should also be a folder"
|
39 |
+
if os.path.isdir(args.inversion):
|
40 |
+
l_inv_paths = sorted(glob(os.path.join(args.inversion, "*.pt")))
|
41 |
+
l_bnames = [os.path.basename(x) for x in l_inv_paths]
|
42 |
+
l_prompt_paths = [os.path.join(args.prompt, x.replace(".pt",".txt")) for x in l_bnames]
|
43 |
+
else:
|
44 |
+
l_inv_paths = [args.inversion]
|
45 |
+
l_prompt_paths = [args.prompt]
|
46 |
+
|
47 |
+
# Make the editing pipeline
|
48 |
+
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
|
49 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
50 |
+
|
51 |
+
|
52 |
+
for inv_path, prompt_path in zip(l_inv_paths, l_prompt_paths):
|
53 |
+
prompt_str = open(prompt_path).read().strip()
|
54 |
+
rec_pil, edit_pil = pipe(prompt_str,
|
55 |
+
num_inference_steps=args.num_ddim_steps,
|
56 |
+
x_in=torch.load(inv_path).unsqueeze(0),
|
57 |
+
edit_dir=construct_direction(args.task_name),
|
58 |
+
guidance_amount=args.xa_guidance,
|
59 |
+
guidance_scale=args.negative_guidance_scale,
|
60 |
+
negative_prompt=prompt_str # use the unedited prompt for the negative prompt
|
61 |
+
)
|
62 |
+
|
63 |
+
bname = os.path.basename(args.inversion).split(".")[0]
|
64 |
+
edit_pil[0].save(os.path.join(args.results_folder, f"edit/{bname}.png"))
|
65 |
+
rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction/{bname}.png"))
|
submodules/pix2pix-zero/src/edit_synthetic.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, pdb
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from diffusers import DDIMScheduler
|
10 |
+
from utils.edit_directions import construct_direction
|
11 |
+
from utils.edit_pipeline import EditingPipeline
|
12 |
+
|
13 |
+
|
14 |
+
if __name__=="__main__":
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument('--prompt_str', type=str, required=True)
|
17 |
+
parser.add_argument('--random_seed', default=0)
|
18 |
+
parser.add_argument('--task_name', type=str, default='cat2dog')
|
19 |
+
parser.add_argument('--results_folder', type=str, default='output/test_cat')
|
20 |
+
parser.add_argument('--num_ddim_steps', type=int, default=50)
|
21 |
+
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
|
22 |
+
parser.add_argument('--xa_guidance', default=0.15, type=float)
|
23 |
+
parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
|
24 |
+
parser.add_argument('--use_float_16', action='store_true')
|
25 |
+
args = parser.parse_args()
|
26 |
+
|
27 |
+
os.makedirs(args.results_folder, exist_ok=True)
|
28 |
+
|
29 |
+
if args.use_float_16:
|
30 |
+
torch_dtype = torch.float16
|
31 |
+
else:
|
32 |
+
torch_dtype = torch.float32
|
33 |
+
|
34 |
+
# make the input noise map
|
35 |
+
torch.cuda.manual_seed(args.random_seed)
|
36 |
+
x = torch.randn((1,4,64,64), device="cuda")
|
37 |
+
|
38 |
+
# Make the editing pipeline
|
39 |
+
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
|
40 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
41 |
+
|
42 |
+
rec_pil, edit_pil = pipe(args.prompt_str,
|
43 |
+
num_inference_steps=args.num_ddim_steps,
|
44 |
+
x_in=x,
|
45 |
+
edit_dir=construct_direction(args.task_name),
|
46 |
+
guidance_amount=args.xa_guidance,
|
47 |
+
guidance_scale=args.negative_guidance_scale,
|
48 |
+
negative_prompt="" # use the empty string for the negative prompt
|
49 |
+
)
|
50 |
+
|
51 |
+
edit_pil[0].save(os.path.join(args.results_folder, f"edit.png"))
|
52 |
+
rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction.png"))
|
submodules/pix2pix-zero/src/inversion.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, pdb
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from lavis.models import load_model_and_preprocess
|
10 |
+
|
11 |
+
from utils.ddim_inv import DDIMInversion
|
12 |
+
from utils.scheduler import DDIMInverseScheduler
|
13 |
+
|
14 |
+
if __name__=="__main__":
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument('--input_image', type=str, default='assets/test_images/cat_a.png')
|
17 |
+
parser.add_argument('--results_folder', type=str, default='output/test_cat')
|
18 |
+
parser.add_argument('--num_ddim_steps', type=int, default=50)
|
19 |
+
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
|
20 |
+
parser.add_argument('--use_float_16', action='store_true')
|
21 |
+
args = parser.parse_args()
|
22 |
+
|
23 |
+
# make the output folders
|
24 |
+
os.makedirs(os.path.join(args.results_folder, "inversion"), exist_ok=True)
|
25 |
+
os.makedirs(os.path.join(args.results_folder, "prompt"), exist_ok=True)
|
26 |
+
|
27 |
+
if args.use_float_16:
|
28 |
+
torch_dtype = torch.float16
|
29 |
+
else:
|
30 |
+
torch_dtype = torch.float32
|
31 |
+
|
32 |
+
|
33 |
+
# load the BLIP model
|
34 |
+
model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda"))
|
35 |
+
# make the DDIM inversion pipeline
|
36 |
+
pipe = DDIMInversion.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
|
37 |
+
pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
|
38 |
+
|
39 |
+
|
40 |
+
# if the input is a folder, collect all the images as a list
|
41 |
+
if os.path.isdir(args.input_image):
|
42 |
+
l_img_paths = sorted(glob(os.path.join(args.input_image, "*.png")))
|
43 |
+
else:
|
44 |
+
l_img_paths = [args.input_image]
|
45 |
+
|
46 |
+
|
47 |
+
for img_path in l_img_paths:
|
48 |
+
bname = os.path.basename(args.input_image).split(".")[0]
|
49 |
+
img = Image.open(args.input_image).resize((512,512), Image.Resampling.LANCZOS)
|
50 |
+
# generate the caption
|
51 |
+
_image = vis_processors["eval"](img).unsqueeze(0).cuda()
|
52 |
+
prompt_str = model_blip.generate({"image": _image})[0]
|
53 |
+
x_inv, x_inv_image, x_dec_img = pipe(
|
54 |
+
prompt_str,
|
55 |
+
guidance_scale=1,
|
56 |
+
num_inversion_steps=args.num_ddim_steps,
|
57 |
+
img=img,
|
58 |
+
torch_dtype=torch_dtype
|
59 |
+
)
|
60 |
+
# save the inversion
|
61 |
+
torch.save(x_inv[0], os.path.join(args.results_folder, f"inversion/{bname}.pt"))
|
62 |
+
# save the prompt string
|
63 |
+
with open(os.path.join(args.results_folder, f"prompt/{bname}.txt"), "w") as f:
|
64 |
+
f.write(prompt_str)
|
submodules/pix2pix-zero/src/make_edit_direction.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, pdb
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from diffusers import DDIMScheduler
|
10 |
+
from utils.edit_pipeline import EditingPipeline
|
11 |
+
|
12 |
+
|
13 |
+
## convert sentences to sentence embeddings
|
14 |
+
def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
|
15 |
+
with torch.no_grad():
|
16 |
+
l_embeddings = []
|
17 |
+
for sent in l_sentences:
|
18 |
+
text_inputs = tokenizer(
|
19 |
+
sent,
|
20 |
+
padding="max_length",
|
21 |
+
max_length=tokenizer.model_max_length,
|
22 |
+
truncation=True,
|
23 |
+
return_tensors="pt",
|
24 |
+
)
|
25 |
+
text_input_ids = text_inputs.input_ids
|
26 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
|
27 |
+
l_embeddings.append(prompt_embeds)
|
28 |
+
return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
|
29 |
+
|
30 |
+
|
31 |
+
if __name__=="__main__":
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
parser.add_argument('--file_source_sentences', required=True)
|
34 |
+
parser.add_argument('--file_target_sentences', required=True)
|
35 |
+
parser.add_argument('--output_folder', required=True)
|
36 |
+
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
|
37 |
+
args = parser.parse_args()
|
38 |
+
|
39 |
+
# load the model
|
40 |
+
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch.float16).to("cuda")
|
41 |
+
bname_src = os.path.basename(args.file_source_sentences).strip(".txt")
|
42 |
+
outf_src = os.path.join(args.output_folder, bname_src+".pt")
|
43 |
+
if os.path.exists(outf_src):
|
44 |
+
print(f"Skipping source file {outf_src} as it already exists")
|
45 |
+
else:
|
46 |
+
with open(args.file_source_sentences, "r") as f:
|
47 |
+
l_sents = [x.strip() for x in f.readlines()]
|
48 |
+
mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
|
49 |
+
print(mean_emb.shape)
|
50 |
+
torch.save(mean_emb, outf_src)
|
51 |
+
|
52 |
+
bname_tgt = os.path.basename(args.file_target_sentences).strip(".txt")
|
53 |
+
outf_tgt = os.path.join(args.output_folder, bname_tgt+".pt")
|
54 |
+
if os.path.exists(outf_tgt):
|
55 |
+
print(f"Skipping target file {outf_tgt} as it already exists")
|
56 |
+
else:
|
57 |
+
with open(args.file_target_sentences, "r") as f:
|
58 |
+
l_sents = [x.strip() for x in f.readlines()]
|
59 |
+
mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
|
60 |
+
print(mean_emb.shape)
|
61 |
+
torch.save(mean_emb, outf_tgt)
|
submodules/pix2pix-zero/src/utils/__pycache__/base_pipeline.cpython-310.pyc
ADDED
Binary file (10.7 kB). View file
|
|
submodules/pix2pix-zero/src/utils/__pycache__/cross_attention.cpython-310.pyc
ADDED
Binary file (1.43 kB). View file
|
|
submodules/pix2pix-zero/src/utils/__pycache__/ddim_inv.cpython-310.pyc
ADDED
Binary file (4.4 kB). View file
|
|
submodules/pix2pix-zero/src/utils/__pycache__/edit_directions.cpython-310.pyc
ADDED
Binary file (693 Bytes). View file
|
|
submodules/pix2pix-zero/src/utils/__pycache__/edit_pipeline.cpython-310.pyc
ADDED
Binary file (4.12 kB). View file
|
|
submodules/pix2pix-zero/src/utils/__pycache__/scheduler.cpython-310.pyc
ADDED
Binary file (10.4 kB). View file
|
|
submodules/pix2pix-zero/src/utils/base_pipeline.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import inspect
|
4 |
+
from packaging import version
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
6 |
+
|
7 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
8 |
+
from diffusers import DiffusionPipeline
|
9 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
10 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
11 |
+
from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
|
12 |
+
from diffusers import StableDiffusionPipeline
|
13 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class BasePipeline(DiffusionPipeline):
|
18 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
vae: AutoencoderKL,
|
22 |
+
text_encoder: CLIPTextModel,
|
23 |
+
tokenizer: CLIPTokenizer,
|
24 |
+
unet: UNet2DConditionModel,
|
25 |
+
scheduler: KarrasDiffusionSchedulers,
|
26 |
+
safety_checker: StableDiffusionSafetyChecker,
|
27 |
+
feature_extractor: CLIPFeatureExtractor,
|
28 |
+
requires_safety_checker: bool = True,
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
33 |
+
deprecation_message = (
|
34 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
35 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
36 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
37 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
38 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
39 |
+
" file"
|
40 |
+
)
|
41 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
42 |
+
new_config = dict(scheduler.config)
|
43 |
+
new_config["steps_offset"] = 1
|
44 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
45 |
+
|
46 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
47 |
+
deprecation_message = (
|
48 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
49 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
50 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
51 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
52 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
53 |
+
)
|
54 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
55 |
+
new_config = dict(scheduler.config)
|
56 |
+
new_config["clip_sample"] = False
|
57 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
58 |
+
|
59 |
+
if safety_checker is None and requires_safety_checker:
|
60 |
+
logger.warning(
|
61 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
62 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
63 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
64 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
65 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
66 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
67 |
+
)
|
68 |
+
|
69 |
+
if safety_checker is not None and feature_extractor is None:
|
70 |
+
raise ValueError(
|
71 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
72 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
73 |
+
)
|
74 |
+
|
75 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
76 |
+
version.parse(unet.config._diffusers_version).base_version
|
77 |
+
) < version.parse("0.9.0.dev0")
|
78 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
79 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
80 |
+
deprecation_message = (
|
81 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
82 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
83 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
84 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
85 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
86 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
87 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
88 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
89 |
+
" the `unet/config.json` file"
|
90 |
+
)
|
91 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
92 |
+
new_config = dict(unet.config)
|
93 |
+
new_config["sample_size"] = 64
|
94 |
+
unet._internal_dict = FrozenDict(new_config)
|
95 |
+
|
96 |
+
self.register_modules(
|
97 |
+
vae=vae,
|
98 |
+
text_encoder=text_encoder,
|
99 |
+
tokenizer=tokenizer,
|
100 |
+
unet=unet,
|
101 |
+
scheduler=scheduler,
|
102 |
+
safety_checker=safety_checker,
|
103 |
+
feature_extractor=feature_extractor,
|
104 |
+
)
|
105 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
106 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
107 |
+
|
108 |
+
@property
|
109 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
110 |
+
def _execution_device(self):
|
111 |
+
r"""
|
112 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
113 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
114 |
+
hooks.
|
115 |
+
"""
|
116 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
117 |
+
return self.device
|
118 |
+
for module in self.unet.modules():
|
119 |
+
if (
|
120 |
+
hasattr(module, "_hf_hook")
|
121 |
+
and hasattr(module._hf_hook, "execution_device")
|
122 |
+
and module._hf_hook.execution_device is not None
|
123 |
+
):
|
124 |
+
return torch.device(module._hf_hook.execution_device)
|
125 |
+
return self.device
|
126 |
+
|
127 |
+
|
128 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
129 |
+
def _encode_prompt(
|
130 |
+
self,
|
131 |
+
prompt,
|
132 |
+
device,
|
133 |
+
num_images_per_prompt,
|
134 |
+
do_classifier_free_guidance,
|
135 |
+
negative_prompt=None,
|
136 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
137 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
138 |
+
):
|
139 |
+
r"""
|
140 |
+
Encodes the prompt into text encoder hidden states.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
prompt (`str` or `List[str]`, *optional*):
|
144 |
+
prompt to be encoded
|
145 |
+
device: (`torch.device`):
|
146 |
+
torch device
|
147 |
+
num_images_per_prompt (`int`):
|
148 |
+
number of images that should be generated per prompt
|
149 |
+
do_classifier_free_guidance (`bool`):
|
150 |
+
whether to use classifier free guidance or not
|
151 |
+
negative_ prompt (`str` or `List[str]`, *optional*):
|
152 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
153 |
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
154 |
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
155 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
156 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
157 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
158 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
159 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
160 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
161 |
+
argument.
|
162 |
+
"""
|
163 |
+
if prompt is not None and isinstance(prompt, str):
|
164 |
+
batch_size = 1
|
165 |
+
elif prompt is not None and isinstance(prompt, list):
|
166 |
+
batch_size = len(prompt)
|
167 |
+
else:
|
168 |
+
batch_size = prompt_embeds.shape[0]
|
169 |
+
|
170 |
+
if prompt_embeds is None:
|
171 |
+
text_inputs = self.tokenizer(
|
172 |
+
prompt,
|
173 |
+
padding="max_length",
|
174 |
+
max_length=self.tokenizer.model_max_length,
|
175 |
+
truncation=True,
|
176 |
+
return_tensors="pt",
|
177 |
+
)
|
178 |
+
text_input_ids = text_inputs.input_ids
|
179 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
180 |
+
|
181 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
182 |
+
text_input_ids, untruncated_ids
|
183 |
+
):
|
184 |
+
removed_text = self.tokenizer.batch_decode(
|
185 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
186 |
+
)
|
187 |
+
logger.warning(
|
188 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
189 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
190 |
+
)
|
191 |
+
|
192 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
193 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
194 |
+
else:
|
195 |
+
attention_mask = None
|
196 |
+
|
197 |
+
prompt_embeds = self.text_encoder(
|
198 |
+
text_input_ids.to(device),
|
199 |
+
attention_mask=attention_mask,
|
200 |
+
)
|
201 |
+
prompt_embeds = prompt_embeds[0]
|
202 |
+
|
203 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
204 |
+
|
205 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
206 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
207 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
208 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
209 |
+
|
210 |
+
# get unconditional embeddings for classifier free guidance
|
211 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
212 |
+
uncond_tokens: List[str]
|
213 |
+
if negative_prompt is None:
|
214 |
+
uncond_tokens = [""] * batch_size
|
215 |
+
elif type(prompt) is not type(negative_prompt):
|
216 |
+
raise TypeError(
|
217 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
218 |
+
f" {type(prompt)}."
|
219 |
+
)
|
220 |
+
elif isinstance(negative_prompt, str):
|
221 |
+
uncond_tokens = [negative_prompt]
|
222 |
+
elif batch_size != len(negative_prompt):
|
223 |
+
raise ValueError(
|
224 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
225 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
226 |
+
" the batch size of `prompt`."
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
uncond_tokens = negative_prompt
|
230 |
+
|
231 |
+
max_length = prompt_embeds.shape[1]
|
232 |
+
uncond_input = self.tokenizer(
|
233 |
+
uncond_tokens,
|
234 |
+
padding="max_length",
|
235 |
+
max_length=max_length,
|
236 |
+
truncation=True,
|
237 |
+
return_tensors="pt",
|
238 |
+
)
|
239 |
+
|
240 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
241 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
242 |
+
else:
|
243 |
+
attention_mask = None
|
244 |
+
|
245 |
+
negative_prompt_embeds = self.text_encoder(
|
246 |
+
uncond_input.input_ids.to(device),
|
247 |
+
attention_mask=attention_mask,
|
248 |
+
)
|
249 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
250 |
+
|
251 |
+
if do_classifier_free_guidance:
|
252 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
253 |
+
seq_len = negative_prompt_embeds.shape[1]
|
254 |
+
|
255 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
256 |
+
|
257 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
258 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
259 |
+
|
260 |
+
# For classifier free guidance, we need to do two forward passes.
|
261 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
262 |
+
# to avoid doing two forward passes
|
263 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
264 |
+
|
265 |
+
return prompt_embeds
|
266 |
+
|
267 |
+
|
268 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
269 |
+
def decode_latents(self, latents):
|
270 |
+
latents = 1 / 0.18215 * latents
|
271 |
+
image = self.vae.decode(latents).sample
|
272 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
273 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
274 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
275 |
+
return image
|
276 |
+
|
277 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
278 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
279 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
280 |
+
raise ValueError(
|
281 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
282 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
283 |
+
)
|
284 |
+
|
285 |
+
if latents is None:
|
286 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
287 |
+
else:
|
288 |
+
latents = latents.to(device)
|
289 |
+
|
290 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
291 |
+
latents = latents * self.scheduler.init_noise_sigma
|
292 |
+
return latents
|
293 |
+
|
294 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
295 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
296 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
297 |
+
# eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
298 |
+
# eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
|
299 |
+
# and should be between [0, 1]
|
300 |
+
|
301 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
302 |
+
extra_step_kwargs = {}
|
303 |
+
if accepts_eta:
|
304 |
+
extra_step_kwargs["eta"] = eta
|
305 |
+
|
306 |
+
# check if the scheduler accepts generator
|
307 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
308 |
+
if accepts_generator:
|
309 |
+
extra_step_kwargs["generator"] = generator
|
310 |
+
return extra_step_kwargs
|
311 |
+
|
312 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
313 |
+
def run_safety_checker(self, image, device, dtype):
|
314 |
+
if self.safety_checker is not None:
|
315 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
316 |
+
image, has_nsfw_concept = self.safety_checker(
|
317 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
318 |
+
)
|
319 |
+
else:
|
320 |
+
has_nsfw_concept = None
|
321 |
+
return image, has_nsfw_concept
|
322 |
+
|
submodules/pix2pix-zero/src/utils/cross_attention.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.models.attention import CrossAttention
|
3 |
+
|
4 |
+
class MyCrossAttnProcessor:
|
5 |
+
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
6 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
7 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
8 |
+
|
9 |
+
query = attn.to_q(hidden_states)
|
10 |
+
|
11 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
12 |
+
key = attn.to_k(encoder_hidden_states)
|
13 |
+
value = attn.to_v(encoder_hidden_states)
|
14 |
+
|
15 |
+
query = attn.head_to_batch_dim(query)
|
16 |
+
key = attn.head_to_batch_dim(key)
|
17 |
+
value = attn.head_to_batch_dim(value)
|
18 |
+
|
19 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
20 |
+
# new bookkeeping to save the attn probs
|
21 |
+
attn.attn_probs = attention_probs
|
22 |
+
|
23 |
+
hidden_states = torch.bmm(attention_probs, value)
|
24 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
25 |
+
|
26 |
+
# linear proj
|
27 |
+
hidden_states = attn.to_out[0](hidden_states)
|
28 |
+
# dropout
|
29 |
+
hidden_states = attn.to_out[1](hidden_states)
|
30 |
+
|
31 |
+
return hidden_states
|
32 |
+
|
33 |
+
|
34 |
+
"""
|
35 |
+
A function that prepares a U-Net model for training by enabling gradient computation
|
36 |
+
for a specified set of parameters and setting the forward pass to be performed by a
|
37 |
+
custom cross attention processor.
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
unet: A U-Net model.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
unet: The prepared U-Net model.
|
44 |
+
"""
|
45 |
+
def prep_unet(unet):
|
46 |
+
# set the gradients for XA maps to be true
|
47 |
+
for name, params in unet.named_parameters():
|
48 |
+
if 'attn2' in name:
|
49 |
+
params.requires_grad = True
|
50 |
+
else:
|
51 |
+
params.requires_grad = False
|
52 |
+
# replace the fwd function
|
53 |
+
for name, module in unet.named_modules():
|
54 |
+
module_name = type(module).__name__
|
55 |
+
if module_name == "CrossAttention":
|
56 |
+
module.set_processor(MyCrossAttnProcessor())
|
57 |
+
return unet
|
submodules/pix2pix-zero/src/utils/ddim_inv.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from random import randrange
|
6 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
7 |
+
from diffusers import DDIMScheduler
|
8 |
+
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
|
9 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
10 |
+
sys.path.insert(0, "src/utils")
|
11 |
+
from base_pipeline import BasePipeline
|
12 |
+
from cross_attention import prep_unet
|
13 |
+
|
14 |
+
|
15 |
+
class DDIMInversion(BasePipeline):
|
16 |
+
|
17 |
+
def auto_corr_loss(self, x, random_shift=True):
|
18 |
+
B,C,H,W = x.shape
|
19 |
+
assert B==1
|
20 |
+
x = x.squeeze(0)
|
21 |
+
# x must be shape [C,H,W] now
|
22 |
+
reg_loss = 0.0
|
23 |
+
for ch_idx in range(x.shape[0]):
|
24 |
+
noise = x[ch_idx][None, None,:,:]
|
25 |
+
while True:
|
26 |
+
if random_shift: roll_amount = randrange(noise.shape[2]//2)
|
27 |
+
else: roll_amount = 1
|
28 |
+
reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=2)).mean()**2
|
29 |
+
reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=3)).mean()**2
|
30 |
+
if noise.shape[2] <= 8:
|
31 |
+
break
|
32 |
+
noise = F.avg_pool2d(noise, kernel_size=2)
|
33 |
+
return reg_loss
|
34 |
+
|
35 |
+
def kl_divergence(self, x):
|
36 |
+
_mu = x.mean()
|
37 |
+
_var = x.var()
|
38 |
+
return _var + _mu**2 - 1 - torch.log(_var+1e-7)
|
39 |
+
|
40 |
+
|
41 |
+
def __call__(
|
42 |
+
self,
|
43 |
+
prompt: Union[str, List[str]] = None,
|
44 |
+
num_inversion_steps: int = 50,
|
45 |
+
guidance_scale: float = 7.5,
|
46 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
47 |
+
num_images_per_prompt: Optional[int] = 1,
|
48 |
+
eta: float = 0.0,
|
49 |
+
output_type: Optional[str] = "pil",
|
50 |
+
return_dict: bool = True,
|
51 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
52 |
+
img=None, # the input image as a PIL image
|
53 |
+
torch_dtype=torch.float32,
|
54 |
+
|
55 |
+
# inversion regularization parameters
|
56 |
+
lambda_ac: float = 20.0,
|
57 |
+
lambda_kl: float = 20.0,
|
58 |
+
num_reg_steps: int = 5,
|
59 |
+
num_ac_rolls: int = 5,
|
60 |
+
):
|
61 |
+
|
62 |
+
# 0. modify the unet to be useful :D
|
63 |
+
self.unet = prep_unet(self.unet)
|
64 |
+
|
65 |
+
# set the scheduler to be the Inverse DDIM scheduler
|
66 |
+
# self.scheduler = MyDDIMScheduler.from_config(self.scheduler.config)
|
67 |
+
|
68 |
+
device = self._execution_device
|
69 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
70 |
+
self.scheduler.set_timesteps(num_inversion_steps, device=device)
|
71 |
+
timesteps = self.scheduler.timesteps
|
72 |
+
|
73 |
+
# Encode the input image with the first stage model
|
74 |
+
x0 = np.array(img)/255
|
75 |
+
x0 = torch.from_numpy(x0).type(torch_dtype).permute(2, 0, 1).unsqueeze(dim=0).repeat(1, 1, 1, 1).cuda()
|
76 |
+
x0 = (x0 - 0.5) * 2.
|
77 |
+
with torch.no_grad():
|
78 |
+
x0_enc = self.vae.encode(x0).latent_dist.sample().to(device, torch_dtype)
|
79 |
+
latents = x0_enc = 0.18215 * x0_enc
|
80 |
+
|
81 |
+
# Decode and return the image
|
82 |
+
with torch.no_grad():
|
83 |
+
x0_dec = self.decode_latents(x0_enc.detach())
|
84 |
+
image_x0_dec = self.numpy_to_pil(x0_dec)
|
85 |
+
|
86 |
+
with torch.no_grad():
|
87 |
+
prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt).to(device)
|
88 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(None, eta)
|
89 |
+
|
90 |
+
# Do the inversion
|
91 |
+
num_warmup_steps = len(timesteps) - num_inversion_steps * self.scheduler.order # should be 0?
|
92 |
+
with self.progress_bar(total=num_inversion_steps) as progress_bar:
|
93 |
+
for i, t in enumerate(timesteps.flip(0)[1:-1]):
|
94 |
+
# expand the latents if we are doing classifier free guidance
|
95 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
96 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
97 |
+
|
98 |
+
# predict the noise residual
|
99 |
+
with torch.no_grad():
|
100 |
+
noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
|
101 |
+
|
102 |
+
# perform guidance
|
103 |
+
if do_classifier_free_guidance:
|
104 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
105 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
106 |
+
|
107 |
+
# regularization of the noise prediction
|
108 |
+
e_t = noise_pred
|
109 |
+
for _outer in range(num_reg_steps):
|
110 |
+
if lambda_ac>0:
|
111 |
+
for _inner in range(num_ac_rolls):
|
112 |
+
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
|
113 |
+
l_ac = self.auto_corr_loss(_var)
|
114 |
+
l_ac.backward()
|
115 |
+
_grad = _var.grad.detach()/num_ac_rolls
|
116 |
+
e_t = e_t - lambda_ac*_grad
|
117 |
+
if lambda_kl>0:
|
118 |
+
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
|
119 |
+
l_kld = self.kl_divergence(_var)
|
120 |
+
l_kld.backward()
|
121 |
+
_grad = _var.grad.detach()
|
122 |
+
e_t = e_t - lambda_kl*_grad
|
123 |
+
e_t = e_t.detach()
|
124 |
+
noise_pred = e_t
|
125 |
+
|
126 |
+
# compute the previous noisy sample x_t -> x_t-1
|
127 |
+
latents = self.scheduler.step(noise_pred, t, latents, reverse=True, **extra_step_kwargs).prev_sample
|
128 |
+
|
129 |
+
# call the callback, if provided
|
130 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
131 |
+
progress_bar.update()
|
132 |
+
|
133 |
+
|
134 |
+
x_inv = latents.detach().clone()
|
135 |
+
# reconstruct the image
|
136 |
+
|
137 |
+
# 8. Post-processing
|
138 |
+
image = self.decode_latents(latents.detach())
|
139 |
+
image = self.numpy_to_pil(image)
|
140 |
+
return x_inv, image, image_x0_dec
|
submodules/pix2pix-zero/src/utils/edit_directions.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
"""
|
6 |
+
This function takes in a task name and returns the direction in the embedding space that transforms class A to class B for the given task.
|
7 |
+
|
8 |
+
Parameters:
|
9 |
+
task_name (str): name of the task for which direction is to be constructed.
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
torch.Tensor: A tensor representing the direction in the embedding space that transforms class A to class B.
|
13 |
+
|
14 |
+
Examples:
|
15 |
+
>>> construct_direction("cat2dog")
|
16 |
+
"""
|
17 |
+
def construct_direction(task_name):
|
18 |
+
if task_name=="cat2dog":
|
19 |
+
emb_dir = f"assets/embeddings_sd_1.4"
|
20 |
+
embs_a = torch.load(os.path.join(emb_dir, f"cat.pt"))
|
21 |
+
embs_b = torch.load(os.path.join(emb_dir, f"dog.pt"))
|
22 |
+
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
|
23 |
+
elif task_name=="dog2cat":
|
24 |
+
emb_dir = f"assets/embeddings_sd_1.4"
|
25 |
+
embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
|
26 |
+
embs_b = torch.load(os.path.join(emb_dir, f"cat.pt"))
|
27 |
+
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
|
28 |
+
else:
|
29 |
+
raise NotImplementedError
|
submodules/pix2pix-zero/src/utils/edit_pipeline.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb, sys
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
6 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
7 |
+
sys.path.insert(0, "src/utils")
|
8 |
+
from base_pipeline import BasePipeline
|
9 |
+
from cross_attention import prep_unet
|
10 |
+
|
11 |
+
|
12 |
+
class EditingPipeline(BasePipeline):
|
13 |
+
def __call__(
|
14 |
+
self,
|
15 |
+
prompt: Union[str, List[str]] = None,
|
16 |
+
height: Optional[int] = None,
|
17 |
+
width: Optional[int] = None,
|
18 |
+
num_inference_steps: int = 50,
|
19 |
+
guidance_scale: float = 7.5,
|
20 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
21 |
+
num_images_per_prompt: Optional[int] = 1,
|
22 |
+
eta: float = 0.0,
|
23 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
24 |
+
latents: Optional[torch.FloatTensor] = None,
|
25 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
26 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
27 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
28 |
+
|
29 |
+
# pix2pix parameters
|
30 |
+
guidance_amount=0.1,
|
31 |
+
edit_dir=None,
|
32 |
+
x_in=None,
|
33 |
+
only_sample=False,
|
34 |
+
|
35 |
+
):
|
36 |
+
|
37 |
+
x_in.to(dtype=self.unet.dtype, device=self._execution_device)
|
38 |
+
|
39 |
+
# 0. modify the unet to be useful :D
|
40 |
+
self.unet = prep_unet(self.unet)
|
41 |
+
|
42 |
+
# 1. setup all caching objects
|
43 |
+
d_ref_t2attn = {} # reference cross attention maps
|
44 |
+
|
45 |
+
# 2. Default height and width to unet
|
46 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
47 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
48 |
+
|
49 |
+
# TODO: add the input checker function
|
50 |
+
# self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds )
|
51 |
+
|
52 |
+
# 2. Define call parameters
|
53 |
+
if prompt is not None and isinstance(prompt, str):
|
54 |
+
batch_size = 1
|
55 |
+
elif prompt is not None and isinstance(prompt, list):
|
56 |
+
batch_size = len(prompt)
|
57 |
+
else:
|
58 |
+
batch_size = prompt_embeds.shape[0]
|
59 |
+
|
60 |
+
device = self._execution_device
|
61 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
62 |
+
x_in = x_in.to(dtype=self.unet.dtype, device=self._execution_device)
|
63 |
+
# 3. Encode input prompt = 2x77x1024
|
64 |
+
prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,)
|
65 |
+
|
66 |
+
# 4. Prepare timesteps
|
67 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
68 |
+
timesteps = self.scheduler.timesteps
|
69 |
+
|
70 |
+
# 5. Prepare latent variables
|
71 |
+
num_channels_latents = self.unet.in_channels
|
72 |
+
|
73 |
+
# randomly sample a latent code if not provided
|
74 |
+
latents = self.prepare_latents(batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, x_in,)
|
75 |
+
|
76 |
+
latents_init = latents.clone()
|
77 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
78 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
79 |
+
|
80 |
+
# 7. First Denoising loop for getting the reference cross attention maps
|
81 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
82 |
+
with torch.no_grad():
|
83 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
84 |
+
for i, t in enumerate(timesteps):
|
85 |
+
# expand the latents if we are doing classifier free guidance
|
86 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
87 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
88 |
+
|
89 |
+
# predict the noise residual
|
90 |
+
noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
|
91 |
+
|
92 |
+
# add the cross attention map to the dictionary
|
93 |
+
d_ref_t2attn[t.item()] = {}
|
94 |
+
for name, module in self.unet.named_modules():
|
95 |
+
module_name = type(module).__name__
|
96 |
+
if module_name == "CrossAttention" and 'attn2' in name:
|
97 |
+
attn_mask = module.attn_probs # size is num_channel,s*s,77
|
98 |
+
d_ref_t2attn[t.item()][name] = attn_mask.detach().cpu()
|
99 |
+
|
100 |
+
# perform guidance
|
101 |
+
if do_classifier_free_guidance:
|
102 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
103 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
104 |
+
|
105 |
+
# compute the previous noisy sample x_t -> x_t-1
|
106 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
107 |
+
|
108 |
+
# call the callback, if provided
|
109 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
110 |
+
progress_bar.update()
|
111 |
+
|
112 |
+
# make the reference image (reconstruction)
|
113 |
+
image_rec = self.numpy_to_pil(self.decode_latents(latents.detach()))
|
114 |
+
|
115 |
+
if only_sample:
|
116 |
+
return image_rec
|
117 |
+
|
118 |
+
|
119 |
+
prompt_embeds_edit = prompt_embeds.clone()
|
120 |
+
#add the edit only to the second prompt, idx 0 is the negative prompt
|
121 |
+
prompt_embeds_edit[1:2] += edit_dir
|
122 |
+
|
123 |
+
latents = latents_init
|
124 |
+
# Second denoising loop for editing the text prompt
|
125 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
126 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
127 |
+
for i, t in enumerate(timesteps):
|
128 |
+
# expand the latents if we are doing classifier free guidance
|
129 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
130 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
131 |
+
|
132 |
+
x_in = latent_model_input.detach().clone()
|
133 |
+
x_in.requires_grad = True
|
134 |
+
|
135 |
+
opt = torch.optim.SGD([x_in], lr=guidance_amount)
|
136 |
+
|
137 |
+
# predict the noise residual
|
138 |
+
noise_pred = self.unet(x_in,t,encoder_hidden_states=prompt_embeds_edit.detach(),cross_attention_kwargs=cross_attention_kwargs,).sample
|
139 |
+
|
140 |
+
loss = 0.0
|
141 |
+
for name, module in self.unet.named_modules():
|
142 |
+
module_name = type(module).__name__
|
143 |
+
if module_name == "CrossAttention" and 'attn2' in name:
|
144 |
+
curr = module.attn_probs # size is num_channel,s*s,77
|
145 |
+
ref = d_ref_t2attn[t.item()][name].detach().cuda()
|
146 |
+
loss += ((curr-ref)**2).sum((1,2)).mean(0)
|
147 |
+
loss.backward(retain_graph=False)
|
148 |
+
opt.step()
|
149 |
+
|
150 |
+
# recompute the noise
|
151 |
+
with torch.no_grad():
|
152 |
+
noise_pred = self.unet(x_in.detach(),t,encoder_hidden_states=prompt_embeds_edit,cross_attention_kwargs=cross_attention_kwargs,).sample
|
153 |
+
|
154 |
+
latents = x_in.detach().chunk(2)[0]
|
155 |
+
|
156 |
+
# perform guidance
|
157 |
+
if do_classifier_free_guidance:
|
158 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
159 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
160 |
+
|
161 |
+
# compute the previous noisy sample x_t -> x_t-1
|
162 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
163 |
+
|
164 |
+
# call the callback, if provided
|
165 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
166 |
+
progress_bar.update()
|
167 |
+
|
168 |
+
|
169 |
+
# 8. Post-processing
|
170 |
+
image = self.decode_latents(latents.detach())
|
171 |
+
|
172 |
+
# 9. Run safety checker
|
173 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
174 |
+
|
175 |
+
# 10. Convert to PIL
|
176 |
+
image_edit = self.numpy_to_pil(image)
|
177 |
+
|
178 |
+
|
179 |
+
return image_rec, image_edit
|
submodules/pix2pix-zero/src/utils/scheduler.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
16 |
+
# and https://github.com/hojonathanho/diffusion
|
17 |
+
import os, sys, pdb
|
18 |
+
import math
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
|
25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
26 |
+
from diffusers.utils import BaseOutput, randn_tensor
|
27 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
32 |
+
class DDIMSchedulerOutput(BaseOutput):
|
33 |
+
"""
|
34 |
+
Output class for the scheduler's step function output.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
38 |
+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
39 |
+
denoising loop.
|
40 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
41 |
+
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
42 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
43 |
+
"""
|
44 |
+
|
45 |
+
prev_sample: torch.FloatTensor
|
46 |
+
pred_original_sample: Optional[torch.FloatTensor] = None
|
47 |
+
|
48 |
+
|
49 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
50 |
+
"""
|
51 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
52 |
+
(1-beta) over time from t = [0,1].
|
53 |
+
|
54 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
55 |
+
to that part of the diffusion process.
|
56 |
+
|
57 |
+
|
58 |
+
Args:
|
59 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
60 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
61 |
+
prevent singularities.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
65 |
+
"""
|
66 |
+
|
67 |
+
def alpha_bar(time_step):
|
68 |
+
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
69 |
+
|
70 |
+
betas = []
|
71 |
+
for i in range(num_diffusion_timesteps):
|
72 |
+
t1 = i / num_diffusion_timesteps
|
73 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
74 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
75 |
+
return torch.tensor(betas)
|
76 |
+
|
77 |
+
|
78 |
+
class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
79 |
+
"""
|
80 |
+
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
|
81 |
+
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
|
82 |
+
|
83 |
+
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
84 |
+
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
85 |
+
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
86 |
+
[`~SchedulerMixin.from_pretrained`] functions.
|
87 |
+
|
88 |
+
For more details, see the original paper: https://arxiv.org/abs/2010.02502
|
89 |
+
|
90 |
+
Args:
|
91 |
+
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
92 |
+
beta_start (`float`): the starting `beta` value of inference.
|
93 |
+
beta_end (`float`): the final `beta` value.
|
94 |
+
beta_schedule (`str`):
|
95 |
+
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
96 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
97 |
+
trained_betas (`np.ndarray`, optional):
|
98 |
+
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
99 |
+
clip_sample (`bool`, default `True`):
|
100 |
+
option to clip predicted sample between -1 and 1 for numerical stability.
|
101 |
+
set_alpha_to_one (`bool`, default `True`):
|
102 |
+
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
103 |
+
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
104 |
+
otherwise it uses the value of alpha at step 0.
|
105 |
+
steps_offset (`int`, default `0`):
|
106 |
+
an offset added to the inference steps. You can use a combination of `offset=1` and
|
107 |
+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
108 |
+
stable diffusion.
|
109 |
+
prediction_type (`str`, default `epsilon`, optional):
|
110 |
+
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
111 |
+
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
112 |
+
https://imagen.research.google/video/paper.pdf)
|
113 |
+
"""
|
114 |
+
|
115 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
116 |
+
order = 1
|
117 |
+
|
118 |
+
@register_to_config
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
num_train_timesteps: int = 1000,
|
122 |
+
beta_start: float = 0.0001,
|
123 |
+
beta_end: float = 0.02,
|
124 |
+
beta_schedule: str = "linear",
|
125 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
126 |
+
clip_sample: bool = True,
|
127 |
+
set_alpha_to_one: bool = True,
|
128 |
+
steps_offset: int = 0,
|
129 |
+
prediction_type: str = "epsilon",
|
130 |
+
):
|
131 |
+
if trained_betas is not None:
|
132 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
133 |
+
elif beta_schedule == "linear":
|
134 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
135 |
+
elif beta_schedule == "scaled_linear":
|
136 |
+
# this schedule is very specific to the latent diffusion model.
|
137 |
+
self.betas = (
|
138 |
+
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
139 |
+
)
|
140 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
141 |
+
# Glide cosine schedule
|
142 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
143 |
+
else:
|
144 |
+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
145 |
+
|
146 |
+
self.alphas = 1.0 - self.betas
|
147 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
148 |
+
|
149 |
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
150 |
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
151 |
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
152 |
+
# whether we use the final alpha of the "non-previous" one.
|
153 |
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
154 |
+
|
155 |
+
# standard deviation of the initial noise distribution
|
156 |
+
self.init_noise_sigma = 1.0
|
157 |
+
|
158 |
+
# setable values
|
159 |
+
self.num_inference_steps = None
|
160 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
161 |
+
|
162 |
+
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
163 |
+
"""
|
164 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
165 |
+
current timestep.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
sample (`torch.FloatTensor`): input sample
|
169 |
+
timestep (`int`, optional): current timestep
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
`torch.FloatTensor`: scaled input sample
|
173 |
+
"""
|
174 |
+
return sample
|
175 |
+
|
176 |
+
def _get_variance(self, timestep, prev_timestep):
|
177 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
178 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
179 |
+
beta_prod_t = 1 - alpha_prod_t
|
180 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
181 |
+
|
182 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
183 |
+
|
184 |
+
return variance
|
185 |
+
|
186 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
187 |
+
"""
|
188 |
+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
num_inference_steps (`int`):
|
192 |
+
the number of diffusion steps used when generating samples with a pre-trained model.
|
193 |
+
"""
|
194 |
+
|
195 |
+
if num_inference_steps > self.config.num_train_timesteps:
|
196 |
+
raise ValueError(
|
197 |
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
198 |
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
199 |
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
200 |
+
)
|
201 |
+
|
202 |
+
self.num_inference_steps = num_inference_steps
|
203 |
+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
204 |
+
# creates integer timesteps by multiplying by ratio
|
205 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
206 |
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
207 |
+
self.timesteps = torch.from_numpy(timesteps).to(device)
|
208 |
+
self.timesteps += self.config.steps_offset
|
209 |
+
|
210 |
+
def step(
|
211 |
+
self,
|
212 |
+
model_output: torch.FloatTensor,
|
213 |
+
timestep: int,
|
214 |
+
sample: torch.FloatTensor,
|
215 |
+
eta: float = 0.0,
|
216 |
+
use_clipped_model_output: bool = False,
|
217 |
+
generator=None,
|
218 |
+
variance_noise: Optional[torch.FloatTensor] = None,
|
219 |
+
return_dict: bool = True,
|
220 |
+
reverse=False
|
221 |
+
) -> Union[DDIMSchedulerOutput, Tuple]:
|
222 |
+
|
223 |
+
|
224 |
+
e_t = model_output
|
225 |
+
|
226 |
+
x = sample
|
227 |
+
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
|
228 |
+
# print(timestep, prev_timestep)
|
229 |
+
a_t = alpha_prod_t = self.alphas_cumprod[timestep-1]
|
230 |
+
a_prev = alpha_t_prev = self.alphas_cumprod[prev_timestep-1] if prev_timestep >= 0 else self.final_alpha_cumprod
|
231 |
+
beta_prod_t = 1 - alpha_prod_t
|
232 |
+
|
233 |
+
pred_x0 = (x - (1-a_t)**0.5 * e_t) / a_t.sqrt()
|
234 |
+
# direction pointing to x_t
|
235 |
+
dir_xt = (1. - a_prev).sqrt() * e_t
|
236 |
+
x = a_prev.sqrt()*pred_x0 + dir_xt
|
237 |
+
if not return_dict:
|
238 |
+
return (x,)
|
239 |
+
return DDIMSchedulerOutput(prev_sample=x, pred_original_sample=pred_x0)
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
def add_noise(
|
246 |
+
self,
|
247 |
+
original_samples: torch.FloatTensor,
|
248 |
+
noise: torch.FloatTensor,
|
249 |
+
timesteps: torch.IntTensor,
|
250 |
+
) -> torch.FloatTensor:
|
251 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
252 |
+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
253 |
+
timesteps = timesteps.to(original_samples.device)
|
254 |
+
|
255 |
+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
256 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
257 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
258 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
259 |
+
|
260 |
+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
261 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
262 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
263 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
264 |
+
|
265 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
266 |
+
return noisy_samples
|
267 |
+
|
268 |
+
def get_velocity(
|
269 |
+
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
270 |
+
) -> torch.FloatTensor:
|
271 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
272 |
+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
273 |
+
timesteps = timesteps.to(sample.device)
|
274 |
+
|
275 |
+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
276 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
277 |
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
278 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
279 |
+
|
280 |
+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
281 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
282 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
283 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
284 |
+
|
285 |
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
286 |
+
return velocity
|
287 |
+
|
288 |
+
def __len__(self):
|
289 |
+
return self.config.num_train_timesteps
|
utils.py
ADDED
File without changes
|
utils/__init__.py
ADDED
File without changes
|
utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (172 Bytes). View file
|
|
utils/__pycache__/direction_utils.cpython-310.pyc
ADDED
Binary file (4.55 kB). View file
|
|
utils/__pycache__/generate_synthetic.cpython-310.pyc
ADDED
Binary file (8.97 kB). View file
|
|
utils/__pycache__/gradio_utils.cpython-310.pyc
ADDED
Binary file (14.9 kB). View file
|
|
utils/direction_utils.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, pdb
|
2 |
+
|
3 |
+
import torch, torchvision
|
4 |
+
from huggingface_hub import hf_hub_url, cached_download, hf_hub_download, HfApi
|
5 |
+
import joblib
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import json
|
9 |
+
import requests
|
10 |
+
import random
|
11 |
+
|
12 |
+
|
13 |
+
"""
|
14 |
+
Returns the list of directions currently available in the HF library
|
15 |
+
"""
|
16 |
+
def get_all_directions_names():
|
17 |
+
hf_api = HfApi()
|
18 |
+
info = hf_api.list_models(author="pix2pix-zero-library")
|
19 |
+
l_model_ids = [m.modelId for m in info]
|
20 |
+
l_model_ids = [m for m in l_model_ids if "_sd14" in m]
|
21 |
+
l_edit_names = [m.split("/")[-1] for m in l_model_ids]
|
22 |
+
# l_edit_names = [m for m in l_edit_names if "_sd14" in m]
|
23 |
+
|
24 |
+
# pdb.set_trace()
|
25 |
+
l_desc = [hf_hub_download(repo_id=m_id, filename="short_description.txt") for m_id in l_model_ids]
|
26 |
+
d_name2desc = {k: open(m).read() for k,m in zip(l_edit_names, l_desc)}
|
27 |
+
|
28 |
+
return d_name2desc
|
29 |
+
|
30 |
+
|
31 |
+
def get_emb(dir_name):
|
32 |
+
REPO_ID = f"pix2pix-zero-library/{dir_name.replace('.pt','')}"
|
33 |
+
if "_sd14" not in REPO_ID: REPO_ID += "_sd14"
|
34 |
+
FILENAME = dir_name
|
35 |
+
if "_sd14" not in FILENAME: FILENAME += "_sd14"
|
36 |
+
if ".pt" not in FILENAME: FILENAME += ".pt"
|
37 |
+
ret = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
|
38 |
+
return torch.load(ret)
|
39 |
+
|
40 |
+
|
41 |
+
def generate_image_prompts_with_templates(word):
|
42 |
+
prompts = []
|
43 |
+
adjectives = ['majestic', 'cute', 'colorful', 'ferocious', 'elegant', 'graceful', 'slimy', 'adorable', 'scary', 'fuzzy', 'tiny', 'gigantic', 'brave', 'fierce', 'mysterious', 'curious', 'fascinating', 'charming', 'gleaming', 'rare']
|
44 |
+
verbs = ['strolling', 'jumping', 'lounging', 'flying', 'sleeping', 'eating', 'playing', 'working', 'gazing', 'standing']
|
45 |
+
adverbs = ['gracefully', 'playfully', 'elegantly', 'fiercely', 'curiously', 'fascinatingly', 'charmingly', 'gently', 'slowly', 'quickly', 'awkwardly', 'carelessly', 'cautiously', 'innocently', 'powerfully', 'grumpily', 'mysteriously']
|
46 |
+
backgrounds = ['a sunny beach', 'a bustling city', 'a quiet forest', 'a cozy living room', 'a futuristic space station', 'a medieval castle', 'an enchanted forest', 'a misty graveyard', 'a snowy mountain peak', 'a crowded market']
|
47 |
+
|
48 |
+
sentence_structures = {
|
49 |
+
"subject verb background": lambda word, bg, verb, adj, adv: f"A {word} {verb} {bg}.",
|
50 |
+
"background subject verb": lambda word, bg, verb, adj, adv: f"{bg}, a {word} is {verb}.",
|
51 |
+
"adjective subject verb background": lambda word, bg, verb, adj, adv: f"A {adj} {word} is {verb} {bg}.",
|
52 |
+
"subject verb adverb background": lambda word, bg, verb, adj, adv: f"A {word} is {verb} {adv} {bg}.",
|
53 |
+
"adverb subject verb background": lambda word, bg, verb, adj, adv: f"{adv.capitalize()}, a {word} is {verb} {bg}.",
|
54 |
+
"background adjective subject verb": lambda word, bg, verb, adj, adv: f"{bg}, there is a {adj} {word} {verb}.",
|
55 |
+
"subject verb adjective background": lambda word, bg, verb, adj, adv: f"A {word} {verb} {adj} {bg}.",
|
56 |
+
"adjective subject verb": lambda word, bg, verb, adj, adv: f"A {adj} {word} is {verb}.",
|
57 |
+
"subject adjective verb background": lambda word, bg, verb, adj, adv: f"A {word} is {adj} and {verb} {bg}.",
|
58 |
+
}
|
59 |
+
|
60 |
+
sentences = []
|
61 |
+
for bg in backgrounds:
|
62 |
+
for verb in verbs:
|
63 |
+
for adj in adjectives:
|
64 |
+
adv = random.choice(adverbs)
|
65 |
+
sentence = f"A {adv} {adj} {word} {verb} on {bg}."
|
66 |
+
sentence_structure = random.choice(list(sentence_structures.keys()))
|
67 |
+
sentence = sentence_structures[sentence_structure](word, bg, verb, adj, adv)
|
68 |
+
sentences.append(sentence)
|
69 |
+
return sentences
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
if __name__=="__main__":
|
75 |
+
print(get_all_directions_names())
|
76 |
+
# print(get_emb("dog_sd14.pt").shape)
|
77 |
+
# print(get_emb("dog").shape)
|
78 |
+
# print(generate_image_prompts("dog")[0:5])
|
79 |
+
|
utils/generate_synthetic.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, time, re, pdb
|
2 |
+
import torch, torchvision
|
3 |
+
import numpy
|
4 |
+
from PIL import Image
|
5 |
+
import hashlib
|
6 |
+
from tqdm import tqdm
|
7 |
+
import openai
|
8 |
+
from utils.direction_utils import *
|
9 |
+
|
10 |
+
p = "submodules/pix2pix-zero/src/utils"
|
11 |
+
if p not in sys.path:
|
12 |
+
sys.path.append(p)
|
13 |
+
from diffusers import DDIMScheduler
|
14 |
+
from edit_directions import construct_direction
|
15 |
+
from edit_pipeline import EditingPipeline
|
16 |
+
from ddim_inv import DDIMInversion
|
17 |
+
from scheduler import DDIMInverseScheduler
|
18 |
+
from lavis.models import load_model_and_preprocess
|
19 |
+
from transformers import T5Tokenizer, AutoTokenizer, T5ForConditionalGeneration, BloomForCausalLM
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
|
24 |
+
with torch.no_grad():
|
25 |
+
l_embeddings = []
|
26 |
+
for sent in tqdm(l_sentences):
|
27 |
+
text_inputs = tokenizer(
|
28 |
+
sent,
|
29 |
+
padding="max_length",
|
30 |
+
max_length=tokenizer.model_max_length,
|
31 |
+
truncation=True,
|
32 |
+
return_tensors="pt",
|
33 |
+
)
|
34 |
+
text_input_ids = text_inputs.input_ids
|
35 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
|
36 |
+
l_embeddings.append(prompt_embeds)
|
37 |
+
return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
def launch_generate_sample(prompt, seed, negative_scale, num_ddim):
|
42 |
+
os.makedirs("tmp", exist_ok=True)
|
43 |
+
# do the editing
|
44 |
+
edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
|
45 |
+
edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config)
|
46 |
+
|
47 |
+
# set the random seed and sample the input noise map
|
48 |
+
torch.cuda.manual_seed(int(seed))
|
49 |
+
z = torch.randn((1,4,64,64), device="cuda")
|
50 |
+
|
51 |
+
z_hashname = hashlib.sha256(z.cpu().numpy().tobytes()).hexdigest()
|
52 |
+
z_inv_fname = f"tmp/{z_hashname}_ddim_{num_ddim}_inv.pt"
|
53 |
+
torch.save(z, z_inv_fname)
|
54 |
+
|
55 |
+
rec_pil = edit_pipe(prompt,
|
56 |
+
num_inference_steps=num_ddim, x_in=z,
|
57 |
+
only_sample=True, # this flag will only generate the sampled image, not the edited image
|
58 |
+
guidance_scale=negative_scale,
|
59 |
+
negative_prompt="" # use the empty string for the negative prompt
|
60 |
+
)
|
61 |
+
# print(rec_pil)
|
62 |
+
del edit_pipe
|
63 |
+
torch.cuda.empty_cache()
|
64 |
+
|
65 |
+
return rec_pil[0], z_inv_fname
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
def clean_l_sentences(ls):
|
70 |
+
s = [re.sub('\d', '', x) for x in ls]
|
71 |
+
s = [x.replace(".","").replace("-","").replace(")","").strip() for x in s]
|
72 |
+
return s
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
def gpt3_compute_word2sentences(task_type, word, num=100):
|
77 |
+
l_sentences = []
|
78 |
+
if task_type=="object":
|
79 |
+
template_prompt = f"Provide many captions for images containing {word}."
|
80 |
+
elif task_type=="style":
|
81 |
+
template_prompt = f"Provide many captions for images that are in the {word} style."
|
82 |
+
while True:
|
83 |
+
ret = openai.Completion.create(
|
84 |
+
model="text-davinci-002",
|
85 |
+
prompt=template_prompt,
|
86 |
+
max_tokens=1000,
|
87 |
+
temperature=1.0)
|
88 |
+
raw_return = ret.choices[0].text
|
89 |
+
for line in raw_return.split("\n"):
|
90 |
+
line = line.strip()
|
91 |
+
if len(line)>10:
|
92 |
+
skip=False
|
93 |
+
for subword in word.split(" "):
|
94 |
+
if subword not in line: skip=True
|
95 |
+
if not skip: l_sentences.append(line)
|
96 |
+
else:
|
97 |
+
l_sentences.append(line+f", {word}")
|
98 |
+
time.sleep(0.05)
|
99 |
+
print(len(l_sentences))
|
100 |
+
if len(l_sentences)>=num:
|
101 |
+
break
|
102 |
+
l_sentences = clean_l_sentences(l_sentences)
|
103 |
+
return l_sentences
|
104 |
+
|
105 |
+
|
106 |
+
def flant5xl_compute_word2sentences(word, num=100):
|
107 |
+
text_input = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters."
|
108 |
+
|
109 |
+
l_sentences = []
|
110 |
+
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
|
111 |
+
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16)
|
112 |
+
input_ids = tokenizer(text_input, return_tensors="pt").input_ids.to("cuda")
|
113 |
+
input_length = input_ids.shape[1]
|
114 |
+
while True:
|
115 |
+
outputs = model.generate(input_ids,temperature=0.9, num_return_sequences=16, do_sample=True, max_length=128)
|
116 |
+
output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
|
117 |
+
for line in output:
|
118 |
+
line = line.strip()
|
119 |
+
skip=False
|
120 |
+
for subword in word.split(" "):
|
121 |
+
if subword not in line: skip=True
|
122 |
+
if not skip: l_sentences.append(line)
|
123 |
+
else: l_sentences.append(line+f", {word}")
|
124 |
+
print(len(l_sentences))
|
125 |
+
if len(l_sentences)>=num:
|
126 |
+
break
|
127 |
+
l_sentences = clean_l_sentences(l_sentences)
|
128 |
+
|
129 |
+
del model
|
130 |
+
del tokenizer
|
131 |
+
torch.cuda.empty_cache()
|
132 |
+
|
133 |
+
return l_sentences
|
134 |
+
|
135 |
+
def bloomz_compute_sentences(word, num=100):
|
136 |
+
l_sentences = []
|
137 |
+
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-7b1")
|
138 |
+
model = BloomForCausalLM.from_pretrained("bigscience/bloomz-7b1", device_map="auto", torch_dtype=torch.float16)
|
139 |
+
input_text = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters. Caption:"
|
140 |
+
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
|
141 |
+
input_length = input_ids.shape[1]
|
142 |
+
t = 0.95
|
143 |
+
eta = 1e-5
|
144 |
+
min_length = 15
|
145 |
+
|
146 |
+
while True:
|
147 |
+
try:
|
148 |
+
outputs = model.generate(input_ids,temperature=t, num_return_sequences=16, do_sample=True, max_length=128, min_length=min_length, eta_cutoff=eta)
|
149 |
+
output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
|
150 |
+
except:
|
151 |
+
continue
|
152 |
+
for line in output:
|
153 |
+
line = line.strip()
|
154 |
+
skip=False
|
155 |
+
for subword in word.split(" "):
|
156 |
+
if subword not in line: skip=True
|
157 |
+
if not skip: l_sentences.append(line)
|
158 |
+
else: l_sentences.append(line+f", {word}")
|
159 |
+
print(len(l_sentences))
|
160 |
+
if len(l_sentences)>=num:
|
161 |
+
break
|
162 |
+
l_sentences = clean_l_sentences(l_sentences)
|
163 |
+
del model
|
164 |
+
del tokenizer
|
165 |
+
torch.cuda.empty_cache()
|
166 |
+
|
167 |
+
return l_sentences
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
def make_custom_dir(description, sent_type, api_key, org_key, l_custom_sentences):
|
172 |
+
if sent_type=="fixed-template":
|
173 |
+
l_sentences = generate_image_prompts_with_templates(description)
|
174 |
+
elif "GPT3" in sent_type:
|
175 |
+
import openai
|
176 |
+
openai.organization = org_key
|
177 |
+
openai.api_key = api_key
|
178 |
+
_=openai.Model.retrieve("text-davinci-002")
|
179 |
+
l_sentences = gpt3_compute_word2sentences("object", description, num=1000)
|
180 |
+
|
181 |
+
elif "flan-t5-xl" in sent_type:
|
182 |
+
l_sentences = flant5xl_compute_word2sentences(description, num=1000)
|
183 |
+
# save the sentences to file
|
184 |
+
with open(f"tmp/flant5xl_sentences_{description}.txt", "w") as f:
|
185 |
+
for line in l_sentences:
|
186 |
+
f.write(line+"\n")
|
187 |
+
elif "BLOOMZ-7B" in sent_type:
|
188 |
+
l_sentences = bloomz_compute_sentences(description, num=1000)
|
189 |
+
# save the sentences to file
|
190 |
+
with open(f"tmp/bloomz_sentences_{description}.txt", "w") as f:
|
191 |
+
for line in l_sentences:
|
192 |
+
f.write(line+"\n")
|
193 |
+
|
194 |
+
elif sent_type=="custom sentences":
|
195 |
+
l_sentences = l_custom_sentences.split("\n")
|
196 |
+
print(f"length of new sentence is {len(l_sentences)}")
|
197 |
+
|
198 |
+
pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
|
199 |
+
emb = load_sentence_embeddings(l_sentences, pipe.tokenizer, pipe.text_encoder, device="cuda")
|
200 |
+
del pipe
|
201 |
+
torch.cuda.empty_cache()
|
202 |
+
return emb
|
203 |
+
|
204 |
+
|
205 |
+
def launch_main(img_in_real, img_in_synth, src, src_custom, dest, dest_custom, num_ddim, xa_guidance, edit_mul, fpath_z_gen, gen_prompt, sent_type_src, sent_type_dest, api_key, org_key, custom_sentences_src, custom_sentences_dest):
|
206 |
+
d_name2desc = get_all_directions_names()
|
207 |
+
d_desc2name = {v:k for k,v in d_name2desc.items()}
|
208 |
+
os.makedirs("tmp", exist_ok=True)
|
209 |
+
|
210 |
+
# generate custom direction first
|
211 |
+
if src=="make your own!":
|
212 |
+
outf_name = f"tmp/template_emb_{src_custom}_{sent_type_src}.pt"
|
213 |
+
if not os.path.exists(outf_name):
|
214 |
+
src_emb = make_custom_dir(src_custom, sent_type_src, api_key, org_key, custom_sentences_src)
|
215 |
+
torch.save(src_emb, outf_name)
|
216 |
+
else:
|
217 |
+
src_emb = torch.load(outf_name)
|
218 |
+
else:
|
219 |
+
src_emb = get_emb(d_desc2name[src])
|
220 |
+
|
221 |
+
if dest=="make your own!":
|
222 |
+
outf_name = f"tmp/template_emb_{dest_custom}_{sent_type_dest}.pt"
|
223 |
+
if not os.path.exists(outf_name):
|
224 |
+
dest_emb = make_custom_dir(dest_custom, sent_type_dest, api_key, org_key, custom_sentences_dest)
|
225 |
+
torch.save(dest_emb, outf_name)
|
226 |
+
else:
|
227 |
+
dest_emb = torch.load(outf_name)
|
228 |
+
else:
|
229 |
+
dest_emb = get_emb(d_desc2name[dest])
|
230 |
+
text_dir = (dest_emb.cuda() - src_emb.cuda())*edit_mul
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
if img_in_real is not None and img_in_synth is None:
|
235 |
+
print("using real image")
|
236 |
+
# resize the image so that the longer side is 512
|
237 |
+
width, height = img_in_real.size
|
238 |
+
if width > height: scale_factor = 512 / width
|
239 |
+
else: scale_factor = 512 / height
|
240 |
+
new_size = (int(width * scale_factor), int(height * scale_factor))
|
241 |
+
img_in_real = img_in_real.resize(new_size, Image.Resampling.LANCZOS)
|
242 |
+
hash = hashlib.sha256(img_in_real.tobytes()).hexdigest()
|
243 |
+
# print(hash)
|
244 |
+
inv_fname = f"tmp/{hash}_ddim_{num_ddim}_inv.pt"
|
245 |
+
caption_fname = f"tmp/{hash}_caption.txt"
|
246 |
+
|
247 |
+
# make the caption if it hasn't been made before
|
248 |
+
if not os.path.exists(caption_fname):
|
249 |
+
# BLIP
|
250 |
+
model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda"))
|
251 |
+
_image = vis_processors["eval"](img_in_real).unsqueeze(0).cuda()
|
252 |
+
prompt_str = model_blip.generate({"image": _image})[0]
|
253 |
+
del model_blip
|
254 |
+
torch.cuda.empty_cache()
|
255 |
+
with open(caption_fname, "w") as f:
|
256 |
+
f.write(prompt_str)
|
257 |
+
else:
|
258 |
+
prompt_str = open(caption_fname, "r").read().strip()
|
259 |
+
print(f"CAPTION: {prompt_str}")
|
260 |
+
|
261 |
+
# do the inversion if it hasn't been done before
|
262 |
+
if not os.path.exists(inv_fname):
|
263 |
+
# inversion pipeline
|
264 |
+
pipe_inv = DDIMInversion.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
|
265 |
+
pipe_inv.scheduler = DDIMInverseScheduler.from_config(pipe_inv.scheduler.config)
|
266 |
+
x_inv, x_inv_image, x_dec_img = pipe_inv( prompt_str,
|
267 |
+
guidance_scale=1, num_inversion_steps=num_ddim,
|
268 |
+
img=img_in_real, torch_dtype=torch.float32 )
|
269 |
+
x_inv = x_inv.detach()
|
270 |
+
torch.save(x_inv, inv_fname)
|
271 |
+
del pipe_inv
|
272 |
+
torch.cuda.empty_cache()
|
273 |
+
else:
|
274 |
+
x_inv = torch.load(inv_fname)
|
275 |
+
|
276 |
+
# do the editing
|
277 |
+
edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
|
278 |
+
edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config)
|
279 |
+
|
280 |
+
_, edit_pil = edit_pipe(prompt_str,
|
281 |
+
num_inference_steps=num_ddim,
|
282 |
+
x_in=x_inv,
|
283 |
+
edit_dir=text_dir,
|
284 |
+
guidance_amount=xa_guidance,
|
285 |
+
guidance_scale=5.0,
|
286 |
+
negative_prompt=prompt_str # use the unedited prompt for the negative prompt
|
287 |
+
)
|
288 |
+
del edit_pipe
|
289 |
+
torch.cuda.empty_cache()
|
290 |
+
return edit_pil[0]
|
291 |
+
|
292 |
+
|
293 |
+
elif img_in_real is None and img_in_synth is not None:
|
294 |
+
print("using synthetic image")
|
295 |
+
x_inv = torch.load(fpath_z_gen)
|
296 |
+
pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
|
297 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
298 |
+
rec_pil, edit_pil = pipe(gen_prompt,
|
299 |
+
num_inference_steps=num_ddim,
|
300 |
+
x_in=x_inv,
|
301 |
+
edit_dir=text_dir,
|
302 |
+
guidance_amount=xa_guidance,
|
303 |
+
guidance_scale=5,
|
304 |
+
negative_prompt="" # use the empty string for the negative prompt
|
305 |
+
)
|
306 |
+
del pipe
|
307 |
+
torch.cuda.empty_cache()
|
308 |
+
return edit_pil[0]
|
309 |
+
|
310 |
+
else:
|
311 |
+
raise ValueError(f"Invalid image type: {image_type}")
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
if __name__=="__main__":
|
316 |
+
print(flant5xl_compute_word2sentences("cat wearing sunglasses", num=100))
|
utils/gradio_utils.py
ADDED
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
def set_visible_true():
|
4 |
+
return gr.update(visible=True)
|
5 |
+
|
6 |
+
def set_visible_false():
|
7 |
+
return gr.update(visible=False)
|
8 |
+
|
9 |
+
|
10 |
+
# HTML_header = f"""
|
11 |
+
# <style>
|
12 |
+
# {CSS_main}
|
13 |
+
# </style>
|
14 |
+
# <div style="text-align: center; max-width: 700px; margin: 0 auto;">
|
15 |
+
# <div style=" display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;">
|
16 |
+
# <h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px;">
|
17 |
+
# Zero-shot Image-to-Image Translation
|
18 |
+
# </h1>
|
19 |
+
# </div>
|
20 |
+
# <p style="margin-bottom: 10px; font-size: 94%">
|
21 |
+
# This is the demo for <a href="https://pix2pixzero.github.io/" target="_blank">pix2pix-zero</a>.
|
22 |
+
# Please visit our <a href="https://pix2pixzero.github.io/"> website</a> and <a href="https://github.com/pix2pixzero/pix2pix-zero" target="_blank">github</a> for more details.
|
23 |
+
# </p>
|
24 |
+
# <p style="margin-bottom: 10px; font-size: 94%">
|
25 |
+
# pix2pix-zero is a diffusion-based image-to-image approach that allows users to specify the edit direction on-the-fly
|
26 |
+
# (e.g., cat to dog). Our method can directly use pre-trained text-to-image diffusion models, such as Stable Diffusion,
|
27 |
+
# for editing real and synthetic images while preserving the input image's structure. Our method is training-free and prompt-free,
|
28 |
+
# as it requires neither manual text prompting for each input image nor costly fine-tuning for each task.
|
29 |
+
# </p>
|
30 |
+
|
31 |
+
# </div>
|
32 |
+
# """
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
CSS_main = """
|
38 |
+
body {
|
39 |
+
font-family: "HelveticaNeue-Light", "Helvetica Neue Light", "Helvetica Neue", Helvetica, Arial, "Lucida Grande", sans-serif;
|
40 |
+
font-weight:300;
|
41 |
+
font-size:18px;
|
42 |
+
margin-left: auto;
|
43 |
+
margin-right: auto;
|
44 |
+
padding-left: 10px;
|
45 |
+
padding-right: 10px;
|
46 |
+
width: 800px;
|
47 |
+
}
|
48 |
+
|
49 |
+
h1 {
|
50 |
+
font-size:32px;
|
51 |
+
font-weight:300;
|
52 |
+
text-align: center;
|
53 |
+
}
|
54 |
+
|
55 |
+
h2 {
|
56 |
+
font-size:32px;
|
57 |
+
font-weight:300;
|
58 |
+
text-align: center;
|
59 |
+
}
|
60 |
+
|
61 |
+
#lbl_gallery_input{
|
62 |
+
font-family: 'Helvetica', 'Arial', sans-serif;
|
63 |
+
text-align: center;
|
64 |
+
color: #fff;
|
65 |
+
font-size: 28px;
|
66 |
+
display: inline
|
67 |
+
}
|
68 |
+
|
69 |
+
|
70 |
+
#lbl_gallery_comparision{
|
71 |
+
font-family: 'Helvetica', 'Arial', sans-serif;
|
72 |
+
text-align: center;
|
73 |
+
color: #fff;
|
74 |
+
font-size: 28px;
|
75 |
+
}
|
76 |
+
|
77 |
+
.disclaimerbox {
|
78 |
+
background-color: #eee;
|
79 |
+
border: 1px solid #eeeeee;
|
80 |
+
border-radius: 10px ;
|
81 |
+
-moz-border-radius: 10px ;
|
82 |
+
-webkit-border-radius: 10px ;
|
83 |
+
padding: 20px;
|
84 |
+
}
|
85 |
+
|
86 |
+
video.header-vid {
|
87 |
+
height: 140px;
|
88 |
+
border: 1px solid black;
|
89 |
+
border-radius: 10px ;
|
90 |
+
-moz-border-radius: 10px ;
|
91 |
+
-webkit-border-radius: 10px ;
|
92 |
+
}
|
93 |
+
|
94 |
+
img.header-img {
|
95 |
+
height: 140px;
|
96 |
+
border: 1px solid black;
|
97 |
+
border-radius: 10px ;
|
98 |
+
-moz-border-radius: 10px ;
|
99 |
+
-webkit-border-radius: 10px ;
|
100 |
+
}
|
101 |
+
|
102 |
+
img.rounded {
|
103 |
+
border: 1px solid #eeeeee;
|
104 |
+
border-radius: 10px ;
|
105 |
+
-moz-border-radius: 10px ;
|
106 |
+
-webkit-border-radius: 10px ;
|
107 |
+
}
|
108 |
+
|
109 |
+
a:link
|
110 |
+
{
|
111 |
+
color: #941120;
|
112 |
+
text-decoration: none;
|
113 |
+
}
|
114 |
+
a:visited
|
115 |
+
{
|
116 |
+
color: #941120;
|
117 |
+
text-decoration: none;
|
118 |
+
}
|
119 |
+
a:hover {
|
120 |
+
color: #941120;
|
121 |
+
}
|
122 |
+
|
123 |
+
td.dl-link {
|
124 |
+
height: 160px;
|
125 |
+
text-align: center;
|
126 |
+
font-size: 22px;
|
127 |
+
}
|
128 |
+
|
129 |
+
.layered-paper-big { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
|
130 |
+
box-shadow:
|
131 |
+
0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */
|
132 |
+
5px 5px 0 0px #fff, /* The second layer */
|
133 |
+
5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */
|
134 |
+
10px 10px 0 0px #fff, /* The third layer */
|
135 |
+
10px 10px 1px 1px rgba(0,0,0,0.35), /* The third layer shadow */
|
136 |
+
15px 15px 0 0px #fff, /* The fourth layer */
|
137 |
+
15px 15px 1px 1px rgba(0,0,0,0.35), /* The fourth layer shadow */
|
138 |
+
20px 20px 0 0px #fff, /* The fifth layer */
|
139 |
+
20px 20px 1px 1px rgba(0,0,0,0.35), /* The fifth layer shadow */
|
140 |
+
25px 25px 0 0px #fff, /* The fifth layer */
|
141 |
+
25px 25px 1px 1px rgba(0,0,0,0.35); /* The fifth layer shadow */
|
142 |
+
margin-left: 10px;
|
143 |
+
margin-right: 45px;
|
144 |
+
}
|
145 |
+
|
146 |
+
.paper-big { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
|
147 |
+
box-shadow:
|
148 |
+
0px 0px 1px 1px rgba(0,0,0,0.35); /* The top layer shadow */
|
149 |
+
|
150 |
+
margin-left: 10px;
|
151 |
+
margin-right: 45px;
|
152 |
+
}
|
153 |
+
|
154 |
+
|
155 |
+
.layered-paper { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
|
156 |
+
box-shadow:
|
157 |
+
0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */
|
158 |
+
5px 5px 0 0px #fff, /* The second layer */
|
159 |
+
5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */
|
160 |
+
10px 10px 0 0px #fff, /* The third layer */
|
161 |
+
10px 10px 1px 1px rgba(0,0,0,0.35); /* The third layer shadow */
|
162 |
+
margin-top: 5px;
|
163 |
+
margin-left: 10px;
|
164 |
+
margin-right: 30px;
|
165 |
+
margin-bottom: 5px;
|
166 |
+
}
|
167 |
+
|
168 |
+
.vert-cent {
|
169 |
+
position: relative;
|
170 |
+
top: 50%;
|
171 |
+
transform: translateY(-50%);
|
172 |
+
}
|
173 |
+
|
174 |
+
hr
|
175 |
+
{
|
176 |
+
border: 0;
|
177 |
+
height: 1px;
|
178 |
+
background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
|
179 |
+
}
|
180 |
+
|
181 |
+
.card {
|
182 |
+
/* width: 130px;
|
183 |
+
height: 195px;
|
184 |
+
width: 1px;
|
185 |
+
height: 1px; */
|
186 |
+
position: relative;
|
187 |
+
display: inline-block;
|
188 |
+
/* margin: 50px; */
|
189 |
+
}
|
190 |
+
.card .img-top {
|
191 |
+
display: none;
|
192 |
+
position: absolute;
|
193 |
+
top: 0;
|
194 |
+
left: 0;
|
195 |
+
z-index: 99;
|
196 |
+
}
|
197 |
+
.card:hover .img-top {
|
198 |
+
display: inline;
|
199 |
+
}
|
200 |
+
details {
|
201 |
+
user-select: none;
|
202 |
+
}
|
203 |
+
|
204 |
+
details>summary span.icon {
|
205 |
+
width: 24px;
|
206 |
+
height: 24px;
|
207 |
+
transition: all 0.3s;
|
208 |
+
margin-left: auto;
|
209 |
+
}
|
210 |
+
|
211 |
+
details[open] summary span.icon {
|
212 |
+
transform: rotate(180deg);
|
213 |
+
}
|
214 |
+
|
215 |
+
summary {
|
216 |
+
display: flex;
|
217 |
+
cursor: pointer;
|
218 |
+
}
|
219 |
+
|
220 |
+
summary::-webkit-details-marker {
|
221 |
+
display: none;
|
222 |
+
}
|
223 |
+
|
224 |
+
ul {
|
225 |
+
display: table;
|
226 |
+
margin: 0 auto;
|
227 |
+
text-align: left;
|
228 |
+
}
|
229 |
+
|
230 |
+
.dark {
|
231 |
+
padding: 1em 2em;
|
232 |
+
background-color: #333;
|
233 |
+
box-shadow: 3px 3px 3px #333;
|
234 |
+
border: 1px #333;
|
235 |
+
}
|
236 |
+
.column {
|
237 |
+
float: left;
|
238 |
+
width: 20%;
|
239 |
+
padding: 0.5%;
|
240 |
+
}
|
241 |
+
|
242 |
+
.galleryImg {
|
243 |
+
transition: opacity 0.3s;
|
244 |
+
-webkit-transition: opacity 0.3s;
|
245 |
+
filter: grayscale(100%);
|
246 |
+
/* filter: blur(2px); */
|
247 |
+
-webkit-transition : -webkit-filter 250ms linear;
|
248 |
+
/* opacity: 0.5; */
|
249 |
+
cursor: pointer;
|
250 |
+
}
|
251 |
+
|
252 |
+
|
253 |
+
|
254 |
+
.selected {
|
255 |
+
/* outline: 100px solid var(--hover-background) !important; */
|
256 |
+
/* outline-offset: -100px; */
|
257 |
+
filter: grayscale(0%);
|
258 |
+
-webkit-transition : -webkit-filter 250ms linear;
|
259 |
+
/*opacity: 1.0 !important; */
|
260 |
+
}
|
261 |
+
|
262 |
+
.galleryImg:hover {
|
263 |
+
filter: grayscale(0%);
|
264 |
+
-webkit-transition : -webkit-filter 250ms linear;
|
265 |
+
|
266 |
+
}
|
267 |
+
|
268 |
+
.row {
|
269 |
+
margin-bottom: 1em;
|
270 |
+
padding: 0px 1em;
|
271 |
+
}
|
272 |
+
/* Clear floats after the columns */
|
273 |
+
.row:after {
|
274 |
+
content: "";
|
275 |
+
display: table;
|
276 |
+
clear: both;
|
277 |
+
}
|
278 |
+
|
279 |
+
/* The expanding image container */
|
280 |
+
#gallery {
|
281 |
+
position: relative;
|
282 |
+
/*display: none;*/
|
283 |
+
}
|
284 |
+
|
285 |
+
#section_comparison{
|
286 |
+
position: relative;
|
287 |
+
width: 100%;
|
288 |
+
height: max-content;
|
289 |
+
}
|
290 |
+
|
291 |
+
/* SLIDER
|
292 |
+
-------------------------------------------------- */
|
293 |
+
|
294 |
+
.slider-container {
|
295 |
+
position: relative;
|
296 |
+
height: 384px;
|
297 |
+
width: 512px;
|
298 |
+
cursor: grab;
|
299 |
+
overflow: hidden;
|
300 |
+
margin: auto;
|
301 |
+
}
|
302 |
+
.slider-after {
|
303 |
+
display: block;
|
304 |
+
position: absolute;
|
305 |
+
top: 0;
|
306 |
+
right: 0;
|
307 |
+
bottom: 0;
|
308 |
+
left: 0;
|
309 |
+
width: 100%;
|
310 |
+
height: 100%;
|
311 |
+
overflow: hidden;
|
312 |
+
}
|
313 |
+
.slider-before {
|
314 |
+
display: block;
|
315 |
+
position: absolute;
|
316 |
+
top: 0;
|
317 |
+
/* right: 0; */
|
318 |
+
bottom: 0;
|
319 |
+
left: 0;
|
320 |
+
width: 100%;
|
321 |
+
height: 100%;
|
322 |
+
z-index: 15;
|
323 |
+
overflow: hidden;
|
324 |
+
}
|
325 |
+
.slider-before-inset {
|
326 |
+
position: absolute;
|
327 |
+
top: 0;
|
328 |
+
bottom: 0;
|
329 |
+
left: 0;
|
330 |
+
}
|
331 |
+
.slider-after img,
|
332 |
+
.slider-before img {
|
333 |
+
object-fit: cover;
|
334 |
+
position: absolute;
|
335 |
+
width: 100%;
|
336 |
+
height: 100%;
|
337 |
+
object-position: 50% 50%;
|
338 |
+
top: 0;
|
339 |
+
bottom: 0;
|
340 |
+
left: 0;
|
341 |
+
-webkit-user-select: none;
|
342 |
+
-khtml-user-select: none;
|
343 |
+
-moz-user-select: none;
|
344 |
+
-o-user-select: none;
|
345 |
+
user-select: none;
|
346 |
+
}
|
347 |
+
|
348 |
+
#lbl_inset_left{
|
349 |
+
text-align: center;
|
350 |
+
position: absolute;
|
351 |
+
top: 384px;
|
352 |
+
width: 150px;
|
353 |
+
left: calc(50% - 256px);
|
354 |
+
z-index: 11;
|
355 |
+
font-size: 16px;
|
356 |
+
color: #fff;
|
357 |
+
margin: 10px;
|
358 |
+
}
|
359 |
+
.inset-before {
|
360 |
+
position: absolute;
|
361 |
+
width: 150px;
|
362 |
+
height: 150px;
|
363 |
+
box-shadow: 3px 3px 3px #333;
|
364 |
+
border: 1px #333;
|
365 |
+
border-style: solid;
|
366 |
+
z-index: 16;
|
367 |
+
top: 410px;
|
368 |
+
left: calc(50% - 256px);
|
369 |
+
margin: 10px;
|
370 |
+
font-size: 1em;
|
371 |
+
background-repeat: no-repeat;
|
372 |
+
pointer-events: none;
|
373 |
+
}
|
374 |
+
|
375 |
+
#lbl_inset_right{
|
376 |
+
text-align: center;
|
377 |
+
position: absolute;
|
378 |
+
top: 384px;
|
379 |
+
width: 150px;
|
380 |
+
right: calc(50% - 256px);
|
381 |
+
z-index: 11;
|
382 |
+
font-size: 16px;
|
383 |
+
color: #fff;
|
384 |
+
margin: 10px;
|
385 |
+
}
|
386 |
+
.inset-after {
|
387 |
+
position: absolute;
|
388 |
+
width: 150px;
|
389 |
+
height: 150px;
|
390 |
+
box-shadow: 3px 3px 3px #333;
|
391 |
+
border: 1px #333;
|
392 |
+
border-style: solid;
|
393 |
+
z-index: 16;
|
394 |
+
top: 410px;
|
395 |
+
right: calc(50% - 256px);
|
396 |
+
margin: 10px;
|
397 |
+
font-size: 1em;
|
398 |
+
background-repeat: no-repeat;
|
399 |
+
pointer-events: none;
|
400 |
+
}
|
401 |
+
|
402 |
+
#lbl_inset_input{
|
403 |
+
text-align: center;
|
404 |
+
position: absolute;
|
405 |
+
top: 384px;
|
406 |
+
width: 150px;
|
407 |
+
left: calc(50% - 256px + 150px + 20px);
|
408 |
+
z-index: 11;
|
409 |
+
font-size: 16px;
|
410 |
+
color: #fff;
|
411 |
+
margin: 10px;
|
412 |
+
}
|
413 |
+
.inset-target {
|
414 |
+
position: absolute;
|
415 |
+
width: 150px;
|
416 |
+
height: 150px;
|
417 |
+
box-shadow: 3px 3px 3px #333;
|
418 |
+
border: 1px #333;
|
419 |
+
border-style: solid;
|
420 |
+
z-index: 16;
|
421 |
+
top: 410px;
|
422 |
+
right: calc(50% - 256px + 150px + 20px);
|
423 |
+
margin: 10px;
|
424 |
+
font-size: 1em;
|
425 |
+
background-repeat: no-repeat;
|
426 |
+
pointer-events: none;
|
427 |
+
}
|
428 |
+
|
429 |
+
.slider-beforePosition {
|
430 |
+
background: #121212;
|
431 |
+
color: #fff;
|
432 |
+
left: 0;
|
433 |
+
pointer-events: none;
|
434 |
+
border-radius: 0.2rem;
|
435 |
+
padding: 2px 10px;
|
436 |
+
}
|
437 |
+
.slider-afterPosition {
|
438 |
+
background: #121212;
|
439 |
+
color: #fff;
|
440 |
+
right: 0;
|
441 |
+
pointer-events: none;
|
442 |
+
border-radius: 0.2rem;
|
443 |
+
padding: 2px 10px;
|
444 |
+
}
|
445 |
+
.beforeLabel {
|
446 |
+
position: absolute;
|
447 |
+
top: 0;
|
448 |
+
margin: 1rem;
|
449 |
+
font-size: 1em;
|
450 |
+
-webkit-user-select: none;
|
451 |
+
-khtml-user-select: none;
|
452 |
+
-moz-user-select: none;
|
453 |
+
-o-user-select: none;
|
454 |
+
user-select: none;
|
455 |
+
}
|
456 |
+
.afterLabel {
|
457 |
+
position: absolute;
|
458 |
+
top: 0;
|
459 |
+
margin: 1rem;
|
460 |
+
font-size: 1em;
|
461 |
+
-webkit-user-select: none;
|
462 |
+
-khtml-user-select: none;
|
463 |
+
-moz-user-select: none;
|
464 |
+
-o-user-select: none;
|
465 |
+
user-select: none;
|
466 |
+
}
|
467 |
+
|
468 |
+
.slider-handle {
|
469 |
+
height: 41px;
|
470 |
+
width: 41px;
|
471 |
+
position: absolute;
|
472 |
+
left: 50%;
|
473 |
+
top: 50%;
|
474 |
+
margin-left: -20px;
|
475 |
+
margin-top: -21px;
|
476 |
+
border: 2px solid #fff;
|
477 |
+
border-radius: 1000px;
|
478 |
+
z-index: 20;
|
479 |
+
pointer-events: none;
|
480 |
+
box-shadow: 0 0 10px rgb(12, 12, 12);
|
481 |
+
}
|
482 |
+
.handle-left-arrow,
|
483 |
+
.handle-right-arrow {
|
484 |
+
width: 0;
|
485 |
+
height: 0;
|
486 |
+
border: 6px inset transparent;
|
487 |
+
position: absolute;
|
488 |
+
top: 50%;
|
489 |
+
margin-top: -6px;
|
490 |
+
}
|
491 |
+
.handle-left-arrow {
|
492 |
+
border-right: 6px solid #fff;
|
493 |
+
left: 50%;
|
494 |
+
margin-left: -17px;
|
495 |
+
}
|
496 |
+
.handle-right-arrow {
|
497 |
+
border-left: 6px solid #fff;
|
498 |
+
right: 50%;
|
499 |
+
margin-right: -17px;
|
500 |
+
}
|
501 |
+
.slider-handle::before {
|
502 |
+
bottom: 50%;
|
503 |
+
margin-bottom: 20px;
|
504 |
+
box-shadow: 0 0 10px rgb(12, 12, 12);
|
505 |
+
}
|
506 |
+
.slider-handle::after {
|
507 |
+
top: 50%;
|
508 |
+
margin-top: 20.5px;
|
509 |
+
box-shadow: 0 0 5px rgb(12, 12, 12);
|
510 |
+
}
|
511 |
+
.slider-handle::before,
|
512 |
+
.slider-handle::after {
|
513 |
+
content: " ";
|
514 |
+
display: block;
|
515 |
+
width: 2px;
|
516 |
+
background: #fff;
|
517 |
+
height: 9999px;
|
518 |
+
position: absolute;
|
519 |
+
left: 50%;
|
520 |
+
margin-left: -1.5px;
|
521 |
+
}
|
522 |
+
|
523 |
+
|
524 |
+
/*
|
525 |
+
-------------------------------------------------
|
526 |
+
The editing results shown below inversion results
|
527 |
+
-------------------------------------------------
|
528 |
+
*/
|
529 |
+
.edit_labels{
|
530 |
+
font-weight:500;
|
531 |
+
font-size: 24px;
|
532 |
+
color: #fff;
|
533 |
+
height: 20px;
|
534 |
+
margin-left: 20px;
|
535 |
+
position: relative;
|
536 |
+
top: 20px;
|
537 |
+
}
|
538 |
+
|
539 |
+
|
540 |
+
.open > a:hover {
|
541 |
+
color: #555;
|
542 |
+
background-color: red;
|
543 |
+
}
|
544 |
+
|
545 |
+
|
546 |
+
#directions { padding-top:30; padding-bottom:0; margin-bottom: 0px; height: 20px; }
|
547 |
+
#custom_task { padding-top:0; padding-bottom:0; margin-bottom: 0px; height: 20px; }
|
548 |
+
#slider_ddim {accent-color: #941120;}
|
549 |
+
#slider_ddim::-webkit-slider-thumb {background-color: #941120;}
|
550 |
+
#slider_xa {accent-color: #941120;}
|
551 |
+
#slider_xa::-webkit-slider-thumb {background-color: #941120;}
|
552 |
+
#slider_edit_mul {accent-color: #941120;}
|
553 |
+
#slider_edit_mul::-webkit-slider-thumb {background-color: #941120;}
|
554 |
+
|
555 |
+
#input_image [data-testid="image"]{
|
556 |
+
height: unset;
|
557 |
+
}
|
558 |
+
#input_image_synth [data-testid="image"]{
|
559 |
+
height: unset;
|
560 |
+
}
|
561 |
+
|
562 |
+
|
563 |
+
"""
|
564 |
+
|
565 |
+
|
566 |
+
|
567 |
+
HTML_header = f"""
|
568 |
+
<body>
|
569 |
+
<center>
|
570 |
+
<span style="font-size:36px">Zero-shot Image-to-Image Translation</span>
|
571 |
+
<table align=center>
|
572 |
+
<tr>
|
573 |
+
<td align=center>
|
574 |
+
<center>
|
575 |
+
<span style="font-size:24px; margin-left: 0px;"><a href='https://pix2pixzero.github.io/'>[Website]</a></span>
|
576 |
+
<span style="font-size:24px; margin-left: 20px;"><a href='https://github.com/pix2pixzero/pix2pix-zero'>[Code]</a></span>
|
577 |
+
</center>
|
578 |
+
</td>
|
579 |
+
</tr>
|
580 |
+
</table>
|
581 |
+
</center>
|
582 |
+
|
583 |
+
<center>
|
584 |
+
<div align=center>
|
585 |
+
<p align=left>
|
586 |
+
This is a demo for <span style="font-weight: bold;">pix2pix-zero</span>, a diffusion-based image-to-image approach that allows users to
|
587 |
+
specify the edit direction on-the-fly (e.g., cat to dog). Our method can directly use pre-trained text-to-image diffusion models, such as Stable Diffusion, for editing real and synthetic images while preserving the input image's structure. Our method is training-free and prompt-free, as it requires neither manual text prompting for each input image nor costly fine-tuning for each task.
|
588 |
+
<br>
|
589 |
+
<span style="font-weight: 800;">TL;DR:</span> <span style=" color: #941120;"> no finetuning</span> required; <span style=" color: #941120;"> no text input</span> needed; input <span style=" color: #941120;"> structure preserved</span>.
|
590 |
+
</p>
|
591 |
+
</div>
|
592 |
+
</center>
|
593 |
+
|
594 |
+
|
595 |
+
<hr>
|
596 |
+
</body>
|
597 |
+
"""
|
598 |
+
|
599 |
+
HTML_input_header = f"""
|
600 |
+
<p style="font-size:150%; padding: 0px">
|
601 |
+
<span font-weight: 800; style=" color: #941120;"> Step 1: </span> select a real input image.
|
602 |
+
</p>
|
603 |
+
"""
|
604 |
+
|
605 |
+
HTML_middle_header = f"""
|
606 |
+
<p style="font-size:150%;">
|
607 |
+
<span font-weight: 800; style=" color: #941120;"> Step 2: </span> select the editing options.
|
608 |
+
</p>
|
609 |
+
"""
|
610 |
+
|
611 |
+
|
612 |
+
HTML_output_header = f"""
|
613 |
+
<p style="font-size:150%;">
|
614 |
+
<span font-weight: 800; style=" color: #941120;"> Step 3: </span> translated image!
|
615 |
+
</p>
|
616 |
+
"""
|