George Krupenchenkov commited on
Commit
796d285
·
1 Parent(s): 7c45782
Files changed (1) hide show
  1. app.py +128 -33
app.py CHANGED
@@ -1,12 +1,15 @@
1
  import os
2
  import random
3
 
 
4
  import gradio as gr
5
  import numpy as np
6
  import torch
7
  # import spaces #[uncomment to use ZeroGPU]
8
- from diffusers import StableDiffusionPipeline
 
9
  from peft import LoraConfig, PeftModel
 
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  # model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
@@ -14,11 +17,12 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
14
  # model_dropdown = ["stabilityai/sdxl-turbo", "CompVis/stable-diffusion-v1-4"]
15
 
16
  models = [
17
- "gstranger/kawaiicat-lora-1.4",
18
  "CompVis/stable-diffusion-v1-4",
19
  "stabilityai/sdxl-turbo",
20
  "sd-legacy/stable-diffusion-v1-5",
21
  ]
 
22
 
23
  model_dropdown = [
24
  "stabilityai/sdxl-turbo",
@@ -26,6 +30,14 @@ model_dropdown = [
26
  "sd-legacy/stable-diffusion-v1-5",
27
  ]
28
 
 
 
 
 
 
 
 
 
29
 
30
  if torch.cuda.is_available():
31
  torch_dtype = torch.float16
@@ -81,47 +93,92 @@ def infer(
81
  model_id,
82
  prompt,
83
  negative_prompt,
84
- randomize_seed,
85
- width,
86
- height,
87
- # model_repo_id=model_repo_id,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  seed=42,
89
  guidance_scale=7,
90
  num_inference_steps=50,
91
  progress=gr.Progress(track_tqdm=True),
92
- lora_scale=1,
93
  ):
94
  if randomize_seed:
95
  seed = random.randint(0, MAX_SEED)
 
 
96
 
97
  generator = torch.Generator().manual_seed(seed)
98
 
99
- if model_id == "gstranger/kawaiicat-lora-1.4":
100
- # добавляем lora
101
- pipe = get_lora_sd_pipeline(
102
- os.path.join(CKPT_DIR, ""), adapter_name="sd-14-lora", dtype=torch_dtype
103
- ).to(device)
104
- pipe.safety_checker = None
105
- print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  else:
108
- pipe = StableDiffusionPipeline.from_pretrained(
109
- model_id,
110
- torch_dtype=torch_dtype,
111
- requires_safety_checker=False,
112
- safety_checker=None,
 
 
 
 
 
 
 
113
  )
114
- pipe = pipe.to(device)
115
-
116
- image = pipe(
117
- prompt=prompt,
118
- negative_prompt=negative_prompt,
119
- guidance_scale=guidance_scale,
120
- num_inference_steps=num_inference_steps,
121
- width=width,
122
- height=height,
123
- generator=generator,
124
- cross_attention_kwargs={"scale": lora_scale},
 
 
 
 
 
125
  ).images[0]
126
 
127
  return image, seed
@@ -155,9 +212,23 @@ with gr.Blocks(css=css) as demo:
155
  minimum=0,
156
  maximum=1,
157
  step=0.01,
158
- value=1,
159
  )
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  with gr.Row():
162
  prompt = gr.Text(
163
  label="Prompt",
@@ -165,8 +236,16 @@ with gr.Blocks(css=css) as demo:
165
  max_lines=1,
166
  placeholder="Enter your prompt",
167
  container=False,
 
 
 
 
 
 
 
168
  )
169
 
 
170
  run_button = gr.Button("Run", scale=0, variant="primary")
171
 
172
  result = gr.Image(label="Result", show_label=False)
@@ -187,7 +266,7 @@ with gr.Blocks(css=css) as demo:
187
  value=42,
188
  )
189
 
190
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
191
 
192
  with gr.Row():
193
  width = gr.Slider(
@@ -224,6 +303,18 @@ with gr.Blocks(css=css) as demo:
224
  )
