jihuayang3 commited on
Commit
c971eb1
·
verified ·
1 Parent(s): 020be71

Update diffusers_image_outpaint_jupyter.ipynb

Browse files
Files changed (1) hide show
  1. diffusers_image_outpaint_jupyter.ipynb +237 -108
diffusers_image_outpaint_jupyter.ipynb CHANGED
@@ -1,108 +1,237 @@
1
- %cd /content
2
- !git clone -b dev https://github.com/camenduru/diffusers-image-outpaint-hf
3
- %cd /content/diffusers-image-outpaint-hf
4
- !pip install transformers accelerate diffusers
5
-
6
- !apt -y install -qq aria2
7
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/model_index.json -d /content/model/lightning -o model_index.json
8
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/scheduler/scheduler_config.json -d /content/model/lightning/scheduler -o scheduler_config.json
9
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/text_encoder/config.json -d /content/model/lightning/text_encoder -o config.json
10
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/resolve/main/lightning/text_encoder/model.fp16.safetensors -d /content/model/lightning/text_encoder -o model.fp16.safetensors
11
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/text_encoder_2/config.json -d /content/model/lightning/text_encoder_2 -o config.json
12
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/resolve/main/lightning/text_encoder_2/model.fp16.safetensors -d /content/model/lightning/text_encoder_2 -o model.fp16.safetensors
13
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/tokenizer/merges.txt -d /content/model/lightning/tokenizer -o merges.txt
14
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/tokenizer/special_tokens_map.json -d /content/model/lightning/tokenizer -o special_tokens_map.json
15
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/tokenizer/tokenizer_config.json -d /content/model/lightning/tokenizer -o tokenizer_config.json
16
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/tokenizer/vocab.json -d /content/model/lightning/tokenizer -o vocab.json
17
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/tokenizer_2/merges.txt -d /content/model/lightning/tokenizer_2 -o merges.txt
18
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/tokenizer_2/special_tokens_map.json -d /content/model/lightning/tokenizer_2 -o special_tokens_map.json
19
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/tokenizer_2/tokenizer_config.json -d /content/model/lightning/tokenizer_2 -o tokenizer_config.json
20
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/tokenizer_2/vocab.json -d /content/model/lightning/tokenizer_2 -o vocab.json
21
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/unet/config.json -d /content/model/lightning/unet -o config.json
22
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/resolve/main/lightning/unet/diffusion_pytorch_model.fp16.safetensors -d /content/model/lightning/unet -o diffusion_pytorch_model.fp16.safetensors
23
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/unet/diffusion_pytorch_model.safetensors.index.json -d /content/model/lightning/unet -o diffusion_pytorch_model.safetensors.index.json
24
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/lightning/vae/config.json -d /content/model/lightning/vae -o config.json
25
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/resolve/main/lightning/vae/diffusion_pytorch_model.fp16.safetensors -d /content/model/lightning/vae -o diffusion_pytorch_model.fp16.safetensors
26
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/vae-fix/config.json -d /content/model/vae-fix -o config.json
27
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/resolve/main/vae-fix/diffusion_pytorch_model.safetensors -d /content/model/vae-fix -o diffusion_pytorch_model.safetensors
28
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/raw/main/union/config_promax.json -d /content/model/union -o config_promax.json
29
- !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/outpaint/resolve/main/union/diffusion_pytorch_model_promax.safetensors -d /content/model/union -o diffusion_pytorch_model_promax.safetensors
30
-
31
-
32
- import torch
33
- from diffusers import AutoencoderKL
34
- from diffusers.models.model_loading_utils import load_state_dict
35
- from controlnet_union import ControlNetModel_Union
36
- from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
37
- from PIL import Image, ImageDraw
38
-
39
- config = ControlNetModel_Union.load_config("/content/model/union/config_promax.json")
40
- controlnet_model = ControlNetModel_Union.from_config(config)
41
- state_dict = load_state_dict("/content/model/union/diffusion_pytorch_model_promax.safetensors")
42
- model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(controlnet_model, state_dict, "/content/model/union/diffusion_pytorch_model_promax.safetensors", "/content/model/union")
43
- model.to(device="cuda", dtype=torch.float16)
44
- vae = AutoencoderKL.from_pretrained("/content/model/vae-fix", torch_dtype=torch.float16).to("cuda")
45
- pipe = StableDiffusionXLFillPipeline.from_pretrained("/content/model/lightning", torch_dtype=torch.float16, vae=vae, controlnet=model, variant="fp16").to("cuda")
46
-
47
- def infer(image, width, height, overlap_width, num_inference_steps, prompt_input=None):
48
- source = image
49
- target_size = (width, height)
50
- overlap = overlap_width
51
-
52
- if source.width < target_size[0] and source.height < target_size[1]:
53
- scale_factor = min(target_size[0] / source.width, target_size[1] / source.height)
54
- new_width = int(source.width * scale_factor)
55
- new_height = int(source.height * scale_factor)
56
- source = source.resize((new_width, new_height), Image.LANCZOS)
57
-
58
- if source.width > target_size[0] or source.height > target_size[1]:
59
- scale_factor = min(target_size[0] / source.width, target_size[1] / source.height)
60
- new_width = int(source.width * scale_factor)
61
- new_height = int(source.height * scale_factor)
62
- source = source.resize((new_width, new_height), Image.LANCZOS)
63
-
64
- margin_x = (target_size[0] - source.width) // 2
65
- margin_y = (target_size[1] - source.height) // 2
66
-
67
- background = Image.new('RGB', target_size, (255, 255, 255))
68
- background.paste(source, (margin_x, margin_y))
69
-
70
- mask = Image.new('L', target_size, 255)
71
- mask_draw = ImageDraw.Draw(mask)
72
- mask_draw.rectangle([
73
- (margin_x + overlap, margin_y + overlap),
74
- (margin_x + source.width - overlap, margin_y + source.height - overlap)
75
- ], fill=0)
76
-
77
- cnet_image = background.copy()
78
- cnet_image.paste(0, (0, 0), mask)
79
-
80
- final_prompt = "high quality"
81
- if prompt_input and prompt_input.strip():
82
- final_prompt += ", " + prompt_input
83
-
84
- (
85
- prompt_embeds,
86
- negative_prompt_embeds,
87
- pooled_prompt_embeds,
88
- negative_pooled_prompt_embeds,
89
- ) = pipe.encode_prompt(final_prompt, "cuda", True)
90
-
91
- results = []
92
-
93
- for image in pipe(
94
- prompt_embeds=prompt_embeds,
95
- negative_prompt_embeds=negative_prompt_embeds,
96
- pooled_prompt_embeds=pooled_prompt_embeds,
97
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
98
- image=cnet_image,
99
- num_inference_steps=num_inference_steps
100
- ):
101
- results.append((cnet_image, image))
102
-
103
- image = image.convert("RGBA")
104
- cnet_image.paste(image, (0, 0), mask)
105
-
106
- results.append((background, cnet_image))
107
-
108
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import AutoencoderKL, TCDScheduler
4
+ from diffusers.models.model_loading_utils import load_state_dict
5
+ from gradio_imageslider import ImageSlider
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ from controlnet_union import ControlNetModel_Union
9
+ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
10
+
11
+ from PIL import Image, ImageDraw
12
+ import numpy as np
13
+
14
+ MODELS = {
15
+ "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
16
+ }
17
+
18
+ config_file = hf_hub_download(
19
+ "xinsir/controlnet-union-sdxl-1.0",
20
+ filename="config_promax.json",
21
+ )
22
+
23
+ config = ControlNetModel_Union.load_config(config_file)
24
+ controlnet_model = ControlNetModel_Union.from_config(config)
25
+ model_file = hf_hub_download(
26
+ "xinsir/controlnet-union-sdxl-1.0",
27
+ filename="diffusion_pytorch_model_promax.safetensors",
28
+ )
29
+ state_dict = load_state_dict(model_file)
30
+ model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
31
+ controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
32
+ )
33
+ model.to(device="cuda", dtype=torch.float16)
34
+
35
+ vae = AutoencoderKL.from_pretrained(
36
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
37
+ ).to("cuda")
38
+
39
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
40
+ "SG161222/RealVisXL_V5.0_Lightning",
41
+ torch_dtype=torch.float16,
42
+ vae=vae,
43
+ controlnet=model,
44
+ variant="fp16",
45
+ ).to("cuda")
46
+
47
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
48
+
49
+
50
+ def infer(image, model_selection, width, height, overlap_width, num_inference_steps, prompt_input=None):
51
+ source = image
52
+ target_size = (width, height)
53
+ target_ratio = (width, height) # Calculate aspect ratio from width and height
54
+ overlap = overlap_width
55
+
56
+ # Upscale if source is smaller than target in both dimensions
57
+ if source.width < target_size[0] and source.height < target_size[1]:
58
+ scale_factor = min(target_size[0] / source.width, target_size[1] / source.height)
59
+ new_width = int(source.width * scale_factor)
60
+ new_height = int(source.height * scale_factor)
61
+ source = source.resize((new_width, new_height), Image.LANCZOS)
62
+
63
+ if source.width > target_size[0] or source.height > target_size[1]:
64
+ scale_factor = min(target_size[0] / source.width, target_size[1] / source.height)
65
+ new_width = int(source.width * scale_factor)
66
+ new_height = int(source.height * scale_factor)
67
+ source = source.resize((new_width, new_height), Image.LANCZOS)
68
+
69
+ margin_x = (target_size[0] - source.width) // 2
70
+ margin_y = (target_size[1] - source.height) // 2
71
+
72
+ background = Image.new('RGB', target_size, (255, 255, 255))
73
+ background.paste(source, (margin_x, margin_y))
74
+
75
+ mask = Image.new('L', target_size, 255)
76
+ mask_draw = ImageDraw.Draw(mask)
77
+ mask_draw.rectangle([
78
+ (margin_x + overlap, margin_y + overlap),
79
+ (margin_x + source.width - overlap, margin_y + source.height - overlap)
80
+ ], fill=0)
81
+
82
+ cnet_image = background.copy()
83
+ cnet_image.paste(0, (0, 0), mask)
84
+
85
+ final_prompt = "high quality"
86
+ if prompt_input.strip() != "":
87
+ final_prompt += ", " + prompt_input
88
+
89
+ (
90
+ prompt_embeds,
91
+ negative_prompt_embeds,
92
+ pooled_prompt_embeds,
93
+ negative_pooled_prompt_embeds,
94
+ ) = pipe.encode_prompt(final_prompt, "cuda", True)
95
+
96
+ for image in pipe(
97
+ prompt_embeds=prompt_embeds,
98
+ negative_prompt_embeds=negative_prompt_embeds,
99
+ pooled_prompt_embeds=pooled_prompt_embeds,
100
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
101
+ image=cnet_image,
102
+ num_inference_steps=num_inference_steps
103
+ ):
104
+ yield cnet_image, image
105
+
106
+ image = image.convert("RGBA")
107
+ cnet_image.paste(image, (0, 0), mask)
108
+
109
+ yield background, cnet_image
110
+
111
+ def preload_presets(target_ratio):
112
+ if target_ratio == "9:16":
113
+ changed_width = 720
114
+ changed_height = 1280
115
+ return changed_width, changed_height, gr.update(open=False)
116
+ elif target_ratio == "16:9":
117
+ changed_width = 1280
118
+ changed_height = 720
119
+ return changed_width, changed_height, gr.update(open=False)
120
+ elif target_ratio == "Custom":
121
+ return 720, 1280, gr.update(open=True)
122
+
123
+ def clear_result():
124
+ return gr.update(value=None)
125
+
126
+
127
+ css = """
128
+ .gradio-container {
129
+ width: 1200px !important;
130
+ }
131
+ """
132
+
133
+
134
+ title = """<h1 align="center">Diffusers Image Outpaint</h1>
135
+ <div align="center">Drop an image you would like to extend, pick your expected ratio and hit Generate.</div>
136
+ """
137
+
138
+ with gr.Blocks(css=css) as demo:
139
+ with gr.Column():
140
+ gr.HTML(title)
141
+
142
+ with gr.Row():
143
+ with gr.Column():
144
+ input_image = gr.Image(
145
+ type="pil",
146
+ label="Input Image",
147
+ sources=["upload"],
148
+ height = 300
149
+ )
150
+
151
+ prompt_input = gr.Textbox(label="Prompt (Optional)")
152
+
153
+ with gr.Row():
154
+ target_ratio = gr.Radio(
155
+ label = "Expected Ratio",
156
+ choices = ["9:16", "16:9", "Custom"],
157
+ value = "9:16",
158
+ scale = 2
159
+ )
160
+
161
+ run_button = gr.Button("Generate", scale=1)
162
+
163
+ with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
164
+ with gr.Column():
165
+ with gr.Row():
166
+ width_slider = gr.Slider(
167
+ label="Width",
168
+ minimum=720,
169
+ maximum=1440,
170
+ step=8,
171
+ value=720, # Set a default value
172
+ )
173
+ height_slider = gr.Slider(
174
+ label="Height",
175
+ minimum=720,
176
+ maximum=1440,
177
+ step=8,
178
+ value=1280, # Set a default value
179
+ )
180
+ with gr.Row():
181
+ model_selection = gr.Dropdown(
182
+ choices=list(MODELS.keys()),
183
+ value="RealVisXL V5.0 Lightning",
184
+ label="Model",
185
+ )
186
+ num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8 )
187
+
188
+ overlap_width = gr.Slider(
189
+ label="Mask overlap width",
190
+ minimum=1,
191
+ maximum=50,
192
+ value=42,
193
+ step=1
194
+ )
195
+
196
+ gr.Examples(
197
+ examples=[
198
+ ["./examples/example_1.webp", "RealVisXL V5.0 Lightning", 1280, 720],
199
+ ["./examples/example_2.jpg", "RealVisXL V5.0 Lightning", 720, 1280],
200
+ ["./examples/example_3.jpg", "RealVisXL V5.0 Lightning", 1024, 1024],
201
+ ],
202
+ inputs=[input_image, model_selection, width_slider, height_slider],
203
+ )
204
+
205
+ with gr.Column():
206
+ result = ImageSlider(
207
+ interactive=False,
208
+ label="Generated Image",
209
+ )
210
+
211
+ target_ratio.change(
212
+ fn = preload_presets,
213
+ inputs = [target_ratio],
214
+ outputs = [width_slider, height_slider, settings_panel],
215
+ queue = False
216
+ )
217
+ run_button.click(
218
+ fn=clear_result,
219
+ inputs=None,
220
+ outputs=result,
221
+ ).then(
222
+ fn=infer,
223
+ inputs=[input_image, model_selection, width_slider, height_slider, overlap_width, num_inference_steps, prompt_input],
224
+ outputs=result,
225
+ )
226
+
227
+ prompt_input.submit(
228
+ fn=clear_result,
229
+ inputs=None,
230
+ outputs=result,
231
+ ).then(
232
+ fn=infer,
233
+ inputs=[input_image, model_selection, width_slider, height_slider, overlap_width, num_inference_steps, prompt_input],
234
+ outputs=result,
235
+ )
236
+
237
+ demo.queue().launch(share=True, show_error=True, show_api=True, inline=False)