pix2pix-zero commited on
Commit
b462bee
Β·
0 Parent(s):

commit message

Browse files
Files changed (39) hide show
  1. .gitattributes +34 -0
  2. .gitignore +3 -0
  3. README.md +13 -0
  4. __pycache__/utils.cpython-310.pyc +0 -0
  5. app.py +143 -0
  6. assets/test_images/cat_1.png +0 -0
  7. assets/test_images/cat_2.png +0 -0
  8. assets/test_images/cat_5.png +0 -0
  9. environment.yml +23 -0
  10. requirements.txt +7 -0
  11. submodules/pix2pix-zero/.gitignore +6 -0
  12. submodules/pix2pix-zero/LICENSE +21 -0
  13. submodules/pix2pix-zero/README.md +154 -0
  14. submodules/pix2pix-zero/environment.yml +23 -0
  15. submodules/pix2pix-zero/src/edit_real.py +65 -0
  16. submodules/pix2pix-zero/src/edit_synthetic.py +52 -0
  17. submodules/pix2pix-zero/src/inversion.py +64 -0
  18. submodules/pix2pix-zero/src/make_edit_direction.py +61 -0
  19. submodules/pix2pix-zero/src/utils/__pycache__/base_pipeline.cpython-310.pyc +0 -0
  20. submodules/pix2pix-zero/src/utils/__pycache__/cross_attention.cpython-310.pyc +0 -0
  21. submodules/pix2pix-zero/src/utils/__pycache__/ddim_inv.cpython-310.pyc +0 -0
  22. submodules/pix2pix-zero/src/utils/__pycache__/edit_directions.cpython-310.pyc +0 -0
  23. submodules/pix2pix-zero/src/utils/__pycache__/edit_pipeline.cpython-310.pyc +0 -0
  24. submodules/pix2pix-zero/src/utils/__pycache__/scheduler.cpython-310.pyc +0 -0
  25. submodules/pix2pix-zero/src/utils/base_pipeline.py +322 -0
  26. submodules/pix2pix-zero/src/utils/cross_attention.py +57 -0
  27. submodules/pix2pix-zero/src/utils/ddim_inv.py +140 -0
  28. submodules/pix2pix-zero/src/utils/edit_directions.py +29 -0
  29. submodules/pix2pix-zero/src/utils/edit_pipeline.py +179 -0
  30. submodules/pix2pix-zero/src/utils/scheduler.py +289 -0
  31. utils.py +0 -0
  32. utils/__init__.py +0 -0
  33. utils/__pycache__/__init__.cpython-310.pyc +0 -0
  34. utils/__pycache__/direction_utils.cpython-310.pyc +0 -0
  35. utils/__pycache__/generate_synthetic.cpython-310.pyc +0 -0
  36. utils/__pycache__/gradio_utils.cpython-310.pyc +0 -0
  37. utils/direction_utils.py +79 -0
  38. utils/generate_synthetic.py +316 -0
  39. utils/gradio_utils.py +616 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ tmp