225
 
226
  gr.Examples(examples=examples, inputs=[prompt])
 
 
 
 
 
 
 
 
 
 
 
 
227
  gr.on(
228
  triggers=[run_button.click, prompt.submit],
229
  fn=infer,
@@ -237,10 +328,14 @@ with gr.Blocks(css=css) as demo:
237
  seed,
238
  guidance_scale,
239
  num_inference_steps,
240
- lora_scale,
 
 
 
241
  ],
242
  outputs=[result, seed],
243
  )
244
 
 
245
  if __name__ == "__main__":
246
  demo.launch()
 
1
  import os
2
  import random
3
 
4
+ import cv2
5
  import gradio as gr
6
  import numpy as np
7
  import torch
8
  # import spaces #[uncomment to use ZeroGPU]
9
+ from diffusers import (ControlNetModel, StableDiffusionControlNetPipeline,
10
+ StableDiffusionPipeline)
11
  from peft import LoraConfig, PeftModel
12
+ from PIL import Image
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  # model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
17
  # model_dropdown = ["stabilityai/sdxl-turbo", "CompVis/stable-diffusion-v1-4"]
18
 
19
  models = [
20
+ # "gstranger/kawaiicat-lora-1.4",
21
  "CompVis/stable-diffusion-v1-4",
22
  "stabilityai/sdxl-turbo",
23
  "sd-legacy/stable-diffusion-v1-5",
24
  ]
25
+ controlnet_modes = ["canny", "Line Art"]
26
 
27
  model_dropdown = [
28
  "stabilityai/sdxl-turbo",
 
30
  "sd-legacy/stable-diffusion-v1-5",
31
  ]
32
 
33
+ def process_control_image(image, mode="canny"):
34
+ if mode == "canny":
35
+ image = np.array(image)
36
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
37
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
38
+ canny = cv2.Canny(blurred, 50, 150)
39
+ return Image.fromarray(canny)
40
+ return image
41
 
42
  if torch.cuda.is_available():
43
  torch_dtype = torch.float16
 
93
  model_id,
94
  prompt,
95
  negative_prompt,
96
+ randomize_seed=False,
97
+ width=512,
98
+ height=512,
99
+
100
+ lora_scale=0.8,
101
+ lora_enable=True,
102
+
103
+ controlnet_enable=False,
104
+ control_mode="Line Art",
105
+ control_strength=0.8,
106
+ control_image=None,
107
+
108
+ ip_adapter_enable=False,
109
+ ip_adapter_scale=0.8,
110
+ ip_image=None,
111
+
112
+
113
+ torch_dtype=torch_dtype,
114
  seed=42,
115
  guidance_scale=7,
116
  num_inference_steps=50,
117
  progress=gr.Progress(track_tqdm=True),
 
118
  ):
119
  if randomize_seed:
120
  seed = random.randint(0, MAX_SEED)
121
+ else:
122
+ seed = 488
123
 
124
  generator = torch.Generator().manual_seed(seed)
125
 
126
+ params = {'prompt': prompt,
127
+ 'negative_prompt': negative_prompt,
128
+ 'guidance_scale': guidance_scale,
129
+ 'num_inference_steps': num_inference_steps,
130
+ 'width': width,
131
+ 'height': height,
132
+ 'generator': generator,
133
+ }
134
+
135
+ if controlnet_enable:
136
+ if control_mode == "canny":
137
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny",
138
+ torch_dtype=torch_dtype, cache_dir="./models_cache")
139
+ elif control_mode == "Line Art":
140
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_lineart",
141
+ torch_dtype=torch_dtype, cache_dir="./models_cache")
142
+
143
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id,
144
+ controlnet=controlnet,
145
+ torch_dtype=torch_dtype,
146
+ safety_checker=None) #.to(device)
147
+
148
+ params['image'] = process_control_image(control_image, control_mode)
149
+ params['controlnet_conditioning_scale'] = float(control_strength)
150
+
151
 
