alex commited on
Commit
76cd760
·
1 Parent(s): fd0980c

further optimisation

Browse files
generate.py CHANGED
@@ -114,7 +114,7 @@ def _parse_args():
114
  args.size = "720*1280"
115
  args.frame_num = None
116
  args.ckpt_dir = "./Wan2.2-Animate-14B/"
117
- args.offload_model = True
118
  args.ulysses_size = 1
119
  args.t5_fsdp = False
120
  args.t5_cpu = False
@@ -130,7 +130,7 @@ def _parse_args():
130
  args.sample_steps = None
131
  args.sample_shift = None
132
  args.sample_guide_scale = None
133
- args.convert_model_dtype = False
134
 
135
  # animate
136
  args.refert_num = 1
 
114
  args.size = "720*1280"
115
  args.frame_num = None
116
  args.ckpt_dir = "./Wan2.2-Animate-14B/"
117
+ args.offload_model = False
118
  args.ulysses_size = 1
119
  args.t5_fsdp = False
120
  args.t5_cpu = False
 
130
  args.sample_steps = None
131
  args.sample_shift = None
132
  args.sample_guide_scale = None
133
+ args.convert_model_dtype = True
134
 
135
  # animate
136
  args.refert_num = 1
wan/animate.py CHANGED
@@ -131,6 +131,8 @@ class WanAnimate:
131
  checkpoint_dir=checkpoint_dir,
132
  config=config
133
  )
 
 
134
 
135
  if use_sp:
136
  self.sp_size = get_world_size()
 
131
  checkpoint_dir=checkpoint_dir,
132
  config=config
133
  )
134
+
135
+ # self.noise_model = torch.compile(self.noise_model)
136
 
137
  if use_sp:
138
  self.sp_size = get_world_size()
wan/modules/animate/model_animate.py CHANGED
@@ -313,11 +313,12 @@ class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
313
  # buffers (don't use register_buffer otherwise dtype will be changed in to())
314
  assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
315
  d = dim // num_heads