2
+ app_fin_v1.py
3
+ __*.py
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Demo Temp
3
+ emoji: πŸ“ˆ
4
+ colorFrom: purple
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.18.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__pycache__/utils.cpython-310.pyc ADDED
Binary file (163 Bytes). View file
 
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ from utils.gradio_utils import *
7
+ from utils.direction_utils import *
8
+ from utils.generate_synthetic import *
9
+
10
+
11
+ if __name__=="__main__":
12
+
13
+ # populate the list of editing directions
14
+ d_name2desc = get_all_directions_names()
15
+ d_name2desc["make your own!"] = "make your own!"
16
+
17
+ with gr.Blocks(css=CSS_main) as demo:
18
+ # Make the header of the demo website
19
+ gr.HTML(HTML_header)
20
+
21
+ with gr.Row():
22
+ # col A: the input image or synthetic image prompt
23
+ with gr.Column(scale=2) as gc_left:
24
+ gr.HTML(" <center> <p style='font-size:150%;'> input </p> </center>")
25
+ img_in_real = gr.Image(type="pil", label="Start by uploading an image", elem_id="input_image")
26
+ img_in_synth = gr.Image(type="pil", label="Synthesized image", elem_id="input_image_synth", visible=False)
27
+ gr.Examples( examples="assets/test_images/", inputs=[img_in_real])
28
+ prompt = gr.Textbox(value="a high resolution painting of a cat in the style of van gogh", label="Or use a synthetic image. Prompt:", interactive=True)
29
+ with gr.Row():
30
+ seed = gr.Number(value=42, label="random seed:", interactive=True)
31
+ negative_guidance = gr.Number(value=5, label="negative guidance:", interactive=True)
32
+ btn_generate = gr.Button("Generate", label="")
33
+ fpath_z_gen = gr.Textbox(value="placeholder", visible=False)
34
+
35
+ # col B: the output image
36
+ with gr.Column(scale=2) as gc_left:
37
+ gr.HTML(" <center> <p style='font-size:150%;'> output </p> </center>")
38
+ img_out = gr.Image(type="pil", label="Output Image", visible=True)
39
+ with gr.Row():
40
+ with gr.Column():
41
+ src = gr.Dropdown(list(d_name2desc.values()), label="source", interactive=True, value="cat")
42
+ src_custom = gr.Textbox(placeholder="enter new task here!", interactive=True, visible=False, label="custom source direction:")
43
+ rad_src = gr.Radio(["GPT3", "flan-t5-xl (free)!", "BLOOMZ-7B (free)!", "fixed-template", "custom sentences"], label="Sentence type:", value="GPT3", interactive=True, visible=False)
44
+ custom_sentences_src = gr.Textbox(placeholder="paste list of sentences here", interactive=True, visible=False, label="custom sentences:", lines=5, max_lines=20)
45
+
46
+
47
+ with gr.Column():
48
+ dest = gr.Dropdown(list(d_name2desc.values()), label="target", interactive=True, value="dog")
49
+ dest_custom = gr.Textbox(placeholder="enter new task here!", interactive=True, visible=False, label="custom target direction:")
50
+ rad_dest = gr.Radio(["GPT3", "flan-t5-xl (free)!", "BLOOMZ-7B (free)!", "fixed-template", "custom sentences"], label="Sentence type:", value="GPT3", interactive=True, visible=False)
51
+ custom_sentences_dest = gr.Textbox(placeholder="paste list of sentences here", interactive=True, visible=False, label="custom sentences:", lines=5, max_lines=20)
52
+
53
+
54
+ with gr.Row():
55
+ api_key = gr.Textbox(placeholder="enter you OpenAI API key here", interactive=True, visible=False, label="OpenAI API key:", type="password")
56
+ org_key = gr.Textbox(placeholder="enter you OpenAI organization key here", interactive=True, visible=False, label="OpenAI Organization:", type="password")
57
+ with gr.Row():
58
+ btn_edit = gr.Button("Run", label="")
59
+ btn_clear = gr.Button("Clear")
60
+
61
+ with gr.Accordion("Change editing settings?", open=False):
62
+ num_ddim = gr.Slider(0, 200, 100, label="Number of DDIM steps", interactive=True, elem_id="slider_ddim", step=10)
63
+ xa_guidance = gr.Slider(0, 0.25, 0.1, label="Cross Attention guidance", interactive=True, elem_id="slider_xa", step=0.01)
64
+ edit_mul = gr.Slider(0, 2, 1.0, label="Edit multiplier", interactive=True, elem_id="slider_edit_mul", step=0.05)
65
+
66
+ with gr.Accordion("Generating your own directions", open=False):
67
+ gr.Textbox("We provide 5 different ways of computing new custom directions:", show_label=False)
68
+ gr.Textbox("We use GPT3 to generate a list of sentences that describe the desired edit. For this options, the users need to make an OpenAI account and enter the API and organizations keys. This option typically results is the best directions and costs roughly $0.14 for one concept.", label="1. GPT3", show_label=True)
69
+ gr.Textbox("Alternatively flan-t5-xl model can also be used to to generate a list of sentences that describe the desired edit. This option is free and does not require creating any new accounts.", label="2. flan-t5-xl (free)", show_label=True)
70
+ gr.Textbox("Similarly BLOOMZ-7B model can also be used to to generate the sentences for free.", label="3. BLOOMZ-7B (free)", show_label=True)
71
+ gr.Textbox("Next, we provide a fixed template based sentence generation. This option does not require any language model and is therefore free and much faster. However the edit directions with this method are often entangled.", label="4. Fixed template", show_label=True)
72
+ gr.Textbox("Finally, the user can also generate their own sentences.", label="5. Custom sentences", show_label=True)
73
+
74
+
75
+ with gr.Accordion("Tips for getting better results", open=True):
76
+ gr.Textbox("The 'Cross Attention guidance' controls the amount of structure guidance to be applied when performing the edit. If the output edited image does not retain the structure from the input, increasing the value will typically address the issue. We recommend changing the value in increments of 0.05.", label="1. Controlling the image structure", show_label=True)
77
+ gr.Textbox("If the output image quality is low or has some artifacts, using more steps would be helpful. This can be controlled with the 'Number of DDIM steps' slider.", label="2. Improving Image Quality", show_label=True)
78
+ gr.Textbox("There can be two reasons why the output image does not have the desired edit applied. Either the cross attention guidance is too strong, or the edit is insufficient. These can be addressed by reducing the 'Cross Attention guidance' slider or increasing the 'Edit multiplier' respectively.", label="3. Amount of edit applied", show_label=True)
79
+
80
+
81
+
82
+
83
+
84
+
85
+ # txt_image_type = gr.Textbox(visible=False)
86
+ btn_generate.click(launch_generate_sample, [prompt, seed, negative_guidance, num_ddim], [img_in_synth, fpath_z_gen])
87
+ btn_generate.click(set_visible_true, [], img_in_synth)
88
+ btn_generate.click(set_visible_false, [], img_in_real)
89
+
90
+ def fn_clear_all():
91
+ return gr.update(value=None), gr.update(value=None), gr.update(value=None)
92
+ btn_clear.click(fn_clear_all, [], [img_out, img_in_real, img_in_synth])
93
+ btn_clear.click(set_visible_true, [], img_in_real)
94
+ btn_clear.click(set_visible_false, [], img_in_synth)
95
+
96
+
97
+ # handling custom directions
98
+ def on_custom_seleceted(src):
99
+ if src=="make your own!": return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
100
+ else: return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
101
+
102
+ src.change(on_custom_seleceted, [src], [src_custom, rad_src, api_key, org_key])
103
+ dest.change(on_custom_seleceted, [dest], [dest_custom, rad_dest, api_key, org_key])
104
+
105
+
106
+ def fn_sentence_type_change(rad):
107
+ print(rad)
108
+ if rad=="GPT3":
109
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
110
+ elif rad=="custom sentences":
111
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
112
+ else:
113
+ print("using template sentence or flan-t5-xl or bloomz-7b")
114
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
115
+
116
+ rad_dest.change(fn_sentence_type_change, [rad_dest], [api_key, org_key, custom_sentences_dest])
117
+ rad_src.change(fn_sentence_type_change, [rad_src], [api_key, org_key, custom_sentences_src])
118
+
119
+ btn_edit.click(launch_main,
120
+ [
121
+ img_in_real, img_in_synth,
122
+ src, src_custom, dest,
123
+ dest_custom, num_ddim,
124
+ xa_guidance, edit_mul,
125
+ fpath_z_gen, prompt,
126
+ rad_src, rad_dest,
127
+ api_key, org_key,
128
+ custom_sentences_src, custom_sentences_dest
129
+ ],
130
+ [img_out])
131
+
132
+
133
+
134
+ gr.HTML("<hr>")
135
+
136
+
137
+ demo.queue(concurrency_count=8)
138
+ demo.launch(debug=True)
139
+
140
+ # gr.close_all()
141
+ # demo.launch(server_port=8089, server_name="0.0.0.0", debug=True)
142
+
143
+
assets/test_images/cat_1.png ADDED
assets/test_images/cat_2.png ADDED
assets/test_images/cat_5.png ADDED
environment.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pix2pix-zero
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - pip
8
+ - pytorch-cuda=11.6
9
+ - torchvision
10
+ - pytorch
11
+ - pip:
12
+ - accelerate
13
+ - diffusers
14
+ - einops
15
+ - gradio
16
+ - ipython
17
+ - numpy
18
+ - opencv-python-headless
19
+ - pillow
20
+ - psutil
21
+ - tqdm
22
+ - transformers
23
+ - salesforce-lavis
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ joblib
3
+ accelerate
4
+ diffusers==0.12.1
5
+ salesforce-lavis
6
+ openai
7
+ #git+https://github.com/pix2pixzero/pix2pix-zero.git
submodules/pix2pix-zero/.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ output
2
+ scripts
3
+ src/folder_*.py
4
+ src/ig_*.py
5
+ assets/edit_sentences
6
+ src/utils/edit_pipeline_spatial.py
submodules/pix2pix-zero/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 pix2pixzero
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
submodules/pix2pix-zero/README.md ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pix2pix-zero
2
+
3
+ ## [**[website]**](https://pix2pixzero.github.io/)
4
+
5
+
6
+ This is author's reimplementation of "Zero-shot Image-to-Image Translation" using the diffusers library. <br>
7
+ The results in the paper are based on the [CompVis](https://github.com/CompVis/stable-diffusion) library, which will be released later.
8
+
9
+ **[New!]** Code for editing real and synthetic images released!
10
+
11
+
12
+
13
+ <br>
14
+ <div class="gif">
15
+ <p align="center">
16
+ <img src='assets/main.gif' align="center">
17
+ </p>
18
+ </div>
19
+
20
+
21
+ We propose pix2pix-zero, a diffusion-based image-to-image approach that allows users to specify the edit direction on-the-fly (e.g., cat to dog). Our method can directly use pre-trained [Stable Diffusion](https://github.com/CompVis/stable-diffusion), for editing real and synthetic images while preserving the input image's structure. Our method is training-free and prompt-free, as it requires neither manual text prompting for each input image nor costly fine-tuning for each task.
22
+
23
+ **TL;DR**: no finetuning required, no text input needed, input structure preserved.
24
+
25
+ ## Results
26
+ All our results are based on [stable-diffusion-v1-4](https://github.com/CompVis/stable-diffusion) model. Please the website for more results.
27
+
28
+ <div>
29
+ <p align="center">
30
+ <img src='assets/results_teaser.jpg' align="center" width=800px>
31
+ </p>
32
+ </div>
33
+ <hr>
34
+
35
+ The top row for each of the results below show editing of real images, and the bottom row shows synthetic image editing.
36
+ <div>
37
+ <p align="center">
38
+ <img src='assets/grid_dog2cat.jpg' align="center" width=800px>
39
+ </p>
40
+ <p align="center">
41
+ <img src='assets/grid_zebra2horse.jpg' align="center" width=800px>
42
+ </p>
43
+ <p align="center">
44
+ <img src='assets/grid_cat2dog.jpg' align="center" width=800px>
45
+ </p>
46
+ <p align="center">
47
+ <img src='assets/grid_horse2zebra.jpg' align="center" width=800px>
48
+ </p>
49
+ <p align="center">
50
+ <img src='assets/grid_tree2fall.jpg' align="center" width=800px>
51
+ </p>
52
+ </div>
53
+
54
+ ## Real Image Editing
55
+ <div>
56
+ <p align="center">
57
+ <img src='assets/results_real.jpg' align="center" width=800px>
58
+ </p>
59
+ </div>
60
+
61
+ ## Synthetic Image Editing
62
+ <div>
63
+ <p align="center">
64
+ <img src='assets/results_syn.jpg' align="center" width=800px>
65
+ </p>
66
+ </div>
67
+
68
+ ## Method Details
69
+
70
+ Given an input image, we first generate text captions using [BLIP](https://github.com/salesforce/LAVIS) and apply regularized DDIM inversion to obtain our inverted noise map.
71
+ Then, we obtain reference cross-attention maps that correspoind to the structure of the input image by denoising, guided with the CLIP embeddings
72
+ of our generated text (c). Next, we denoise with edited text embeddings, while enforcing a loss to match current cross-attention maps with the
73
+ reference cross-attention maps.
74
+
75
+ <div>
76
+ <p align="center">
77
+ <img src='assets/method.jpeg' align="center" width=900>
78
+ </p>
79
+ </div>
80
+
81
+
82
+ ## Getting Started
83
+
84
+ **Environment Setup**
85
+ - We provide a [conda env file](environment.yml) that contains all the required dependencies
86
+ ```
87
+ conda env create -f environment.yml
88
+ ```
89
+ - Following this, you can activate the conda environment with the command below.
90
+ ```
91
+ conda activate pix2pix-zero
92
+ ```
93
+
94
+ **Real Image Translation**
95
+ - First, run the inversion command below to obtain the input noise that reconstructs the image.
96
+ The command below will save the inversion in the results folder as `output/test_cat/inversion/cat_1.pt`
97
+ and the BLIP-generated prompt as `output/test_cat/prompt/cat_1.txt`
98
+ ```
99
+ python src/inversion.py \
100
+ --input_image "assets/test_images/cats/cat_1.png" \
101
+ --results_folder "output/test_cat"
102
+ ```
103
+ - Next, we can perform image editing with the editing direction as shown below.
104
+ The command below will save the edited image as `output/test_cat/edit/cat_1.png`
105
+ ```
106
+ python src/edit_real.py \
107
+ --inversion "output/test_cat/inversion/cat_1.pt" \
108
+ --prompt "output/test_cat/prompt/cat_1.txt" \
109
+ --task_name "cat2dog" \
110
+ --results_folder "output/test_cat/"
111
+ ```
112
+
113
+ **Editing Synthetic Images**
114
+ - Similarly, we can edit the synthetic images generated by Stable Diffusion with the following command.
115
+ ```
116
+ python src/edit_synthetic.py \
117
+ --results_folder "output/synth_editing" \
118
+ --prompt_str "a high resolution painting of a cat in the style of van gough" \
119
+ --task "cat2dog"
120
+ ```
121
+
122
+ ### **Tips and Debugging**
123
+ - **Controlling the Image Structure:**<br>
124
+ The `--xa_guidance` flag controls the amount of cross-attention guidance to be applied when performing the edit. If the output edited image does not retain the structure from the input, increasing the value will typically address the issue. We recommend changing the value in increments of 0.05.
125
+
126
+ - **Improving Image Quality:**<br>
127
+ If the output image quality is low or has some artifacts, using more steps for both the inversion and editing would be helpful.
128
+ This can be controlled with the `--num_ddim_steps` flag.
129
+
130
+ - **Reducing the VRAM Requirements:**<br>
131
+ We can reduce the VRAM requirements using lower precision and setting the flag `--use_float_16`.
132
+
133
+ <br>
134
+
135
+ **Finding Custom Edit Directions**<br>
136
+ - We provide some pre-computed directions in the assets [folder](assets/embeddings_sd_1.4).
137
+ To generate new edit directions, users can first generate two files containing a large number of sentences (~1000) and then run the command as shown below.
138
+ ```
139
+ python src/make_edit_direction.py \
140
+ --file_source_sentences sentences/apple.txt \
141
+ --file_target_sentences sentences/orange.txt \
142
+ --output_folder assets/embeddings_sd_1.4
143
+ ```
144
+ - After running the above command, you can set the flag `--task apple2orange` for the new edit.
145
+
146
+
147
+
148
+ ## Comparison
149
+ Comparisons with different baselines, including, SDEdit + word swap, DDIM + word swap, and prompt-to-propmt. Our method successfully applies the edit, while preserving the structure of the input image.
150
+ <div>
151
+ <p align="center">
152
+ <img src='assets/comparison.jpg' align="center" width=900>
153
+ </p>
154
+ </div>
submodules/pix2pix-zero/environment.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pix2pix-zero
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - pip
8
+ - pytorch-cuda=11.6
9
+ - torchvision
10
+ - pytorch
11
+ - pip:
12
+ - accelerate
13
+ - diffusers
14
+ - einops
15
+ - gradio
16
+ - ipython
17
+ - numpy
18
+ - opencv-python-headless
19
+ - pillow
20
+ - psutil
21
+ - tqdm
22
+ - transformers
23
+ - salesforce-lavis
submodules/pix2pix-zero/src/edit_real.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from diffusers import DDIMScheduler
10
+ from utils.ddim_inv import DDIMInversion
11
+ from utils.edit_directions import construct_direction
12
+ from utils.edit_pipeline import EditingPipeline
13
+
14
+
15
+ if __name__=="__main__":
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--inversion', required=True)
18
+ parser.add_argument('--prompt', type=str, required=True)
19
+ parser.add_argument('--task_name', type=str, default='cat2dog')
20
+ parser.add_argument('--results_folder', type=str, default='output/test_cat')
21
+ parser.add_argument('--num_ddim_steps', type=int, default=50)
22
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
23
+ parser.add_argument('--xa_guidance', default=0.1, type=float)
24
+ parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
25
+ parser.add_argument('--use_float_16', action='store_true')
26
+
27
+ args = parser.parse_args()
28
+
29
+ os.makedirs(os.path.join(args.results_folder, "edit"), exist_ok=True)
30
+ os.makedirs(os.path.join(args.results_folder, "reconstruction"), exist_ok=True)
31
+
32
+ if args.use_float_16:
33
+ torch_dtype = torch.float16
34
+ else:
35
+ torch_dtype = torch.float32
36
+
37
+ # if the inversion is a folder, the prompt should also be a folder
38
+ assert (os.path.isdir(args.inversion)==os.path.isdir(args.prompt)), "If the inversion is a folder, the prompt should also be a folder"
39
+ if os.path.isdir(args.inversion):
40
+ l_inv_paths = sorted(glob(os.path.join(args.inversion, "*.pt")))
41
+ l_bnames = [os.path.basename(x) for x in l_inv_paths]
42
+ l_prompt_paths = [os.path.join(args.prompt, x.replace(".pt",".txt")) for x in l_bnames]
43
+ else:
44
+ l_inv_paths = [args.inversion]
45
+ l_prompt_paths = [args.prompt]
46
+
47
+ # Make the editing pipeline
48
+ pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
49
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
50
+
51
+
52
+ for inv_path, prompt_path in zip(l_inv_paths, l_prompt_paths):
53
+ prompt_str = open(prompt_path).read().strip()
54
+ rec_pil, edit_pil = pipe(prompt_str,
55
+ num_inference_steps=args.num_ddim_steps,
56
+ x_in=torch.load(inv_path).unsqueeze(0),
57
+ edit_dir=construct_direction(args.task_name),
58
+ guidance_amount=args.xa_guidance,
59
+ guidance_scale=args.negative_guidance_scale,
60
+ negative_prompt=prompt_str # use the unedited prompt for the negative prompt
61
+ )
62
+
63
+ bname = os.path.basename(args.inversion).split(".")[0]
64
+ edit_pil[0].save(os.path.join(args.results_folder, f"edit/{bname}.png"))
65
+ rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction/{bname}.png"))
submodules/pix2pix-zero/src/edit_synthetic.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from diffusers import DDIMScheduler
10
+ from utils.edit_directions import construct_direction
11
+ from utils.edit_pipeline import EditingPipeline
12
+
13
+
14
+ if __name__=="__main__":
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--prompt_str', type=str, required=True)
17
+ parser.add_argument('--random_seed', default=0)
18
+ parser.add_argument('--task_name', type=str, default='cat2dog')
19
+ parser.add_argument('--results_folder', type=str, default='output/test_cat')
20
+ parser.add_argument('--num_ddim_steps', type=int, default=50)
21
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
22
+ parser.add_argument('--xa_guidance', default=0.15, type=float)
23
+ parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
24
+ parser.add_argument('--use_float_16', action='store_true')
25
+ args = parser.parse_args()
26
+
27
+ os.makedirs(args.results_folder, exist_ok=True)
28
+
29
+ if args.use_float_16:
30
+ torch_dtype = torch.float16
31
+ else:
32
+ torch_dtype = torch.float32
33
+
34
+ # make the input noise map
35
+ torch.cuda.manual_seed(args.random_seed)
36
+ x = torch.randn((1,4,64,64), device="cuda")
37
+
38
+ # Make the editing pipeline
39
+ pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
40
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
41
+
42
+ rec_pil, edit_pil = pipe(args.prompt_str,
43
+ num_inference_steps=args.num_ddim_steps,
44
+ x_in=x,
45
+ edit_dir=construct_direction(args.task_name),
46
+ guidance_amount=args.xa_guidance,
47
+ guidance_scale=args.negative_guidance_scale,
48
+ negative_prompt="" # use the empty string for the negative prompt
49
+ )
50
+
51
+ edit_pil[0].save(os.path.join(args.results_folder, f"edit.png"))
52
+ rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction.png"))
submodules/pix2pix-zero/src/inversion.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from lavis.models import load_model_and_preprocess
10
+
11
+ from utils.ddim_inv import DDIMInversion
12
+ from utils.scheduler import DDIMInverseScheduler
13
+
14
+ if __name__=="__main__":
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--input_image', type=str, default='assets/test_images/cat_a.png')
17
+ parser.add_argument('--results_folder', type=str, default='output/test_cat')
18
+ parser.add_argument('--num_ddim_steps', type=int, default=50)
19
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
20
+ parser.add_argument('--use_float_16', action='store_true')
21
+ args = parser.parse_args()
22
+
23
+ # make the output folders
24
+ os.makedirs(os.path.join(args.results_folder, "inversion"), exist_ok=True)
25
+ os.makedirs(os.path.join(args.results_folder, "prompt"), exist_ok=True)
26
+
27
+ if args.use_float_16:
28
+ torch_dtype = torch.float16
29
+ else:
30
+ torch_dtype = torch.float32
31
+
32
+
33
+ # load the BLIP model
34
+ model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda"))
35
+ # make the DDIM inversion pipeline
36
+ pipe = DDIMInversion.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
37
+ pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
38
+
39
+
40
+ # if the input is a folder, collect all the images as a list
41
+ if os.path.isdir(args.input_image):
42
+ l_img_paths = sorted(glob(os.path.join(args.input_image, "*.png")))
43
+ else:
44
+ l_img_paths = [args.input_image]
45
+
46
+
47
+ for img_path in l_img_paths:
48
+ bname = os.path.basename(args.input_image).split(".")[0]
49
+ img = Image.open(args.input_image).resize((512,512), Image.Resampling.LANCZOS)
50
+ # generate the caption
51
+ _image = vis_processors["eval"](img).unsqueeze(0).cuda()
52
+ prompt_str = model_blip.generate({"image": _image})[0]
53
+ x_inv, x_inv_image, x_dec_img = pipe(
54
+ prompt_str,
55
+ guidance_scale=1,
56
+ num_inversion_steps=args.num_ddim_steps,
57
+ img=img,
58
+ torch_dtype=torch_dtype
59
+ )
60
+ # save the inversion
61
+ torch.save(x_inv[0], os.path.join(args.results_folder, f"inversion/{bname}.pt"))
62
+ # save the prompt string
63
+ with open(os.path.join(args.results_folder, f"prompt/{bname}.txt"), "w") as f:
64
+ f.write(prompt_str)
submodules/pix2pix-zero/src/make_edit_direction.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from diffusers import DDIMScheduler
10
+ from utils.edit_pipeline import EditingPipeline
11
+
12
+
13
+ ## convert sentences to sentence embeddings
14
+ def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
15
+ with torch.no_grad():
16
+ l_embeddings = []
17
+ for sent in l_sentences:
18
+ text_inputs = tokenizer(
19
+ sent,
20
+ padding="max_length",
21
+ max_length=tokenizer.model_max_length,
22
+ truncation=True,
23
+ return_tensors="pt",
24
+ )
25
+ text_input_ids = text_inputs.input_ids
26
+ prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
27
+ l_embeddings.append(prompt_embeds)
28
+ return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
29
+
30
+
31
+ if __name__=="__main__":
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument('--file_source_sentences', required=True)
34
+ parser.add_argument('--file_target_sentences', required=True)
35
+ parser.add_argument('--output_folder', required=True)
36
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
37
+ args = parser.parse_args()
38
+
39
+ # load the model
40
+ pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch.float16).to("cuda")
41
+ bname_src = os.path.basename(args.file_source_sentences).strip(".txt")
42
+ outf_src = os.path.join(args.output_folder, bname_src+".pt")
43
+ if os.path.exists(outf_src):
44
+ print(f"Skipping source file {outf_src} as it already exists")
45
+ else:
46
+ with open(args.file_source_sentences, "r") as f:
47
+ l_sents = [x.strip() for x in f.readlines()]
48
+ mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
49
+ print(mean_emb.shape)
50
+ torch.save(mean_emb, outf_src)
51
+
52
+ bname_tgt = os.path.basename(args.file_target_sentences).strip(".txt")
53
+ outf_tgt = os.path.join(args.output_folder, bname_tgt+".pt")
54
+ if os.path.exists(outf_tgt):
55
+ print(f"Skipping target file {outf_tgt} as it already exists")
56
+ else:
57
+ with open(args.file_target_sentences, "r") as f:
58
+ l_sents = [x.strip() for x in f.readlines()]
59
+ mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
60
+ print(mean_emb.shape)
61
+ torch.save(mean_emb, outf_tgt)
submodules/pix2pix-zero/src/utils/__pycache__/base_pipeline.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
submodules/pix2pix-zero/src/utils/__pycache__/cross_attention.cpython-310.pyc ADDED
Binary file (1.43 kB). View file
 
submodules/pix2pix-zero/src/utils/__pycache__/ddim_inv.cpython-310.pyc ADDED
Binary file (4.4 kB). View file
 
submodules/pix2pix-zero/src/utils/__pycache__/edit_directions.cpython-310.pyc ADDED
Binary file (693 Bytes). View file
 
submodules/pix2pix-zero/src/utils/__pycache__/edit_pipeline.cpython-310.pyc ADDED
Binary file (4.12 kB). View file
 
submodules/pix2pix-zero/src/utils/__pycache__/scheduler.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
submodules/pix2pix-zero/src/utils/base_pipeline.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import inspect
4
+ from packaging import version
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
8
+ from diffusers import DiffusionPipeline
9
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
10
+ from diffusers.schedulers import KarrasDiffusionSchedulers
11
+ from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
12
+ from diffusers import StableDiffusionPipeline
13
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
+
15
+
16
+
17
+ class BasePipeline(DiffusionPipeline):
18
+ _optional_components = ["safety_checker", "feature_extractor"]
19
+ def __init__(
20
+ self,
21
+ vae: AutoencoderKL,
22
+ text_encoder: CLIPTextModel,
23
+ tokenizer: CLIPTokenizer,
24
+ unet: UNet2DConditionModel,
25
+ scheduler: KarrasDiffusionSchedulers,
26
+ safety_checker: StableDiffusionSafetyChecker,
27
+ feature_extractor: CLIPFeatureExtractor,
28
+ requires_safety_checker: bool = True,
29
+ ):
30
+ super().__init__()
31
+
32
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
33
+ deprecation_message = (
34
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
35
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
36
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
37
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
38
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
39
+ " file"
40
+ )
41
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
42
+ new_config = dict(scheduler.config)
43
+ new_config["steps_offset"] = 1
44
+ scheduler._internal_dict = FrozenDict(new_config)
45
+
46
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
47
+ deprecation_message = (
48
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
49
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
50
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
51
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
52
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
53
+ )
54
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
55
+ new_config = dict(scheduler.config)
56
+ new_config["clip_sample"] = False
57
+ scheduler._internal_dict = FrozenDict(new_config)
58
+
59
+ if safety_checker is None and requires_safety_checker:
60
+ logger.warning(
61
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
62
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
63
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
64
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
65
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
66
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
67
+ )
68
+
69
+ if safety_checker is not None and feature_extractor is None:
70
+ raise ValueError(
71
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
72
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
73
+ )
74
+
75
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
76
+ version.parse(unet.config._diffusers_version).base_version
77
+ ) < version.parse("0.9.0.dev0")
78
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
79
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
80
+ deprecation_message = (
81
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
82
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
83
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
84
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
85
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
86
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
87
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
88
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
89
+ " the `unet/config.json` file"
90
+ )
91
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
92
+ new_config = dict(unet.config)
93
+ new_config["sample_size"] = 64
94
+ unet._internal_dict = FrozenDict(new_config)
95
+
96
+ self.register_modules(
97
+ vae=vae,
98
+ text_encoder=text_encoder,
99
+ tokenizer=tokenizer,
100
+ unet=unet,
101
+ scheduler=scheduler,
102
+ safety_checker=safety_checker,
103
+ feature_extractor=feature_extractor,
104
+ )
105
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
106
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
107
+
108
+ @property
109
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
110
+ def _execution_device(self):
111
+ r"""
112
+ Returns the device on which the pipeline's models will be executed. After calling
113
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
114
+ hooks.
115
+ """
116
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
117
+ return self.device
118
+ for module in self.unet.modules():
119
+ if (
120
+ hasattr(module, "_hf_hook")
121
+ and hasattr(module._hf_hook, "execution_device")
122
+ and module._hf_hook.execution_device is not None
123
+ ):
124
+ return torch.device(module._hf_hook.execution_device)
125
+ return self.device
126
+
127
+
128
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
129
+ def _encode_prompt(
130
+ self,
131
+ prompt,
132
+ device,
133
+ num_images_per_prompt,
134
+ do_classifier_free_guidance,
135
+ negative_prompt=None,
136
+ prompt_embeds: Optional[torch.FloatTensor] = None,
137
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
138
+ ):
139
+ r"""
140
+ Encodes the prompt into text encoder hidden states.
141
+
142
+ Args:
143
+ prompt (`str` or `List[str]`, *optional*):
144
+ prompt to be encoded
145
+ device: (`torch.device`):
146
+ torch device
147
+ num_images_per_prompt (`int`):
148
+ number of images that should be generated per prompt
149
+ do_classifier_free_guidance (`bool`):
150
+ whether to use classifier free guidance or not
151
+ negative_ prompt (`str` or `List[str]`, *optional*):
152
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
153
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
154
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
155
+ prompt_embeds (`torch.FloatTensor`, *optional*):
156
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
157
+ provided, text embeddings will be generated from `prompt` input argument.
158
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
159
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
160
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
161
+ argument.
162
+ """
163
+ if prompt is not None and isinstance(prompt, str):
164
+ batch_size = 1
165
+ elif prompt is not None and isinstance(prompt, list):
166
+ batch_size = len(prompt)
167
+ else:
168
+ batch_size = prompt_embeds.shape[0]
169
+
170
+ if prompt_embeds is None:
171
+ text_inputs = self.tokenizer(
172
+ prompt,
173
+ padding="max_length",
174
+ max_length=self.tokenizer.model_max_length,
175
+ truncation=True,
176
+ return_tensors="pt",
177
+ )
178
+ text_input_ids = text_inputs.input_ids
179
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
180
+
181
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
182
+ text_input_ids, untruncated_ids
183
+ ):
184
+ removed_text = self.tokenizer.batch_decode(
185
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
186
+ )
187
+ logger.warning(
188
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
189
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
190
+ )
191
+
192
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
193
+ attention_mask = text_inputs.attention_mask.to(device)
194
+ else:
195
+ attention_mask = None
196
+
197
+ prompt_embeds = self.text_encoder(
198
+ text_input_ids.to(device),
199
+ attention_mask=attention_mask,
200
+ )
201
+ prompt_embeds = prompt_embeds[0]
202
+
203
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
204
+
205
+ bs_embed, seq_len, _ = prompt_embeds.shape
206
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
207
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
208
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
209
+
210
+ # get unconditional embeddings for classifier free guidance
211
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
212
+ uncond_tokens: List[str]
213
+ if negative_prompt is None:
214
+ uncond_tokens = [""] * batch_size
215
+ elif type(prompt) is not type(negative_prompt):
216
+ raise TypeError(
217
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
218
+ f" {type(prompt)}."
219
+ )
220
+ elif isinstance(negative_prompt, str):
221
+ uncond_tokens = [negative_prompt]
222
+ elif batch_size != len(negative_prompt):
223
+ raise ValueError(
224
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
225
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
226
+ " the batch size of `prompt`."
227
+ )
228
+ else:
229
+ uncond_tokens = negative_prompt
230
+
231
+ max_length = prompt_embeds.shape[1]
232
+ uncond_input = self.tokenizer(
233
+ uncond_tokens,
234
+ padding="max_length",
235
+ max_length=max_length,
236
+ truncation=True,
237
+ return_tensors="pt",
238
+ )
239
+
240
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
241
+ attention_mask = uncond_input.attention_mask.to(device)
242
+ else:
243
+ attention_mask = None
244
+
245
+ negative_prompt_embeds = self.text_encoder(
246
+ uncond_input.input_ids.to(device),
247
+ attention_mask=attention_mask,
248
+ )
249
+ negative_prompt_embeds = negative_prompt_embeds[0]
250
+
251
+ if do_classifier_free_guidance:
252
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
253
+ seq_len = negative_prompt_embeds.shape[1]
254
+
255
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
256
+
257
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
258
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
259
+
260
+ # For classifier free guidance, we need to do two forward passes.
261
+ # Here we concatenate the unconditional and text embeddings into a single batch
262
+ # to avoid doing two forward passes
263
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
264
+
265
+ return prompt_embeds
266
+
267
+
268
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
269
+ def decode_latents(self, latents):
270
+ latents = 1 / 0.18215 * latents
271
+ image = self.vae.decode(latents).sample
272
+ image = (image / 2 + 0.5).clamp(0, 1)
273
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
274
+ image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
275
+ return image
276
+
277
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
278
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
279
+ if isinstance(generator, list) and len(generator) != batch_size:
280
+ raise ValueError(
281
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
282
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
283
+ )
284
+
285
+ if latents is None:
286
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
287
+ else:
288
+ latents = latents.to(device)
289
+
290
+ # scale the initial noise by the standard deviation required by the scheduler
291
+ latents = latents * self.scheduler.init_noise_sigma
292
+ return latents
293
+
294
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
295
+ def prepare_extra_step_kwargs(self, generator, eta):
296
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
297
+ # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
298
+ # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
299
+ # and should be between [0, 1]
300
+
301
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
302
+ extra_step_kwargs = {}
303
+ if accepts_eta:
304
+ extra_step_kwargs["eta"] = eta
305
+
306
+ # check if the scheduler accepts generator
307
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
308
+ if accepts_generator:
309
+ extra_step_kwargs["generator"] = generator
310
+ return extra_step_kwargs
311
+
312
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
313
+ def run_safety_checker(self, image, device, dtype):
314
+ if self.safety_checker is not None:
315
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
316
+ image, has_nsfw_concept = self.safety_checker(
317
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
318
+ )
319
+ else:
320
+ has_nsfw_concept = None
321
+ return image, has_nsfw_concept
322
+
submodules/pix2pix-zero/src/utils/cross_attention.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.models.attention import CrossAttention
3
+
4
+ class MyCrossAttnProcessor:
5
+ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
6
+ batch_size, sequence_length, _ = hidden_states.shape
7
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
8
+
9
+ query = attn.to_q(hidden_states)
10
+
11
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
12
+ key = attn.to_k(encoder_hidden_states)
13
+ value = attn.to_v(encoder_hidden_states)
14
+
15
+ query = attn.head_to_batch_dim(query)
16
+ key = attn.head_to_batch_dim(key)
17
+ value = attn.head_to_batch_dim(value)
18
+
19
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
20
+ # new bookkeeping to save the attn probs
21
+ attn.attn_probs = attention_probs
22
+
23
+ hidden_states = torch.bmm(attention_probs, value)
24
+ hidden_states = attn.batch_to_head_dim(hidden_states)
25
+
26
+ # linear proj
27
+ hidden_states = attn.to_out[0](hidden_states)
28
+ # dropout
29
+ hidden_states = attn.to_out[1](hidden_states)
30
+
31
+ return hidden_states
32
+
33
+
34
+ """
35
+ A function that prepares a U-Net model for training by enabling gradient computation
36
+ for a specified set of parameters and setting the forward pass to be performed by a
37
+ custom cross attention processor.
38
+
39
+ Parameters:
40
+ unet: A U-Net model.
41
+
42
+ Returns:
43
+ unet: The prepared U-Net model.
44
+ """
45
+ def prep_unet(unet):
46
+ # set the gradients for XA maps to be true
47
+ for name, params in unet.named_parameters():
48
+ if 'attn2' in name:
49
+ params.requires_grad = True
50
+ else:
51
+ params.requires_grad = False
52
+ # replace the fwd function
53
+ for name, module in unet.named_modules():
54
+ module_name = type(module).__name__
55
+ if module_name == "CrossAttention":
56
+ module.set_processor(MyCrossAttnProcessor())
57
+ return unet
submodules/pix2pix-zero/src/utils/ddim_inv.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from random import randrange
6
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
7
+ from diffusers import DDIMScheduler
8
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
9
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
10
+ sys.path.insert(0, "src/utils")
11
+ from base_pipeline import BasePipeline
12
+ from cross_attention import prep_unet
13
+
14
+
15
+ class DDIMInversion(BasePipeline):
16
+
17
+ def auto_corr_loss(self, x, random_shift=True):
18
+ B,C,H,W = x.shape
19
+ assert B==1
20
+ x = x.squeeze(0)
21
+ # x must be shape [C,H,W] now
22
+ reg_loss = 0.0
23
+ for ch_idx in range(x.shape[0]):
24
+ noise = x[ch_idx][None, None,:,:]
25
+ while True:
26
+ if random_shift: roll_amount = randrange(noise.shape[2]//2)
27
+ else: roll_amount = 1
28
+ reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=2)).mean()**2
29
+ reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=3)).mean()**2
30
+ if noise.shape[2] <= 8:
31
+ break
32
+ noise = F.avg_pool2d(noise, kernel_size=2)
33
+ return reg_loss
34
+
35
+ def kl_divergence(self, x):
36
+ _mu = x.mean()
37
+ _var = x.var()
38
+ return _var + _mu**2 - 1 - torch.log(_var+1e-7)
39
+
40
+
41
+ def __call__(
42
+ self,
43
+ prompt: Union[str, List[str]] = None,
44
+ num_inversion_steps: int = 50,
45
+ guidance_scale: float = 7.5,
46
+ negative_prompt: Optional[Union[str, List[str]]] = None,
47
+ num_images_per_prompt: Optional[int] = 1,
48
+ eta: float = 0.0,
49
+ output_type: Optional[str] = "pil",
50
+ return_dict: bool = True,
51
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
52
+ img=None, # the input image as a PIL image
53
+ torch_dtype=torch.float32,
54
+
55
+ # inversion regularization parameters
56
+ lambda_ac: float = 20.0,
57
+ lambda_kl: float = 20.0,
58
+ num_reg_steps: int = 5,
59
+ num_ac_rolls: int = 5,
60
+ ):
61
+
62
+ # 0. modify the unet to be useful :D
63
+ self.unet = prep_unet(self.unet)
64
+
65
+ # set the scheduler to be the Inverse DDIM scheduler
66
+ # self.scheduler = MyDDIMScheduler.from_config(self.scheduler.config)
67
+
68
+ device = self._execution_device
69
+ do_classifier_free_guidance = guidance_scale > 1.0
70
+ self.scheduler.set_timesteps(num_inversion_steps, device=device)
71
+ timesteps = self.scheduler.timesteps
72
+
73
+ # Encode the input image with the first stage model
74
+ x0 = np.array(img)/255
75
+ x0 = torch.from_numpy(x0).type(torch_dtype).permute(2, 0, 1).unsqueeze(dim=0).repeat(1, 1, 1, 1).cuda()
76
+ x0 = (x0 - 0.5) * 2.
77
+ with torch.no_grad():
78
+ x0_enc = self.vae.encode(x0).latent_dist.sample().to(device, torch_dtype)
79
+ latents = x0_enc = 0.18215 * x0_enc
80
+
81
+ # Decode and return the image
82
+ with torch.no_grad():
83
+ x0_dec = self.decode_latents(x0_enc.detach())
84
+ image_x0_dec = self.numpy_to_pil(x0_dec)
85
+
86
+ with torch.no_grad():
87
+ prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt).to(device)
88
+ extra_step_kwargs = self.prepare_extra_step_kwargs(None, eta)
89
+
90
+ # Do the inversion
91
+ num_warmup_steps = len(timesteps) - num_inversion_steps * self.scheduler.order # should be 0?
92
+ with self.progress_bar(total=num_inversion_steps) as progress_bar:
93
+ for i, t in enumerate(timesteps.flip(0)[1:-1]):
94
+ # expand the latents if we are doing classifier free guidance
95
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
96
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
97
+
98
+ # predict the noise residual
99
+ with torch.no_grad():
100
+ noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
101
+
102
+ # perform guidance
103
+ if do_classifier_free_guidance:
104
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
105
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
106
+
107
+ # regularization of the noise prediction
108
+ e_t = noise_pred
109
+ for _outer in range(num_reg_steps):
110
+ if lambda_ac>0:
111
+ for _inner in range(num_ac_rolls):
112
+ _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
113
+ l_ac = self.auto_corr_loss(_var)
114
+ l_ac.backward()
115
+ _grad = _var.grad.detach()/num_ac_rolls
116
+ e_t = e_t - lambda_ac*_grad
117
+ if lambda_kl>0:
118
+ _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
119
+ l_kld = self.kl_divergence(_var)
120
+ l_kld.backward()
121
+ _grad = _var.grad.detach()
122
+ e_t = e_t - lambda_kl*_grad
123
+ e_t = e_t.detach()
124
+ noise_pred = e_t
125
+
126
+ # compute the previous noisy sample x_t -> x_t-1
127
+ latents = self.scheduler.step(noise_pred, t, latents, reverse=True, **extra_step_kwargs).prev_sample
128
+
129
+ # call the callback, if provided
130
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
131
+ progress_bar.update()
132
+
133
+
134
+ x_inv = latents.detach().clone()
135
+ # reconstruct the image
136
+
137
+ # 8. Post-processing
138
+ image = self.decode_latents(latents.detach())
139
+ image = self.numpy_to_pil(image)
140
+ return x_inv, image, image_x0_dec
submodules/pix2pix-zero/src/utils/edit_directions.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+
5
+ """
6
+ This function takes in a task name and returns the direction in the embedding space that transforms class A to class B for the given task.
7
+
8
+ Parameters:
9
+ task_name (str): name of the task for which direction is to be constructed.
10
+
11
+ Returns:
12
+ torch.Tensor: A tensor representing the direction in the embedding space that transforms class A to class B.
13
+
14
+ Examples:
15
+ >>> construct_direction("cat2dog")
16
+ """
17
+ def construct_direction(task_name):
18
+ if task_name=="cat2dog":
19
+ emb_dir = f"assets/embeddings_sd_1.4"
20
+ embs_a = torch.load(os.path.join(emb_dir, f"cat.pt"))
21
+ embs_b = torch.load(os.path.join(emb_dir, f"dog.pt"))
22
+ return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
23
+ elif task_name=="dog2cat":
24
+ emb_dir = f"assets/embeddings_sd_1.4"
25
+ embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
26
+ embs_b = torch.load(os.path.join(emb_dir, f"cat.pt"))
27
+ return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
28
+ else:
29
+ raise NotImplementedError
submodules/pix2pix-zero/src/utils/edit_pipeline.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb, sys
2
+
3
+ import numpy as np
4
+ import torch
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
7
+ sys.path.insert(0, "src/utils")
8
+ from base_pipeline import BasePipeline
9
+ from cross_attention import prep_unet
10
+
11
+
12
+ class EditingPipeline(BasePipeline):
13
+ def __call__(
14
+ self,
15
+ prompt: Union[str, List[str]] = None,
16
+ height: Optional[int] = None,
17
+ width: Optional[int] = None,
18
+ num_inference_steps: int = 50,
19
+ guidance_scale: float = 7.5,
20
+ negative_prompt: Optional[Union[str, List[str]]] = None,
21
+ num_images_per_prompt: Optional[int] = 1,
22
+ eta: float = 0.0,
23
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
24
+ latents: Optional[torch.FloatTensor] = None,
25
+ prompt_embeds: Optional[torch.FloatTensor] = None,
26
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
27
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
28
+
29
+ # pix2pix parameters
30
+ guidance_amount=0.1,
31
+ edit_dir=None,
32
+ x_in=None,
33
+ only_sample=False,
34
+
35
+ ):
36
+
37
+ x_in.to(dtype=self.unet.dtype, device=self._execution_device)
38
+
39
+ # 0. modify the unet to be useful :D
40
+ self.unet = prep_unet(self.unet)
41
+
42
+ # 1. setup all caching objects
43
+ d_ref_t2attn = {} # reference cross attention maps
44
+
45
+ # 2. Default height and width to unet
46
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
47
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
48
+
49
+ # TODO: add the input checker function
50
+ # self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds )
51
+
52
+ # 2. Define call parameters
53
+ if prompt is not None and isinstance(prompt, str):
54
+ batch_size = 1
55
+ elif prompt is not None and isinstance(prompt, list):
56
+ batch_size = len(prompt)
57
+ else:
58
+ batch_size = prompt_embeds.shape[0]
59
+
60
+ device = self._execution_device
61
+ do_classifier_free_guidance = guidance_scale > 1.0
62
+ x_in = x_in.to(dtype=self.unet.dtype, device=self._execution_device)
63
+ # 3. Encode input prompt = 2x77x1024
64
+ prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,)
65
+
66
+ # 4. Prepare timesteps
67
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
68
+ timesteps = self.scheduler.timesteps
69
+
70
+ # 5. Prepare latent variables
71
+ num_channels_latents = self.unet.in_channels
72
+
73
+ # randomly sample a latent code if not provided
74
+ latents = self.prepare_latents(batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, x_in,)
75
+
76
+ latents_init = latents.clone()
77
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
78
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
79
+
80
+ # 7. First Denoising loop for getting the reference cross attention maps
81
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
82
+ with torch.no_grad():
83
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
84
+ for i, t in enumerate(timesteps):
85
+ # expand the latents if we are doing classifier free guidance
86
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
87
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
88
+
89
+ # predict the noise residual
90
+ noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
91
+
92
+ # add the cross attention map to the dictionary
93
+ d_ref_t2attn[t.item()] = {}
94
+ for name, module in self.unet.named_modules():
95
+ module_name = type(module).__name__
96
+ if module_name == "CrossAttention" and 'attn2' in name:
97
+ attn_mask = module.attn_probs # size is num_channel,s*s,77
98
+ d_ref_t2attn[t.item()][name] = attn_mask.detach().cpu()
99
+
100
+ # perform guidance
101
+ if do_classifier_free_guidance:
102
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
103
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
104
+
105
+ # compute the previous noisy sample x_t -> x_t-1
106
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
107
+
108
+ # call the callback, if provided
109
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
110
+ progress_bar.update()
111
+
112
+ # make the reference image (reconstruction)
113
+ image_rec = self.numpy_to_pil(self.decode_latents(latents.detach()))
114
+
115
+ if only_sample:
116
+ return image_rec
117
+
118
+
119
+ prompt_embeds_edit = prompt_embeds.clone()
120
+ #add the edit only to the second prompt, idx 0 is the negative prompt
121
+ prompt_embeds_edit[1:2] += edit_dir
122
+
123
+ latents = latents_init
124
+ # Second denoising loop for editing the text prompt
125
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
126
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
127
+ for i, t in enumerate(timesteps):
128
+ # expand the latents if we are doing classifier free guidance
129
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
130
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
131
+
132
+ x_in = latent_model_input.detach().clone()
133
+ x_in.requires_grad = True
134
+
135
+ opt = torch.optim.SGD([x_in], lr=guidance_amount)
136
+
137
+ # predict the noise residual
138
+ noise_pred = self.unet(x_in,t,encoder_hidden_states=prompt_embeds_edit.detach(),cross_attention_kwargs=cross_attention_kwargs,).sample
139
+
140
+ loss = 0.0
141
+ for name, module in self.unet.named_modules():
142
+ module_name = type(module).__name__
143
+ if module_name == "CrossAttention" and 'attn2' in name:
144
+ curr = module.attn_probs # size is num_channel,s*s,77
145
+ ref = d_ref_t2attn[t.item()][name].detach().cuda()
146
+ loss += ((curr-ref)**2).sum((1,2)).mean(0)
147
+ loss.backward(retain_graph=False)
148
+ opt.step()
149
+
150
+ # recompute the noise
151
+ with torch.no_grad():
152
+ noise_pred = self.unet(x_in.detach(),t,encoder_hidden_states=prompt_embeds_edit,cross_attention_kwargs=cross_attention_kwargs,).sample
153
+
154
+ latents = x_in.detach().chunk(2)[0]
155
+
156
+ # perform guidance
157
+ if do_classifier_free_guidance:
158
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
159
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
160
+
161
+ # compute the previous noisy sample x_t -> x_t-1
162
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
163
+
164
+ # call the callback, if provided
165
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
166
+ progress_bar.update()
167
+
168
+
169
+ # 8. Post-processing
170
+ image = self.decode_latents(latents.detach())
171
+
172
+ # 9. Run safety checker
173
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
174
+
175
+ # 10. Convert to PIL
176
+ image_edit = self.numpy_to_pil(image)
177
+
178
+
179
+ return image_rec, image_edit
submodules/pix2pix-zero/src/utils/scheduler.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+ import os, sys, pdb
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.utils import BaseOutput, randn_tensor
27
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
28
+
29
+
30
+ @dataclass
31
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
32
+ class DDIMSchedulerOutput(BaseOutput):
33
+ """
34
+ Output class for the scheduler's step function output.
35
+
36
+ Args:
37
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
39
+ denoising loop.
40
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
42
+ `pred_original_sample` can be used to preview progress or for guidance.
43
+ """
44
+
45
+ prev_sample: torch.FloatTensor
46
+ pred_original_sample: Optional[torch.FloatTensor] = None
47
+
48
+
49
+ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
50
+ """
51
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
52
+ (1-beta) over time from t = [0,1].
53
+
54
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
55
+ to that part of the diffusion process.
56
+
57
+
58
+ Args:
59
+ num_diffusion_timesteps (`int`): the number of betas to produce.
60
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
61
+ prevent singularities.
62
+
63
+ Returns:
64
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
65
+ """
66
+
67
+ def alpha_bar(time_step):
68
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
69
+
70
+ betas = []
71
+ for i in range(num_diffusion_timesteps):
72
+ t1 = i / num_diffusion_timesteps
73
+ t2 = (i + 1) / num_diffusion_timesteps
74
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
75
+ return torch.tensor(betas)
76
+
77
+
78
+ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
79
+ """
80
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
81
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
82
+
83
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
84
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
85
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
86
+ [`~SchedulerMixin.from_pretrained`] functions.
87
+
88
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
89
+
90
+ Args:
91
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
92
+ beta_start (`float`): the starting `beta` value of inference.
93
+ beta_end (`float`): the final `beta` value.
94
+ beta_schedule (`str`):
95
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
96
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
97
+ trained_betas (`np.ndarray`, optional):
98
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
99
+ clip_sample (`bool`, default `True`):
100
+ option to clip predicted sample between -1 and 1 for numerical stability.
101
+ set_alpha_to_one (`bool`, default `True`):
102
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
103
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
104
+ otherwise it uses the value of alpha at step 0.
105
+ steps_offset (`int`, default `0`):
106
+ an offset added to the inference steps. You can use a combination of `offset=1` and
107
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
108
+ stable diffusion.
109
+ prediction_type (`str`, default `epsilon`, optional):
110
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
111
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
112
+ https://imagen.research.google/video/paper.pdf)
113
+ """
114
+
115
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
116
+ order = 1
117
+
118
+ @register_to_config
119
+ def __init__(
120
+ self,
121
+ num_train_timesteps: int = 1000,
122
+ beta_start: float = 0.0001,
123
+ beta_end: float = 0.02,
124
+ beta_schedule: str = "linear",
125
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
126
+ clip_sample: bool = True,
127
+ set_alpha_to_one: bool = True,
128
+ steps_offset: int = 0,
129
+ prediction_type: str = "epsilon",
130
+ ):
131
+ if trained_betas is not None:
132
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
133
+ elif beta_schedule == "linear":
134
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
135
+ elif beta_schedule == "scaled_linear":
136
+ # this schedule is very specific to the latent diffusion model.
137
+ self.betas = (
138
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
139
+ )
140
+ elif beta_schedule == "squaredcos_cap_v2":
141
+ # Glide cosine schedule
142
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
143
+ else:
144
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
145
+
146
+ self.alphas = 1.0 - self.betas
147
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
148
+
149
+ # At every step in ddim, we are looking into the previous alphas_cumprod
150
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
151
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
152
+ # whether we use the final alpha of the "non-previous" one.
153
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
154
+
155
+ # standard deviation of the initial noise distribution
156
+ self.init_noise_sigma = 1.0
157
+
158
+ # setable values
159
+ self.num_inference_steps = None
160
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
161
+
162
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
163
+ """
164
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
165
+ current timestep.
166
+
167
+ Args:
168
+ sample (`torch.FloatTensor`): input sample
169
+ timestep (`int`, optional): current timestep
170
+
171
+ Returns:
172
+ `torch.FloatTensor`: scaled input sample
173
+ """
174
+ return sample
175
+
176
+ def _get_variance(self, timestep, prev_timestep):
177
+ alpha_prod_t = self.alphas_cumprod[timestep]
178
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
179
+ beta_prod_t = 1 - alpha_prod_t
180
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
181
+
182
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
183
+
184
+ return variance
185
+
186
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
187
+ """
188
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
189
+
190
+ Args:
191
+ num_inference_steps (`int`):
192
+ the number of diffusion steps used when generating samples with a pre-trained model.
193
+ """
194
+
195
+ if num_inference_steps > self.config.num_train_timesteps:
196
+ raise ValueError(
197
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
198
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
199
+ f" maximal {self.config.num_train_timesteps} timesteps."
200
+ )
201
+
202
+ self.num_inference_steps = num_inference_steps
203
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
204
+ # creates integer timesteps by multiplying by ratio
205
+ # casting to int to avoid issues when num_inference_step is power of 3
206
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
207
+ self.timesteps = torch.from_numpy(timesteps).to(device)
208
+ self.timesteps += self.config.steps_offset
209
+
210
+ def step(
211
+ self,
212
+ model_output: torch.FloatTensor,
213
+ timestep: int,
214
+ sample: torch.FloatTensor,
215
+ eta: float = 0.0,
216
+ use_clipped_model_output: bool = False,
217
+ generator=None,
218
+ variance_noise: Optional[torch.FloatTensor] = None,
219
+ return_dict: bool = True,
220
+ reverse=False
221
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
222
+
223
+
224
+ e_t = model_output
225
+
226
+ x = sample
227
+ prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
228
+ # print(timestep, prev_timestep)
229
+ a_t = alpha_prod_t = self.alphas_cumprod[timestep-1]
230
+ a_prev = alpha_t_prev = self.alphas_cumprod[prev_timestep-1] if prev_timestep >= 0 else self.final_alpha_cumprod
231
+ beta_prod_t = 1 - alpha_prod_t
232
+
233
+ pred_x0 = (x - (1-a_t)**0.5 * e_t) / a_t.sqrt()
234
+ # direction pointing to x_t
235
+ dir_xt = (1. - a_prev).sqrt() * e_t
236
+ x = a_prev.sqrt()*pred_x0 + dir_xt
237
+ if not return_dict:
238
+ return (x,)
239
+ return DDIMSchedulerOutput(prev_sample=x, pred_original_sample=pred_x0)
240
+
241
+
242
+
243
+
244
+
245
+ def add_noise(
246
+ self,
247
+ original_samples: torch.FloatTensor,
248
+ noise: torch.FloatTensor,
249
+ timesteps: torch.IntTensor,
250
+ ) -> torch.FloatTensor:
251
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
252
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
253
+ timesteps = timesteps.to(original_samples.device)
254
+
255
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
256
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
257
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
258
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
259
+
260
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
261
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
262
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
263
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
264
+
265
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
266
+ return noisy_samples
267
+
268
+ def get_velocity(
269
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
270
+ ) -> torch.FloatTensor:
271
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
272
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
273
+ timesteps = timesteps.to(sample.device)
274
+
275
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
276
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
277
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
278
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
279
+
280
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
281
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
282
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
283
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
284
+
285
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
286
+ return velocity
287
+
288
+ def __len__(self):
289
+ return self.config.num_train_timesteps
utils.py ADDED
File without changes
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (172 Bytes). View file
 
utils/__pycache__/direction_utils.cpython-310.pyc ADDED
Binary file (4.55 kB). View file
 
utils/__pycache__/generate_synthetic.cpython-310.pyc ADDED
Binary file (8.97 kB). View file
 
utils/__pycache__/gradio_utils.cpython-310.pyc ADDED
Binary file (14.9 kB). View file
 
utils/direction_utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, pdb
2
+
3
+ import torch, torchvision
4
+ from huggingface_hub import hf_hub_url, cached_download, hf_hub_download, HfApi
5
+ import joblib
6
+ from pathlib import Path
7
+
8
+ import json
9
+ import requests
10
+ import random
11
+
12
+
13
+ """
14
+ Returns the list of directions currently available in the HF library
15
+ """
16
+ def get_all_directions_names():
17
+ hf_api = HfApi()
18
+ info = hf_api.list_models(author="pix2pix-zero-library")
19
+ l_model_ids = [m.modelId for m in info]
20
+ l_model_ids = [m for m in l_model_ids if "_sd14" in m]
21
+ l_edit_names = [m.split("/")[-1] for m in l_model_ids]
22
+ # l_edit_names = [m for m in l_edit_names if "_sd14" in m]
23
+
24
+ # pdb.set_trace()
25
+ l_desc = [hf_hub_download(repo_id=m_id, filename="short_description.txt") for m_id in l_model_ids]
26
+ d_name2desc = {k: open(m).read() for k,m in zip(l_edit_names, l_desc)}
27
+
28
+ return d_name2desc
29
+
30
+
31
+ def get_emb(dir_name):
32
+ REPO_ID = f"pix2pix-zero-library/{dir_name.replace('.pt','')}"
33
+ if "_sd14" not in REPO_ID: REPO_ID += "_sd14"
34
+ FILENAME = dir_name
35
+ if "_sd14" not in FILENAME: FILENAME += "_sd14"
36
+ if ".pt" not in FILENAME: FILENAME += ".pt"
37
+ ret = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
38
+ return torch.load(ret)
39
+
40
+
41
+ def generate_image_prompts_with_templates(word):
42
+ prompts = []
43
+ adjectives = ['majestic', 'cute', 'colorful', 'ferocious', 'elegant', 'graceful', 'slimy', 'adorable', 'scary', 'fuzzy', 'tiny', 'gigantic', 'brave', 'fierce', 'mysterious', 'curious', 'fascinating', 'charming', 'gleaming', 'rare']
44
+ verbs = ['strolling', 'jumping', 'lounging', 'flying', 'sleeping', 'eating', 'playing', 'working', 'gazing', 'standing']
45
+ adverbs = ['gracefully', 'playfully', 'elegantly', 'fiercely', 'curiously', 'fascinatingly', 'charmingly', 'gently', 'slowly', 'quickly', 'awkwardly', 'carelessly', 'cautiously', 'innocently', 'powerfully', 'grumpily', 'mysteriously']
46
+ backgrounds = ['a sunny beach', 'a bustling city', 'a quiet forest', 'a cozy living room', 'a futuristic space station', 'a medieval castle', 'an enchanted forest', 'a misty graveyard', 'a snowy mountain peak', 'a crowded market']
47
+
48
+ sentence_structures = {
49
+ "subject verb background": lambda word, bg, verb, adj, adv: f"A {word} {verb} {bg}.",
50
+ "background subject verb": lambda word, bg, verb, adj, adv: f"{bg}, a {word} is {verb}.",
51
+ "adjective subject verb background": lambda word, bg, verb, adj, adv: f"A {adj} {word} is {verb} {bg}.",
52
+ "subject verb adverb background": lambda word, bg, verb, adj, adv: f"A {word} is {verb} {adv} {bg}.",
53
+ "adverb subject verb background": lambda word, bg, verb, adj, adv: f"{adv.capitalize()}, a {word} is {verb} {bg}.",
54
+ "background adjective subject verb": lambda word, bg, verb, adj, adv: f"{bg}, there is a {adj} {word} {verb}.",
55
+ "subject verb adjective background": lambda word, bg, verb, adj, adv: f"A {word} {verb} {adj} {bg}.",
56
+ "adjective subject verb": lambda word, bg, verb, adj, adv: f"A {adj} {word} is {verb}.",
57
+ "subject adjective verb background": lambda word, bg, verb, adj, adv: f"A {word} is {adj} and {verb} {bg}.",
58
+ }
59
+
60
+ sentences = []
61
+ for bg in backgrounds:
62
+ for verb in verbs:
63
+ for adj in adjectives:
64
+ adv = random.choice(adverbs)
65
+ sentence = f"A {adv} {adj} {word} {verb} on {bg}."
66
+ sentence_structure = random.choice(list(sentence_structures.keys()))
67
+ sentence = sentence_structures[sentence_structure](word, bg, verb, adj, adv)
68
+ sentences.append(sentence)
69
+ return sentences
70
+
71
+
72
+
73
+
74
+ if __name__=="__main__":
75
+ print(get_all_directions_names())
76
+ # print(get_emb("dog_sd14.pt").shape)
77
+ # print(get_emb("dog").shape)
78
+ # print(generate_image_prompts("dog")[0:5])
79
+
utils/generate_synthetic.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, time, re, pdb
2
+ import torch, torchvision
3
+ import numpy
4
+ from PIL import Image
5
+ import hashlib
6
+ from tqdm import tqdm
7
+ import openai
8
+ from utils.direction_utils import *
9
+
10
+ p = "submodules/pix2pix-zero/src/utils"
11
+ if p not in sys.path:
12
+ sys.path.append(p)
13
+ from diffusers import DDIMScheduler
14
+ from edit_directions import construct_direction
15
+ from edit_pipeline import EditingPipeline
16
+ from ddim_inv import DDIMInversion
17
+ from scheduler import DDIMInverseScheduler
18
+ from lavis.models import load_model_and_preprocess
19
+ from transformers import T5Tokenizer, AutoTokenizer, T5ForConditionalGeneration, BloomForCausalLM
20
+
21
+
22
+
23
+ def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
24
+ with torch.no_grad():
25
+ l_embeddings = []
26
+ for sent in tqdm(l_sentences):
27
+ text_inputs = tokenizer(
28
+ sent,
29
+ padding="max_length",
30
+ max_length=tokenizer.model_max_length,
31
+ truncation=True,
32
+ return_tensors="pt",
33
+ )
34
+ text_input_ids = text_inputs.input_ids
35
+ prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
36
+ l_embeddings.append(prompt_embeds)
37
+ return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
38
+
39
+
40
+
41
+ def launch_generate_sample(prompt, seed, negative_scale, num_ddim):
42
+ os.makedirs("tmp", exist_ok=True)
43
+ # do the editing
44
+ edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
45
+ edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config)
46
+
47
+ # set the random seed and sample the input noise map
48
+ torch.cuda.manual_seed(int(seed))
49
+ z = torch.randn((1,4,64,64), device="cuda")
50
+
51
+ z_hashname = hashlib.sha256(z.cpu().numpy().tobytes()).hexdigest()
52
+ z_inv_fname = f"tmp/{z_hashname}_ddim_{num_ddim}_inv.pt"
53
+ torch.save(z, z_inv_fname)
54
+
55
+ rec_pil = edit_pipe(prompt,
56
+ num_inference_steps=num_ddim, x_in=z,
57
+ only_sample=True, # this flag will only generate the sampled image, not the edited image
58
+ guidance_scale=negative_scale,
59
+ negative_prompt="" # use the empty string for the negative prompt
60
+ )
61
+ # print(rec_pil)
62
+ del edit_pipe
63
+ torch.cuda.empty_cache()
64
+
65
+ return rec_pil[0], z_inv_fname
66
+
67
+
68
+
69
+ def clean_l_sentences(ls):
70
+ s = [re.sub('\d', '', x) for x in ls]
71
+ s = [x.replace(".","").replace("-","").replace(")","").strip() for x in s]
72
+ return s
73
+
74
+
75
+
76
+ def gpt3_compute_word2sentences(task_type, word, num=100):
77
+ l_sentences = []
78
+ if task_type=="object":
79
+ template_prompt = f"Provide many captions for images containing {word}."
80
+ elif task_type=="style":
81
+ template_prompt = f"Provide many captions for images that are in the {word} style."
82
+ while True:
83
+ ret = openai.Completion.create(
84
+ model="text-davinci-002",
85
+ prompt=template_prompt,
86
+ max_tokens=1000,
87
+ temperature=1.0)
88
+ raw_return = ret.choices[0].text
89
+ for line in raw_return.split("\n"):
90
+ line = line.strip()
91
+ if len(line)>10:
92
+ skip=False
93
+ for subword in word.split(" "):
94
+ if subword not in line: skip=True
95
+ if not skip: l_sentences.append(line)
96
+ else:
97
+ l_sentences.append(line+f", {word}")
98
+ time.sleep(0.05)
99
+ print(len(l_sentences))
100
+ if len(l_sentences)>=num:
101
+ break
102
+ l_sentences = clean_l_sentences(l_sentences)
103
+ return l_sentences
104
+
105
+
106
+ def flant5xl_compute_word2sentences(word, num=100):
107
+ text_input = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters."
108
+
109
+ l_sentences = []
110
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
111
+ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16)
112
+ input_ids = tokenizer(text_input, return_tensors="pt").input_ids.to("cuda")
113
+ input_length = input_ids.shape[1]
114
+ while True:
115
+ outputs = model.generate(input_ids,temperature=0.9, num_return_sequences=16, do_sample=True, max_length=128)
116
+ output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
117
+ for line in output:
118
+ line = line.strip()
119
+ skip=False
120
+ for subword in word.split(" "):
121
+ if subword not in line: skip=True
122
+ if not skip: l_sentences.append(line)
123
+ else: l_sentences.append(line+f", {word}")
124
+ print(len(l_sentences))
125
+ if len(l_sentences)>=num:
126
+ break
127
+ l_sentences = clean_l_sentences(l_sentences)
128
+
129
+ del model
130
+ del tokenizer
131
+ torch.cuda.empty_cache()
132
+
133
+ return l_sentences
134
+
135
+ def bloomz_compute_sentences(word, num=100):
136
+ l_sentences = []
137
+ tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-7b1")
138
+ model = BloomForCausalLM.from_pretrained("bigscience/bloomz-7b1", device_map="auto", torch_dtype=torch.float16)
139
+ input_text = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters. Caption:"
140
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
141
+ input_length = input_ids.shape[1]
142
+ t = 0.95
143
+ eta = 1e-5
144
+ min_length = 15
145
+
146
+ while True:
147
+ try:
148
+ outputs = model.generate(input_ids,temperature=t, num_return_sequences=16, do_sample=True, max_length=128, min_length=min_length, eta_cutoff=eta)
149
+ output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
150
+ except:
151
+ continue
152
+ for line in output:
153
+ line = line.strip()
154
+ skip=False
155
+ for subword in word.split(" "):
156
+ if subword not in line: skip=True
157
+ if not skip: l_sentences.append(line)
158
+ else: l_sentences.append(line+f", {word}")
159
+ print(len(l_sentences))
160
+ if len(l_sentences)>=num:
161
+ break
162
+ l_sentences = clean_l_sentences(l_sentences)
163
+ del model
164
+ del tokenizer
165
+ torch.cuda.empty_cache()
166
+
167
+ return l_sentences
168
+
169
+
170
+
171
+ def make_custom_dir(description, sent_type, api_key, org_key, l_custom_sentences):
172
+ if sent_type=="fixed-template":
173
+ l_sentences = generate_image_prompts_with_templates(description)
174
+ elif "GPT3" in sent_type:
175
+ import openai
176
+ openai.organization = org_key
177
+ openai.api_key = api_key
178
+ _=openai.Model.retrieve("text-davinci-002")
179
+ l_sentences = gpt3_compute_word2sentences("object", description, num=1000)
180
+
181
+ elif "flan-t5-xl" in sent_type:
182
+ l_sentences = flant5xl_compute_word2sentences(description, num=1000)
183
+ # save the sentences to file
184
+ with open(f"tmp/flant5xl_sentences_{description}.txt", "w") as f:
185
+ for line in l_sentences:
186
+ f.write(line+"\n")
187
+ elif "BLOOMZ-7B" in sent_type:
188
+ l_sentences = bloomz_compute_sentences(description, num=1000)
189
+ # save the sentences to file
190
+ with open(f"tmp/bloomz_sentences_{description}.txt", "w") as f:
191
+ for line in l_sentences:
192
+ f.write(line+"\n")
193
+
194
+ elif sent_type=="custom sentences":
195
+ l_sentences = l_custom_sentences.split("\n")
196
+ print(f"length of new sentence is {len(l_sentences)}")
197
+
198
+ pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
199
+ emb = load_sentence_embeddings(l_sentences, pipe.tokenizer, pipe.text_encoder, device="cuda")
200
+ del pipe
201
+ torch.cuda.empty_cache()
202
+ return emb
203
+
204
+
205
+ def launch_main(img_in_real, img_in_synth, src, src_custom, dest, dest_custom, num_ddim, xa_guidance, edit_mul, fpath_z_gen, gen_prompt, sent_type_src, sent_type_dest, api_key, org_key, custom_sentences_src, custom_sentences_dest):
206
+ d_name2desc = get_all_directions_names()
207
+ d_desc2name = {v:k for k,v in d_name2desc.items()}
208
+ os.makedirs("tmp", exist_ok=True)
209
+
210
+ # generate custom direction first
211
+ if src=="make your own!":
212
+ outf_name = f"tmp/template_emb_{src_custom}_{sent_type_src}.pt"
213
+ if not os.path.exists(outf_name):
214
+ src_emb = make_custom_dir(src_custom, sent_type_src, api_key, org_key, custom_sentences_src)
215
+ torch.save(src_emb, outf_name)
216
+ else:
217
+ src_emb = torch.load(outf_name)
218
+ else:
219
+ src_emb = get_emb(d_desc2name[src])
220
+
221
+ if dest=="make your own!":
222
+ outf_name = f"tmp/template_emb_{dest_custom}_{sent_type_dest}.pt"
223
+ if not os.path.exists(outf_name):
224
+ dest_emb = make_custom_dir(dest_custom, sent_type_dest, api_key, org_key, custom_sentences_dest)
225
+ torch.save(dest_emb, outf_name)
226
+ else:
227
+ dest_emb = torch.load(outf_name)
228
+ else:
229
+ dest_emb = get_emb(d_desc2name[dest])
230
+ text_dir = (dest_emb.cuda() - src_emb.cuda())*edit_mul
231
+
232
+
233
+
234
+ if img_in_real is not None and img_in_synth is None:
235
+ print("using real image")
236
+ # resize the image so that the longer side is 512
237
+ width, height = img_in_real.size
238
+ if width > height: scale_factor = 512 / width
239
+ else: scale_factor = 512 / height
240
+ new_size = (int(width * scale_factor), int(height * scale_factor))
241
+ img_in_real = img_in_real.resize(new_size, Image.Resampling.LANCZOS)
242
+ hash = hashlib.sha256(img_in_real.tobytes()).hexdigest()
243
+ # print(hash)
244
+ inv_fname = f"tmp/{hash}_ddim_{num_ddim}_inv.pt"
245
+ caption_fname = f"tmp/{hash}_caption.txt"
246
+
247
+ # make the caption if it hasn't been made before
248
+ if not os.path.exists(caption_fname):
249
+ # BLIP
250
+ model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda"))
251
+ _image = vis_processors["eval"](img_in_real).unsqueeze(0).cuda()
252
+ prompt_str = model_blip.generate({"image": _image})[0]
253
+ del model_blip
254
+ torch.cuda.empty_cache()
255
+ with open(caption_fname, "w") as f:
256
+ f.write(prompt_str)
257
+ else:
258
+ prompt_str = open(caption_fname, "r").read().strip()
259
+ print(f"CAPTION: {prompt_str}")
260
+
261
+ # do the inversion if it hasn't been done before
262
+ if not os.path.exists(inv_fname):
263
+ # inversion pipeline
264
+ pipe_inv = DDIMInversion.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
265
+ pipe_inv.scheduler = DDIMInverseScheduler.from_config(pipe_inv.scheduler.config)
266
+ x_inv, x_inv_image, x_dec_img = pipe_inv( prompt_str,
267
+ guidance_scale=1, num_inversion_steps=num_ddim,
268
+ img=img_in_real, torch_dtype=torch.float32 )
269
+ x_inv = x_inv.detach()
270
+ torch.save(x_inv, inv_fname)
271
+ del pipe_inv
272
+ torch.cuda.empty_cache()
273
+ else:
274
+ x_inv = torch.load(inv_fname)
275
+
276
+ # do the editing
277
+ edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
278
+ edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config)
279
+
280
+ _, edit_pil = edit_pipe(prompt_str,
281
+ num_inference_steps=num_ddim,
282
+ x_in=x_inv,
283
+ edit_dir=text_dir,
284
+ guidance_amount=xa_guidance,
285
+ guidance_scale=5.0,
286
+ negative_prompt=prompt_str # use the unedited prompt for the negative prompt
287
+ )
288
+ del edit_pipe
289
+ torch.cuda.empty_cache()
290
+ return edit_pil[0]
291
+
292
+
293
+ elif img_in_real is None and img_in_synth is not None:
294
+ print("using synthetic image")
295
+ x_inv = torch.load(fpath_z_gen)
296
+ pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
297
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
298
+ rec_pil, edit_pil = pipe(gen_prompt,
299
+ num_inference_steps=num_ddim,
300
+ x_in=x_inv,
301
+ edit_dir=text_dir,
302
+ guidance_amount=xa_guidance,
303
+ guidance_scale=5,
304
+ negative_prompt="" # use the empty string for the negative prompt
305
+ )
306
+ del pipe
307
+ torch.cuda.empty_cache()
308
+ return edit_pil[0]
309
+
310
+ else:
311
+ raise ValueError(f"Invalid image type: {image_type}")
312
+
313
+
314
+
315
+ if __name__=="__main__":
316
+ print(flant5xl_compute_word2sentences("cat wearing sunglasses", num=100))
utils/gradio_utils.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def set_visible_true():
4
+ return gr.update(visible=True)
5
+
6
+ def set_visible_false():
7
+ return gr.update(visible=False)
8
+
9
+
10
+ # HTML_header = f"""
11
+ # <style>
12
+ # {CSS_main}
13
+ # </style>
14
+ # <div style="text-align: center; max-width: 700px; margin: 0 auto;">
15
+ # <div style=" display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;">
16
+ # <h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px;">
17
+ # Zero-shot Image-to-Image Translation
18
+ # </h1>
19
+ # </div>
20
+ # <p style="margin-bottom: 10px; font-size: 94%">
21
+ # This is the demo for <a href="https://pix2pixzero.github.io/" target="_blank">pix2pix-zero</a>.
22
+ # Please visit our <a href="https://pix2pixzero.github.io/"> website</a> and <a href="https://github.com/pix2pixzero/pix2pix-zero" target="_blank">github</a> for more details.
23
+ # </p>
24
+ # <p style="margin-bottom: 10px; font-size: 94%">
25
+ # pix2pix-zero is a diffusion-based image-to-image approach that allows users to specify the edit direction on-the-fly
26
+ # (e.g., cat to dog). Our method can directly use pre-trained text-to-image diffusion models, such as Stable Diffusion,
27
+ # for editing real and synthetic images while preserving the input image's structure. Our method is training-free and prompt-free,
28
+ # as it requires neither manual text prompting for each input image nor costly fine-tuning for each task.
29
+ # </p>
30
+
31
+ # </div>
32
+ # """
33
+
34
+
35
+
36
+
37
+ CSS_main = """
38
+ body {
39
+ font-family: "HelveticaNeue-Light", "Helvetica Neue Light", "Helvetica Neue", Helvetica, Arial, "Lucida Grande", sans-serif;
40
+ font-weight:300;
41
+ font-size:18px;
42
+ margin-left: auto;
43
+ margin-right: auto;
44
+ padding-left: 10px;
45
+ padding-right: 10px;
46
+ width: 800px;
47
+ }
48
+
49
+ h1 {
50
+ font-size:32px;
51
+ font-weight:300;
52
+ text-align: center;
53
+ }
54
+
55
+ h2 {
56
+ font-size:32px;
57
+ font-weight:300;
58
+ text-align: center;
59
+ }
60
+
61
+ #lbl_gallery_input{
62
+ font-family: 'Helvetica', 'Arial', sans-serif;
63
+ text-align: center;
64
+ color: #fff;
65
+ font-size: 28px;
66
+ display: inline
67
+ }
68
+
69
+
70
+ #lbl_gallery_comparision{
71
+ font-family: 'Helvetica', 'Arial', sans-serif;
72
+ text-align: center;
73
+ color: #fff;
74
+ font-size: 28px;
75
+ }
76
+
77
+ .disclaimerbox {
78
+ background-color: #eee;
79
+ border: 1px solid #eeeeee;
80
+ border-radius: 10px ;
81
+ -moz-border-radius: 10px ;
82
+ -webkit-border-radius: 10px ;
83
+ padding: 20px;
84
+ }
85
+
86
+ video.header-vid {
87
+ height: 140px;
88
+ border: 1px solid black;
89
+ border-radius: 10px ;
90
+ -moz-border-radius: 10px ;
91
+ -webkit-border-radius: 10px ;
92
+ }
93
+
94
+ img.header-img {
95
+ height: 140px;
96
+ border: 1px solid black;
97
+ border-radius: 10px ;
98
+ -moz-border-radius: 10px ;
99
+ -webkit-border-radius: 10px ;
100
+ }
101
+
102
+ img.rounded {
103
+ border: 1px solid #eeeeee;
104
+ border-radius: 10px ;
105
+ -moz-border-radius: 10px ;
106
+ -webkit-border-radius: 10px ;
107
+ }
108
+
109
+ a:link
110
+ {
111
+ color: #941120;
112
+ text-decoration: none;
113
+ }
114
+ a:visited
115
+ {
116
+ color: #941120;
117
+ text-decoration: none;
118
+ }
119
+ a:hover {
120
+ color: #941120;
121
+ }
122
+
123
+ td.dl-link {
124
+ height: 160px;
125
+ text-align: center;
126
+ font-size: 22px;
127
+ }
128
+
129
+ .layered-paper-big { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
130
+ box-shadow:
131
+ 0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */
132
+ 5px 5px 0 0px #fff, /* The second layer */
133
+ 5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */
134
+ 10px 10px 0 0px #fff, /* The third layer */
135
+ 10px 10px 1px 1px rgba(0,0,0,0.35), /* The third layer shadow */
136
+ 15px 15px 0 0px #fff, /* The fourth layer */
137
+ 15px 15px 1px 1px rgba(0,0,0,0.35), /* The fourth layer shadow */
138
+ 20px 20px 0 0px #fff, /* The fifth layer */
139
+ 20px 20px 1px 1px rgba(0,0,0,0.35), /* The fifth layer shadow */
140
+ 25px 25px 0 0px #fff, /* The fifth layer */
141
+ 25px 25px 1px 1px rgba(0,0,0,0.35); /* The fifth layer shadow */
142
+ margin-left: 10px;
143
+ margin-right: 45px;
144
+ }
145
+
146
+ .paper-big { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
147
+ box-shadow:
148
+ 0px 0px 1px 1px rgba(0,0,0,0.35); /* The top layer shadow */
149
+
150
+ margin-left: 10px;
151
+ margin-right: 45px;
152
+ }
153
+
154
+
155
+ .layered-paper { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
156
+ box-shadow:
157
+ 0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */
158
+ 5px 5px 0 0px #fff, /* The second layer */
159
+ 5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */
160
+ 10px 10px 0 0px #fff, /* The third layer */
161
+ 10px 10px 1px 1px rgba(0,0,0,0.35); /* The third layer shadow */
162
+ margin-top: 5px;
163
+ margin-left: 10px;
164
+ margin-right: 30px;
165
+ margin-bottom: 5px;
166
+ }
167
+
168
+ .vert-cent {
169
+ position: relative;
170
+ top: 50%;
171
+ transform: translateY(-50%);
172
+ }
173
+
174
+ hr
175
+ {
176
+ border: 0;
177
+ height: 1px;
178
+ background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
179
+ }
180
+
181
+ .card {
182
+ /* width: 130px;
183
+ height: 195px;
184
+ width: 1px;
185
+ height: 1px; */
186
+ position: relative;
187
+ display: inline-block;
188
+ /* margin: 50px; */
189
+ }
190
+ .card .img-top {
191
+ display: none;
192
+ position: absolute;
193
+ top: 0;
194
+ left: 0;
195
+ z-index: 99;
196
+ }
197
+ .card:hover .img-top {
198
+ display: inline;
199
+ }
200
+ details {
201
+ user-select: none;
202
+ }
203
+
204
+ details>summary span.icon {
205
+ width: 24px;
206
+ height: 24px;
207
+ transition: all 0.3s;
208
+ margin-left: auto;
209
+ }
210
+
211
+ details[open] summary span.icon {
212
+ transform: rotate(180deg);
213
+ }
214
+
215
+ summary {
216
+ display: flex;
217
+ cursor: pointer;
218
+ }
219
+
220
+ summary::-webkit-details-marker {
221
+ display: none;
222
+ }
223
+
224
+ ul {
225
+ display: table;
226
+ margin: 0 auto;
227
+ text-align: left;
228
+ }
229
+
230
+ .dark {
231
+ padding: 1em 2em;
232
+ background-color: #333;
233
+ box-shadow: 3px 3px 3px #333;
234
+ border: 1px #333;
235
+ }
236
+ .column {
237
+ float: left;
238
+ width: 20%;
239
+ padding: 0.5%;
240
+ }
241
+
242
+ .galleryImg {
243
+ transition: opacity 0.3s;
244
+ -webkit-transition: opacity 0.3s;
245
+ filter: grayscale(100%);
246
+ /* filter: blur(2px); */
247
+ -webkit-transition : -webkit-filter 250ms linear;
248
+ /* opacity: 0.5; */
249
+ cursor: pointer;
250
+ }
251
+
252
+
253
+
254
+ .selected {
255
+ /* outline: 100px solid var(--hover-background) !important; */
256
+ /* outline-offset: -100px; */
257
+ filter: grayscale(0%);
258
+ -webkit-transition : -webkit-filter 250ms linear;
259
+ /*opacity: 1.0 !important; */
260
+ }
261
+
262
+ .galleryImg:hover {
263
+ filter: grayscale(0%);
264
+ -webkit-transition : -webkit-filter 250ms linear;
265
+
266
+ }
267
+
268
+ .row {
269
+ margin-bottom: 1em;
270
+ padding: 0px 1em;
271
+ }
272
+ /* Clear floats after the columns */
273
+ .row:after {
274
+ content: "";
275
+ display: table;
276
+ clear: both;
277
+ }
278
+
279
+ /* The expanding image container */
280
+ #gallery {
281
+ position: relative;
282
+ /*display: none;*/
283
+ }
284
+
285
+ #section_comparison{
286
+ position: relative;
287
+ width: 100%;
288
+ height: max-content;
289
+ }
290
+
291
+ /* SLIDER
292
+ -------------------------------------------------- */
293
+
294
+ .slider-container {
295
+ position: relative;
296
+ height: 384px;
297
+ width: 512px;
298
+ cursor: grab;
299
+ overflow: hidden;
300
+ margin: auto;
301
+ }
302
+ .slider-after {
303
+ display: block;
304
+ position: absolute;
305
+ top: 0;
306
+ right: 0;
307
+ bottom: 0;
308
+ left: 0;
309
+ width: 100%;
310
+ height: 100%;
311
+ overflow: hidden;
312
+ }
313
+ .slider-before {
314
+ display: block;
315
+ position: absolute;
316
+ top: 0;
317
+ /* right: 0; */
318
+ bottom: 0;
319
+ left: 0;
320
+ width: 100%;
321
+ height: 100%;
322
+ z-index: 15;
323
+ overflow: hidden;
324
+ }
325
+ .slider-before-inset {
326
+ position: absolute;
327
+ top: 0;
328
+ bottom: 0;
329
+ left: 0;
330
+ }
331
+ .slider-after img,
332
+ .slider-before img {
333
+ object-fit: cover;
334
+ position: absolute;
335
+ width: 100%;
336
+ height: 100%;
337
+ object-position: 50% 50%;
338
+ top: 0;
339
+ bottom: 0;
340
+ left: 0;
341
+ -webkit-user-select: none;
342
+ -khtml-user-select: none;
343
+ -moz-user-select: none;
344
+ -o-user-select: none;
345
+ user-select: none;
346
+ }
347
+
348
+ #lbl_inset_left{
349
+ text-align: center;
350
+ position: absolute;
351
+ top: 384px;
352
+ width: 150px;
353
+ left: calc(50% - 256px);
354
+ z-index: 11;
355
+ font-size: 16px;
356
+ color: #fff;
357
+ margin: 10px;
358
+ }
359
+ .inset-before {
360
+ position: absolute;
361
+ width: 150px;
362
+ height: 150px;
363
+ box-shadow: 3px 3px 3px #333;
364
+ border: 1px #333;
365
+ border-style: solid;
366
+ z-index: 16;
367
+ top: 410px;
368
+ left: calc(50% - 256px);
369
+ margin: 10px;
370
+ font-size: 1em;
371
+ background-repeat: no-repeat;
372
+ pointer-events: none;
373
+ }
374
+
375
+ #lbl_inset_right{
376
+ text-align: center;
377
+ position: absolute;
378
+ top: 384px;
379
+ width: 150px;
380
+ right: calc(50% - 256px);
381
+ z-index: 11;
382
+ font-size: 16px;
383
+ color: #fff;
384
+ margin: 10px;
385
+ }
386
+ .inset-after {
387
+ position: absolute;
388
+ width: 150px;
389
+ height: 150px;
390
+ box-shadow: 3px 3px 3px #333;
391
+ border: 1px #333;
392
+ border-style: solid;
393
+ z-index: 16;
394
+ top: 410px;
395
+ right: calc(50% - 256px);
396
+ margin: 10px;
397
+ font-size: 1em;
398
+ background-repeat: no-repeat;
399
+ pointer-events: none;
400
+ }
401
+
402
+ #lbl_inset_input{
403
+ text-align: center;
404
+ position: absolute;
405
+ top: 384px;
406
+ width: 150px;
407
+ left: calc(50% - 256px + 150px + 20px);
408
+ z-index: 11;
409
+ font-size: 16px;
410
+ color: #fff;
411
+ margin: 10px;
412
+ }
413
+ .inset-target {
414
+ position: absolute;
415
+ width: 150px;
416
+ height: 150px;
417
+ box-shadow: 3px 3px 3px #333;
418
+ border: 1px #333;
419
+ border-style: solid;
420
+ z-index: 16;
421
+ top: 410px;
422
+ right: calc(50% - 256px + 150px + 20px);
423
+ margin: 10px;
424
+ font-size: 1em;
425
+ background-repeat: no-repeat;
426
+ pointer-events: none;
427
+ }
428
+
429
+ .slider-beforePosition {
430
+ background: #121212;
431
+ color: #fff;
432
+ left: 0;
433
+ pointer-events: none;
434
+ border-radius: 0.2rem;
435
+ padding: 2px 10px;
436
+ }
437
+ .slider-afterPosition {
438
+ background: #121212;
439
+ color: #fff;
440
+ right: 0;
441
+ pointer-events: none;
442
+ border-radius: 0.2rem;
443
+ padding: 2px 10px;
444
+ }
445
+ .beforeLabel {
446
+ position: absolute;
447
+ top: 0;
448
+ margin: 1rem;
449
+ font-size: 1em;
450
+ -webkit-user-select: none;
451
+ -khtml-user-select: none;
452
+ -moz-user-select: none;
453
+ -o-user-select: none;
454
+ user-select: none;
455
+ }
456
+ .afterLabel {
457
+ position: absolute;
458
+ top: 0;
459
+ margin: 1rem;
460
+ font-size: 1em;
461
+ -webkit-user-select: none;
462
+ -khtml-user-select: none;
463
+ -moz-user-select: none;
464
+ -o-user-select: none;
465
+ user-select: none;
466
+ }
467
+
468
+ .slider-handle {
469
+ height: 41px;
470
+ width: 41px;
471
+ position: absolute;
472
+ left: 50%;
473
+ top: 50%;
474
+ margin-left: -20px;
475
+ margin-top: -21px;
476
+ border: 2px solid #fff;
477
+ border-radius: 1000px;
478
+ z-index: 20;
479
+ pointer-events: none;
480
+ box-shadow: 0 0 10px rgb(12, 12, 12);
481
+ }
482
+ .handle-left-arrow,
483
+ .handle-right-arrow {
484
+ width: 0;
485
+ height: 0;
486
+ border: 6px inset transparent;
487
+ position: absolute;
488
+ top: 50%;
489
+ margin-top: -6px;
490
+ }
491
+ .handle-left-arrow {
492
+ border-right: 6px solid #fff;
493
+ left: 50%;
494
+ margin-left: -17px;
495
+ }
496
+ .handle-right-arrow {
497
+ border-left: 6px solid #fff;
498
+ right: 50%;
499
+ margin-right: -17px;
500
+ }
501
+ .slider-handle::before {
502
+ bottom: 50%;
503
+ margin-bottom: 20px;
504
+ box-shadow: 0 0 10px rgb(12, 12, 12);
505
+ }
506
+ .slider-handle::after {
507
+ top: 50%;
508
+ margin-top: 20.5px;
509
+ box-shadow: 0 0 5px rgb(12, 12, 12);
510
+ }
511
+ .slider-handle::before,
512
+ .slider-handle::after {
513
+ content: " ";
514
+ display: block;
515
+ width: 2px;
516
+ background: #fff;
517
+ height: 9999px;
518
+ position: absolute;
519
+ left: 50%;
520
+ margin-left: -1.5px;
521
+ }
522
+
523
+
524
+ /*
525
+ -------------------------------------------------
526
+ The editing results shown below inversion results
527
+ -------------------------------------------------
528
+ */
529
+ .edit_labels{
530
+ font-weight:500;
531
+ font-size: 24px;
532
+ color: #fff;
533
+ height: 20px;
534
+ margin-left: 20px;
535
+ position: relative;
536
+ top: 20px;
537
+ }
538
+
539
+
540
+ .open > a:hover {
541
+ color: #555;
542
+ background-color: red;
543
+ }
544
+
545
+
546
+ #directions { padding-top:30; padding-bottom:0; margin-bottom: 0px; height: 20px; }
547
+ #custom_task { padding-top:0; padding-bottom:0; margin-bottom: 0px; height: 20px; }
548
+ #slider_ddim {accent-color: #941120;}
549
+ #slider_ddim::-webkit-slider-thumb {background-color: #941120;}
550
+ #slider_xa {accent-color: #941120;}
551
+ #slider_xa::-webkit-slider-thumb {background-color: #941120;}
552
+ #slider_edit_mul {accent-color: #941120;}
553
+ #slider_edit_mul::-webkit-slider-thumb {background-color: #941120;}
554
+
555
+ #input_image [data-testid="image"]{
556
+ height: unset;
557
+ }
558
+ #input_image_synth [data-testid="image"]{
559
+ height: unset;
560
+ }
561
+
562
+
563
+ """
564
+
565
+
566
+
567
+ HTML_header = f"""
568
+ <body>
569
+ <center>
570
+ <span style="font-size:36px">Zero-shot Image-to-Image Translation</span>
571
+ <table align=center>
572
+ <tr>
573
+ <td align=center>
574
+ <center>
575
+ <span style="font-size:24px; margin-left: 0px;"><a href='https://pix2pixzero.github.io/'>[Website]</a></span>
576
+ <span style="font-size:24px; margin-left: 20px;"><a href='https://github.com/pix2pixzero/pix2pix-zero'>[Code]</a></span>
577
+ </center>
578
+ </td>
579
+ </tr>
580
+ </table>
581
+ </center>
582
+
583
+ <center>
584
+ <div align=center>
585
+ <p align=left>
586
+ This is a demo for <span style="font-weight: bold;">pix2pix-zero</span>, a diffusion-based image-to-image approach that allows users to
587
+ specify the edit direction on-the-fly (e.g., cat to dog). Our method can directly use pre-trained text-to-image diffusion models, such as Stable Diffusion, for editing real and synthetic images while preserving the input image's structure. Our method is training-free and prompt-free, as it requires neither manual text prompting for each input image nor costly fine-tuning for each task.
588
+ <br>
589
+ <span style="font-weight: 800;">TL;DR:</span> <span style=" color: #941120;"> no finetuning</span> required; <span style=" color: #941120;"> no text input</span> needed; input <span style=" color: #941120;"> structure preserved</span>.
590
+ </p>
591
+ </div>
592
+ </center>
593
+
594
+
595
+ <hr>
596
+ </body>
597
+ """
598
+
599
+ HTML_input_header = f"""
600
+ <p style="font-size:150%; padding: 0px">
601
+ <span font-weight: 800; style=" color: #941120;"> Step 1: </span> select a real input image.
602
+ </p>
603
+ """
604
+
605
+ HTML_middle_header = f"""
606
+ <p style="font-size:150%;">
607
+ <span font-weight: 800; style=" color: #941120;"> Step 2: </span> select the editing options.
608
+ </p>
609
+ """
610
+
611
+
612
+ HTML_output_header = f"""
613
+ <p style="font-size:150%;">
614
+ <span font-weight: 800; style=" color: #941120;"> Step 3: </span> translated image!
615
+ </p>
616
+ """