152
  else:
153
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,
154
+ torch_dtype=torch_dtype,
155
+ safety_checker=None) #.to(device)
156
+
157
+ if lora_enable:
158
+ unet_sub_dir = os.path.join(CKPT_DIR, "unet")
159
+ text_encoder_sub_dir = os.path.join(CKPT_DIR, "text_encoder")
160
+ adapter_name="sd-14-lora"
161
+
162
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
163
+ pipe.text_encoder = PeftModel.from_pretrained(
164
+ pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name
165
  )
166
+ params['cross_attention_kwargs']={"scale": lora_scale}
167
+
168
+ if torch_dtype in (torch.float16, torch.bfloat16):
169
+ pipe.unet.half()
170
+ pipe.text_encoder.half()
171
+
172
+ if ip_adapter_enable:
173
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
174
+ pipe.set_ip_adapter_scale(ip_adapter_scale)
175
+ params['ip_adapter_image'] = process_control_image(ip_image, "")
176
+
177
+
178
+ # pipe.to(device)
179
+
180
+
181
+ image = pipe(**params
182
  ).images[0]
183
 
184
  return image, seed
 
212
  minimum=0,
213
  maximum=1,
214
  step=0.01,
215
+ value=0.8,
216
  )
217
 
218
+ lora_enable = gr.Checkbox(label="Use LORA", value=True)
219
+
220
+ with gr.Columns():
221
+ controlnet_enable = gr.Checkbox(label="Enable ControlNet")
222
+ with gr.Accordion("ControlNet Settings", visible=False) as controlnet_accordion:
223
+ control_mode = gr.Dropdown(controlnet_modes, label="Control Mode", value="canny")
224
+ control_strength = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label="Control Strength")
225
+ control_image = gr.Image(label="Control Image", type="pil")
226
+
227
+ ip_adapter_enable = gr.Checkbox(label="Enable IP-Adapter")
228
+ with gr.Accordion("IP-Adapter Settings", visible=False) as ipadapter_accordion:
229
+ ip_adapter_scale = gr.Slider(0, 1, value=0.5, label="IP-Adapter Scale")
230
+ ip_image = gr.Image(label="Reference Image", type="pil")
231
+
232
  with gr.Row():
233
  prompt = gr.Text(
234
  label="Prompt",
 
236
  max_lines=1,
237
  placeholder="Enter your prompt",
238
  container=False,
239
+ )
240
+
241
+ negative_prompt = gr.Textbox(
242
+ label="Negative prompt",
243
+ max_lines=1,
244
+ placeholder="Enter your negative prompt",
245
+ value="bad anatomy, crop image, bad face of the cat"
246
  )
247
 
248
+
249
  run_button = gr.Button("Run", scale=0, variant="primary")
250
 
251
  result = gr.Image(label="Result", show_label=False)
 
266
  value=42,
267
  )
268
 
269
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
270
 
271
  with gr.Row():
272
  width = gr.Slider(
 
303
  )
304
 
305
  gr.Examples(examples=examples, inputs=[prompt])
306
+
307
+ controlnet_enable.change(
308
+ lambda x: gr.update(visible=x),
309
+ controlnet_enable,
310
+ controlnet_accordion
311
+ )
312
+ ip_adapter_enable.change(
313
+ lambda x: gr.update(visible=x),
314
+ ip_adapter_enable,
315
+ ipadapter_accordion
316
+ )
317
+
318
  gr.on(
319
  triggers=[run_button.click, prompt.submit],
320
  fn=infer,
 
328
  seed,
329
  guidance_scale,
330
  num_inference_steps,
331
+ lora_enable, lora_scale,
332
+ controlnet_enable, control_mode, control_strength, control_image,
333
+ ip_adapter_enable, ip_adapter_scale, ip_image
334
+
335
  ],
336
  outputs=[result, seed],
337
  )
338
 
339
+
340
  if __name__ == "__main__":
341
  demo.launch()