316
- self.freqs = torch.cat([
317
  rope_params(1024, d - 4 * (d // 6)),
318
  rope_params(1024, 2 * (d // 6)),
319
  rope_params(1024, 2 * (d // 6))
320
  ], dim=1)
 
321
 
322
  self.img_emb = MLPProj(1280, dim)
323
 
@@ -381,9 +382,7 @@ class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
381
  face_pixel_values=None
382
  ):
383
  # params
384
- device = self.patch_embedding.weight.device
385
- if self.freqs.device != device:
386
- self.freqs = self.freqs.to(device)
387
 
388
  if y is not None:
389
  x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@@ -428,7 +427,7 @@ class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
428
  e=e0,
429
  seq_lens=seq_lens,
430
  grid_sizes=grid_sizes,
431
- freqs=self.freqs,
432
  context=context,
433
  context_lens=context_lens)
434
 
 
313
  # buffers (don't use register_buffer otherwise dtype will be changed in to())
314
  assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
315
  d = dim // num_heads
316
+ _freqs = torch.cat([
317
  rope_params(1024, d - 4 * (d // 6)),
318
  rope_params(1024, 2 * (d // 6)),
319
  rope_params(1024, 2 * (d // 6))
320
  ], dim=1)
321
+ self.register_buffer("freqs", _freqs, persistent=False)
322
 
323
  self.img_emb = MLPProj(1280, dim)
324
 
 
382
  face_pixel_values=None
383
  ):
384
  # params
385
+ freqs = self.freqs
 
 
386
 
387
  if y is not None:
388
  x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
 
427
  e=e0,
428
  seq_lens=seq_lens,
429
  grid_sizes=grid_sizes,
430
+ freqs=freqs,
431
  context=context,
432
  context_lens=context_lens)
433
 
wan/modules/animate/preprocess/process_pipepline.py CHANGED
@@ -19,9 +19,9 @@ from utils import resize_by_area, get_frame_indices, padding_resize, get_face_bb
19
  from human_visualization import draw_aapose_by_meta_new
20
  from retarget_pose import get_retarget_pose
21
  import sam2.modeling.sam.transformer as transformer
22
- transformer.USE_FLASH_ATTN = False
23
  transformer.MATH_KERNEL_ON = True
24
- transformer.OLD_GPU = True
25
  from sam_utils import build_sam2_video_predictor
26
 
27
 
@@ -401,52 +401,53 @@ class ProcessPipeline():
401
  key_frame_boxes.append(np.array([x1, y1, x2, y2], dtype=np.float32))
402
 
403
  # init SAM2 for this chunk
404
- inference_state = self.predictor.init_state_v2(frames=each_frames)
405
- self.predictor.reset_state(inference_state)
406
- ann_obj_id = 1
 
407
 
408
- # seed with box prompts (preferred), else fall back to points
409
- for ann_frame_idx, box_xyxy in zip(key_frame_index_list, key_frame_boxes):
410
- used_box = False
411
- try:
412
- # If your predictor exposes a box API, this is ideal.
413
- _ = self.predictor.add_new_box(
414
- inference_state=inference_state,
415
- frame_idx=ann_frame_idx,
416
- obj_id=ann_obj_id,
417
- box=box_xyxy[None, :] # shape (1, 4)
418
- )
419
- used_box = True
420
- except Exception:
421
  used_box = False
422
-
423
- if not used_box:
424
- # Fallback: sample a few positive points inside the box
425
- x1, y1, x2, y2 = box_xyxy.astype(int)
426
- cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
427
- pts = np.array([
428
- [cx, cy],
429
- [x1 + (x2 - x1) // 4, cy],
430
- [x2 - (x2 - x1) // 4, cy],
431
- [cx, y1 + (y2 - y1) // 4],
432
- [cx, y2 - (y2 - y1) // 4],
433
- ], dtype=np.int32)
434
- labels = np.ones(len(pts), dtype=np.int32) # 1 = positive
435
- _ = self.predictor.add_new_points(
436
- inference_state=inference_state,
437
- frame_idx=ann_frame_idx,
438
- obj_id=ann_obj_id,
439
- points=pts,
440
- labels=labels,
441
- )
442
-
443
- # propagate across the chunk
444
- video_segments = {}
445
- for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
446
- video_segments[out_frame_idx] = {
447
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
448
- for i, out_obj_id in enumerate(out_obj_ids)
449
- }
 
 
 
 
 
 
 
 
 
 
 
450
 
451
  # collect masks (single object id)
452
  for out_frame_idx in range(len(video_segments)):
@@ -475,7 +476,7 @@ class ProcessPipeline():
475
  continue
476
 
477
  # choose a few key frames to seed the object
478
- key_frame_num = 4 if len(each_frames) > 4 else 1
479
  key_frame_step = max(1, len(kp2ds) // key_frame_num)
480
  key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))[:key_frame_num]
481
 
@@ -529,3 +530,56 @@ class ProcessPipeline():
529
 
530
  return all_mask
531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  from human_visualization import draw_aapose_by_meta_new
20
  from retarget_pose import get_retarget_pose
21
  import sam2.modeling.sam.transformer as transformer
22
+ transformer.USE_FLASH_ATTN = True
23
  transformer.MATH_KERNEL_ON = True
24
+ transformer.OLD_GPU = False
25
  from sam_utils import build_sam2_video_predictor
26
 
27
 
 
401
  key_frame_boxes.append(np.array([x1, y1, x2, y2], dtype=np.float32))
402
 
403
  # init SAM2 for this chunk
404
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
405
+ inference_state = self.predictor.init_state_v2(frames=each_frames)
406
+ self.predictor.reset_state(inference_state)
407
+ ann_obj_id = 1
408
 
409
+ # seed with box prompts (preferred), else fall back to points
410
+ for ann_frame_idx, box_xyxy in zip(key_frame_index_list, key_frame_boxes):
 
 
 
 
 
 
 
 
 
 
 
411
  used_box = False
412
+ try:
413
+ # If your predictor exposes a box API, this is ideal.
414
+ _ = self.predictor.add_new_box(
415
+ inference_state=inference_state,
416
+ frame_idx=ann_frame_idx,
417
+ obj_id=ann_obj_id,
418
+ box=box_xyxy[None, :] # shape (1, 4)
419
+ )
420
+ used_box = True
421
+ except Exception:
422
+ used_box = False
423
+
424
+ if not used_box:
425
+ # Fallback: sample a few positive points inside the box
426
+ x1, y1, x2, y2 = box_xyxy.astype(int)
427
+ cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
428
+ pts = np.array([
429
+ [cx, cy],
430
+ [x1 + (x2 - x1) // 4, cy],
431
+ [x2 - (x2 - x1) // 4, cy],
432
+ [cx, y1 + (y2 - y1) // 4],
433
+ [cx, y2 - (y2 - y1) // 4],
434
+ ], dtype=np.int32)
435
+ labels = np.ones(len(pts), dtype=np.int32) # 1 = positive
436
+ _ = self.predictor.add_new_points(
437
+ inference_state=inference_state,
438
+ frame_idx=ann_frame_idx,
439
+ obj_id=ann_obj_id,
440
+ points=pts,
441
+ labels=labels,
442
+ )
443
+
444
+ # propagate across the chunk
445
+ video_segments = {}
446
+ for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
447
+ video_segments[out_frame_idx] = {
448
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
449
+ for i, out_obj_id in enumerate(out_obj_ids)
450
+ }
451
 
452
  # collect masks (single object id)
453
  for out_frame_idx in range(len(video_segments)):
 
476
  continue
477
 
478
  # choose a few key frames to seed the object
479
+ key_frame_num = 1
480
  key_frame_step = max(1, len(kp2ds) // key_frame_num)
481
  key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))[:key_frame_num]
482
 
 
530
 
531
  return all_mask
532
 
533
+ def get_face_bbox_masks(self, frames, kp2ds_all, scale=1.3, feather_px=0, keep_soft=False):
534
+ """
535
+ Create a per-frame mask that's simply the face bounding box.
536
+ - scale: bbox scale factor used by get_face_bboxes
537
+ - feather_px: optional Gaussian blur in pixels to feather edges
538
+ - keep_soft: if True, keep float [0,1] soft mask (after blur); else binarize to {0,1}
539
+ """
540
+ H, W = frames[0].shape[:2]
541
+
542
+ def _clip_box(x1, y1, x2, y2):
543
+ x1 = max(0, min(int(x1), W - 1))
544
+ x2 = max(0, min(int(x2), W - 1))
545
+ y1 = max(0, min(int(y1), H - 1))
546
+ y2 = max(0, min(int(y2), H - 1))
547
+ if x2 <= x1: x2 = min(W - 1, x1 + 1)
548
+ if y2 <= y1: y2 = min(H - 1, y1 + 1)
549
+ return x1, y1, x2, y2
550
+
551
+ masks = []
552
+ last_box = None
553
+ for meta in kp2ds_all:
554
+ # get_face_bboxes returns (x1, x2, y1, y2)
555
+ try:
556
+ x1, x2, y1, y2 = get_face_bboxes(
557
+ meta['keypoints_face'][:, :2],
558
+ scale=scale,
559
+ image_shape=(H, W)
560
+ )
561
+ x1, y1, x2, y2 = _clip_box(x1, y1, x2, y2)
562
+ last_box = (x1, y1, x2, y2)
563
+ except Exception:
564
+ # fallback: reuse last seen box to avoid holes
565
+ if last_box is None:
566
+ # no detection yet: push empty mask
567
+ masks.append(np.zeros((H, W), dtype=np.uint8))
568
+ continue
569
+ x1, y1, x2, y2 = last_box
570
+
571
+ m = np.zeros((H, W), dtype=np.float32)
572
+ m[y1:y2, x1:x2] = 1.0
573
+
574
+ if feather_px and feather_px > 0:
575
+ # kernel size must be odd and >= 3
576
+ k = max(3, int(feather_px) | 1)
577
+ m = cv2.GaussianBlur(m, (k, k), 0)
578
+
579
+ if keep_soft:
580
+ masks.append(m) # float [0,1]
581
+ else:
582
+ masks.append((m >= 0.5).astype(np.uint8)) # hard {0,1}
583
+
584
+ return masks
585
+
wan/modules/model.py CHANGED
@@ -30,7 +30,7 @@ def rope_params(max_seq_len, dim, theta=10000):
30
  freqs = torch.outer(
31
  torch.arange(max_seq_len),
32
  1.0 / torch.pow(theta,
33
- torch.arange(0, dim, 2).to(torch.float64).div(dim)))
34
  freqs = torch.polar(torch.ones_like(freqs), freqs)
35
  return freqs
36
 
@@ -41,14 +41,14 @@ def rope_apply(x, grid_sizes, freqs):
41
 
42
  # split freqs
43
  freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
44
-
45
  # loop over samples
46
  output = []
47
  for i, (f, h, w) in enumerate(grid_sizes.tolist()):
48
  seq_len = f * h * w
49
 
50
  # precompute multipliers
51
- x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
52
  seq_len, n, -1, 2))
53
  freqs_i = torch.cat([
54
  freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
@@ -65,7 +65,6 @@ def rope_apply(x, grid_sizes, freqs):
65
  output.append(x_i)
66
  return torch.stack(output).float()
67
 
68
-
69
  class WanRMSNorm(nn.Module):
70
 
71
  def __init__(self, dim, eps=1e-5):
@@ -397,12 +396,12 @@ class WanModel(ModelMixin, ConfigMixin):
397
  # buffers (don't use register_buffer otherwise dtype will be changed in to())
398
  assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
399
  d = dim // num_heads
400
- self.freqs = torch.cat([
401
  rope_params(1024, d - 4 * (d // 6)),
402
  rope_params(1024, 2 * (d // 6)),
403
  rope_params(1024, 2 * (d // 6))
404
- ],
405
- dim=1)
406
 
407
  # initialize weights
408
  self.init_weights()
@@ -437,9 +436,7 @@ class WanModel(ModelMixin, ConfigMixin):
437
  if self.model_type == 'i2v':
438
  assert y is not None
439
  # params
440
- device = self.patch_embedding.weight.device
441
- if self.freqs.device != device:
442
- self.freqs = self.freqs.to(device)
443
 
444
  if y is not None:
445
  x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@@ -482,7 +479,7 @@ class WanModel(ModelMixin, ConfigMixin):
482
  e=e0,
483
  seq_lens=seq_lens,
484
  grid_sizes=grid_sizes,
485
- freqs=self.freqs,
486
  context=context,
487
  context_lens=context_lens)
488
 
 
30
  freqs = torch.outer(
31
  torch.arange(max_seq_len),
32
  1.0 / torch.pow(theta,
33
+ torch.arange(0, dim, 2).to(torch.float32).div(dim)))
34
  freqs = torch.polar(torch.ones_like(freqs), freqs)
35
  return freqs
36
 
 
41
 
42
  # split freqs
43
  freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
44
+
45
  # loop over samples
46
  output = []
47
  for i, (f, h, w) in enumerate(grid_sizes.tolist()):
48
  seq_len = f * h * w
49
 
50
  # precompute multipliers
51
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
52
  seq_len, n, -1, 2))
53
  freqs_i = torch.cat([
54
  freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
 
65
  output.append(x_i)
66
  return torch.stack(output).float()
67
 
 
68
  class WanRMSNorm(nn.Module):
69
 
70
  def __init__(self, dim, eps=1e-5):
 
396
  # buffers (don't use register_buffer otherwise dtype will be changed in to())
397
  assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
398
  d = dim // num_heads
399
+ _freqs = torch.cat([
400
  rope_params(1024, d - 4 * (d // 6)),
401
  rope_params(1024, 2 * (d // 6)),
402
  rope_params(1024, 2 * (d // 6))
403
+ ], dim=1)
404
+ self.register_buffer("freqs", _freqs, persistent=False)
405
 
406
  # initialize weights
407
  self.init_weights()
 
436
  if self.model_type == 'i2v':
437
  assert y is not None
438
  # params
439
+ freqs = self.freqs
 
 
440
 
441
  if y is not None:
442
  x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
 
479
  e=e0,
480
  seq_lens=seq_lens,
481
  grid_sizes=grid_sizes,
482
+ freqs=freqs,
483
  context=context,
484
  context_lens=context_lens)
485