caohy666 commited on
Commit
ba53076
·
1 Parent(s): 2060744

<feat> complete app.py.

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +261 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
app.py CHANGED
@@ -18,10 +18,13 @@ from peft import LoraConfig
18
  from omegaconf import OmegaConf
19
  from safetensors.torch import safe_open
20
  from PIL import Image, ImageDraw, ImageFilter
 
 
21
 
22
  from models import HunyuanVideoTransformer3DModel
23
  from pipelines import HunyuanVideoImageToVideoPipeline
24
 
 
25
  header = """
26
  # DRA-Ctrl Gradio App
27
 
@@ -33,9 +36,267 @@ header = """
33
  </div>
34
  """
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def create_app():
37
  with gr.Blocks() as app:
38
  gr.Markdown(header, elem_id="header")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  return app
41
 
 
18
  from omegaconf import OmegaConf
19
  from safetensors.torch import safe_open
20
  from PIL import Image, ImageDraw, ImageFilter
21
+ from huggingface_hub import hf_hub_download
22
+ from transformers import pipeline
23
 
24
  from models import HunyuanVideoTransformer3DModel
25
  from pipelines import HunyuanVideoImageToVideoPipeline
26
 
27
+
28
  header = """
29
  # DRA-Ctrl Gradio App
30
 
 
36
  </div>
37
  """
38
 
39
+ notice = """
40
+ For easier testing, in spatially-aligned image generation tasks, when passing the condition image to `gradio_app`,
41
+ there's no need to manually input edge maps, depth maps, or other condition images - only the original image is required.
42
+ The corresponding condition images will be automatically extracted.
43
+ """
44
+
45
+
46
+
47
+ @spaces.GPU
48
+ def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task):
49
+ # init models
50
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V',
51
+ subfolder="transformer",
52
+ inference_subject_driven=task in ['subject_driven'])
53
+ scheduler = diffusers.FlowMatchEulerDiscreteScheduler()
54
+ vae = diffusers.AutoencoderKLHunyuanVideo.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V',
55
+ subfolder="vae")
56
+ text_encoder = transformers.LlavaForConditionalGeneration.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V',
57
+ subfolder="text_encoder")
58
+ text_encoder_2 = transformers.CLIPTextModel.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V',
59
+ subfolder="text_encoder_2")
60
+ tokenizer = transformers.AutoTokenizer.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V',
61
+ subfolder="tokenizer")
62
+ tokenizer_2 = transformers.CLIPTokenizer.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V',
63
+ subfolder="tokenizer_2")
64
+ image_processor = transformers.CLIPImageProcessor.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V',
65
+ subfolder="image_processor")
66
+
67
+ device = "cuda" if torch.cuda.is_available() else "cpu"
68
+ weight_dtype = torch.bfloat16
69
+
70
+ transformer.requires_grad_(False)
71
+ vae.requires_grad_(False).to(device, dtype=weight_dtype)
72
+ text_encoder.requires_grad_(False).to(device, dtype=weight_dtype)
73
+ text_encoder_2.requires_grad_(False).to(device, dtype=weight_dtype)
74
+ transformer.to(device, dtype=weight_dtype)
75
+ vae.enable_tiling()
76
+ vae.enable_slicing()
77
+
78
+ # insert LoRA
79
+ lora_config = LoraConfig(
80
+ r=16,
81
+ lora_alpha=16,
82
+ init_lora_weights="gaussian",
83
+ target_modules=[
84
+ 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0',
85
+ 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out',
86
+ 'ff.net.0.proj', 'ff.net.2',
87
+ 'ff_context.net.0.proj', 'ff_context.net.2',
88
+ 'norm1_context.linear', 'norm1.linear',
89
+ 'norm.linear', 'proj_mlp', 'proj_out',
90
+ ]
91
+ )
92
+ transformer.add_adapter(lora_config)
93
+
94
+ # hack LoRA forward
95
+ def create_hacked_forward(module):
96
+ lora_forward = module.forward
97
+ non_lora_forward = module.base_layer.forward
98
+ img_sequence_length = int((args.img_size / 8 / 2) ** 2)
99
+ encoder_sequence_length = 144 + 252 # encoder sequence: 144 img 252 txt
100
+ num_imgs = 4
101
+ num_generated_imgs = 3
102
+ num_encoder_sequences = 2 if args.task in ['subject_driven', 'style_transfer'] else 1
103
+
104
+ def hacked_lora_forward(self, x, *args, **kwargs):
105
+ if x.shape[1] == img_sequence_length * num_imgs and len(x.shape) > 2:
106
+ return torch.cat((
107
+ lora_forward(x[:, :-img_sequence_length*num_generated_imgs], *args, **kwargs),
108
+ non_lora_forward(x[:, -img_sequence_length*num_generated_imgs:], *args, **kwargs)
109
+ ), dim=1)
110
+ elif x.shape[1] == encoder_sequence_length * num_encoder_sequences or x.shape[1] == encoder_sequence_length:
111
+ return lora_forward(x, *args, **kwargs)
112
+ elif x.shape[1] == img_sequence_length * num_imgs + encoder_sequence_length * num_encoder_sequences:
113
+ return torch.cat((
114
+ lora_forward(x[:, :(num_imgs - num_generated_imgs)*img_sequence_length], *args, **kwargs),
115
+ non_lora_forward(x[:, (num_imgs - num_generated_imgs)*img_sequence_length:-num_encoder_sequences*encoder_sequence_length], *args, **kwargs),
116
+ lora_forward(x[:, -num_encoder_sequences*encoder_sequence_length:], *args, **kwargs)
117
+ ), dim=1)
118
+ elif x.shape[1] == 3072:
119
+ return non_lora_forward(x, *args, **kwargs)
120
+ else:
121
+ raise ValueError(
122
+ f"hacked_lora_forward receives unexpected sequence length: {x.shape[1]}, input shape: {x.shape}!"
123
+ )
124
+
125
+ return hacked_lora_forward.__get__(module, type(module))
126
+
127
+ for n, m in transformer.named_modules():
128
+ if isinstance(m, peft.tuners.lora.layer.Linear):
129
+ m.forward = create_hacked_forward(m)
130
+
131
+ # load LoRA weights
132
+ model_root = hf_hub_download(
133
+ repo_id="Kunbyte/DRA-Ctrl",
134
+ filename=f"{task}.safetensors",
135
+ resume_download=True)
136
+
137
+ try:
138
+ with safe_open(model_root, framework="pt") as f:
139
+ lora_weights = {}
140
+ for k in f.keys():
141
+ param = f.get_tensor(k)
142
+ if k.endswith(".weight"):
143
+ k = k.replace('.weight', '.default.weight')
144
+ lora_weights[k] = param
145
+ transformer.load_state_dict(lora_weights, strict=False)
146
+ except Exception as e:
147
+ raise ValueError(f'{e}')
148
+
149
+ transformer.requires_grad_(False)
150
+
151
+ pipe = HunyuanVideoImageToVideoPipeline(
152
+ text_encoder=text_encoder,
153
+ tokenizer=tokenizer,
154
+ transformer=transformer,
155
+ vae=vae,
156
+ scheduler=copy.deepcopy(scheduler),
157
+ text_encoder_2=text_encoder_2,
158
+ tokenizer_2=tokenizer_2,
159
+ image_processor=image_processor,
160
+ )
161
+
162
+ # start generation
163
+ c_txt = None if condition_image_prompt == "" else condition_image_prompt
164
+ c_img = condition_image.resize((512, 512))
165
+ t_txt = target_prompt
166
+
167
+ if args.task not in ['subject_driven', 'style_transfer']:
168
+ if args.task == "canny":
169
+ def get_canny_edge(img):
170
+ img_np = np.array(img)
171
+ img_gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
172
+ edges = cv2.Canny(img_gray, 100, 200)
173
+ edges_tmp = Image.fromarray(edges).convert("RGB")
174
+ edges_tmp.save(os.path.join(save_dir, f"edges.png"))
175
+ edges[edges == 0] = 128
176
+ return Image.fromarray(edges).convert("RGB")
177
+ c_img = get_canny_edge(c_img)
178
+ elif args.task == "coloring":
179
+ c_img = (
180
+ c_img.resize((args.img_size, args.img_size))
181
+ .convert("L")
182
+ .convert("RGB")
183
+ )
184
+ elif args.task == "deblurring":
185
+ blur_radius = 10
186
+ c_img = (
187
+ c_img.convert("RGB")
188
+ .filter(ImageFilter.GaussianBlur(blur_radius))
189
+ .resize((args.img_size, args.img_size))
190
+ .convert("RGB")
191
+ )
192
+ elif args.task == "depth":
193
+ def get_depth_map(img):
194
+ from transformers import pipeline
195
+
196
+ depth_pipe = pipeline(
197
+ task="depth-estimation",
198
+ model="LiheYoung/depth-anything-small-hf",
199
+ device="cpu",
200
+ )
201
+ return depth_pipe(img)["depth"].convert("RGB").resize((args.img_size, args.img_size))
202
+ c_img = get_depth_map(c_img)
203
+ c_img.save(os.path.join(save_dir, f"depth.png"))
204
+ k = (255 - 128) / 255
205
+ b = 128
206
+ c_img = c_img.point(lambda x: k * x + b)
207
+ elif args.task == "depth_pred":
208
+ c_img = c_img
209
+ elif args.task == "fill":
210
+ c_img = c_img.resize((args.img_size, args.img_size)).convert("RGB")
211
+ x1, x2 = args.fill_x1, args.fill_x2
212
+ y1, y2 = args.fill_y1, args.fill_y2
213
+ mask = Image.new("L", (args.img_size, args.img_size), 0)
214
+ draw = ImageDraw.Draw(mask)
215
+ draw.rectangle((x1, y1, x2, y2), fill=255)
216
+ if args.inpainting:
217
+ mask = Image.eval(mask, lambda a: 255 - a)
218
+ c_img = Image.composite(
219
+ c_img,
220
+ Image.new("RGB", (args.img_size, args.img_size), (255, 255, 255)),
221
+ mask
222
+ )
223
+ c_img.save(os.path.join(save_dir, f"mask.png"))
224
+ c_img = Image.composite(
225
+ c_img,
226
+ Image.new("RGB", (args.img_size, args.img_size), (128, 128, 128)),
227
+ mask
228
+ )
229
+ elif args.task == "sr":
230
+ c_img = c_img.resize((int(args.img_size / 4), int(args.img_size / 4))).convert("RGB")
231
+ c_img.save(os.path.join(save_dir, f"low_resolution.png"))
232
+ c_img = c_img.resize((args.img_size, args.img_size))
233
+ c_img.save(os.path.join(save_dir, f"low_to_high.png"))
234
+
235
+ gen_img = pipe(
236
+ image=c_img,
237
+ prompt=[t_txt.strip()],
238
+ prompt_condition=[c_txt.strip()] if c_txt is not None else None,
239
+ prompt_2=[t_txt],
240
+ height=512,
241
+ width=512,
242
+ num_frames=5,
243
+ num_inference_steps=50,
244
+ guidance_scale=6.0,
245
+ num_videos_per_prompt=1,
246
+ generator=torch.Generator(device=pipe.transformer.device).manual_seed(0),
247
+ output_type='pt',
248
+ image_embed_interleave=4,
249
+ frame_gap=48,
250
+ mixup=True,
251
+ mixup_num_imgs=2,
252
+ ).frames
253
+
254
+ gen_img = gen_img[:, 0:1, :, :, :]
255
+ gen_img = gen_img.squeeze(0).squeeze(0).cpu().to(torch.float32).numpy()
256
+ gen_img = np.transpose(gen_img, (1, 2, 0))
257
+ gen_img = (gen_img * 255).astype(np.uint8)
258
+ gen_img = Image.fromarray(gen_img)
259
+
260
+ return gen_img
261
+
262
  def create_app():
263
  with gr.Blocks() as app:
264
  gr.Markdown(header, elem_id="header")
265
+ with gr.Row(equal_height=False):
266
+ with gr.Column(variant="panel", elem_classes="inputPanel"):
267
+ condition_image = gr.Image(
268
+ type="pil", label="Condition Image", width=300, elem_id="input"
269
+ )
270
+ task = gr.Radio(
271
+ [
272
+ ("Subject-driven Image Generation", "subject_driven"),
273
+ ("Canny-to-Image", "canny"),
274
+ ("Colorization", "coloring"),
275
+ ("Deblurring", "deblurring"),
276
+ ("Depth-to-Image", "depth"),
277
+ ("Depth Prediction", "depth_pred"),
278
+ ("In/Out-Painting", "fill"),
279
+ ("Super-Resolution", "sr"),
280
+ ("Style Transfer", "style_transfer")
281
+ ],
282
+ label="Task Selection",
283
+ value="subject_driven",
284
+ interactive=True,
285
+ elem_id="task_selection"
286
+ )
287
+ gr.Markdown(notice, elem_id="notice")
288
+ target_prompt = gr.Textbox(lines=2, label="Target Prompt", elem_id="text")
289
+ condition_image_prompt = gr.Textbox(lines=2, label="Condition Image Prompt", elem_id="text")
290
+ submit_btn = gr.Button("Run", elem_id="submit_btn")
291
+
292
+ with gr.Column(variant="panel", elem_classes="outputPanel"):
293
+ output_image = gr.Image(type="pil", elem_id="output")
294
+
295
+ submit_btn.click(
296
+ fn=process_image_and_text,
297
+ inputs=[condition_image, target_prompt, condition_image_prompt, task],
298
+ outputs=output_image,
299
+ )
300
 
301
  return app
302