Spaces:
Running
on
Zero
Running
on
Zero
alex
commited on
Commit
·
76cd760
1
Parent(s):
fd0980c
further optimisation
Browse files- generate.py +2 -2
- wan/animate.py +2 -0
- wan/modules/animate/model_animate.py +4 -5
- wan/modules/animate/preprocess/process_pipepline.py +101 -47
- wan/modules/model.py +8 -11
generate.py
CHANGED
|
@@ -114,7 +114,7 @@ def _parse_args():
|
|
| 114 |
args.size = "720*1280"
|
| 115 |
args.frame_num = None
|
| 116 |
args.ckpt_dir = "./Wan2.2-Animate-14B/"
|
| 117 |
-
args.offload_model =
|
| 118 |
args.ulysses_size = 1
|
| 119 |
args.t5_fsdp = False
|
| 120 |
args.t5_cpu = False
|
|
@@ -130,7 +130,7 @@ def _parse_args():
|
|
| 130 |
args.sample_steps = None
|
| 131 |
args.sample_shift = None
|
| 132 |
args.sample_guide_scale = None
|
| 133 |
-
args.convert_model_dtype =
|
| 134 |
|
| 135 |
# animate
|
| 136 |
args.refert_num = 1
|
|
|
|
| 114 |
args.size = "720*1280"
|
| 115 |
args.frame_num = None
|
| 116 |
args.ckpt_dir = "./Wan2.2-Animate-14B/"
|
| 117 |
+
args.offload_model = False
|
| 118 |
args.ulysses_size = 1
|
| 119 |
args.t5_fsdp = False
|
| 120 |
args.t5_cpu = False
|
|
|
|
| 130 |
args.sample_steps = None
|
| 131 |
args.sample_shift = None
|
| 132 |
args.sample_guide_scale = None
|
| 133 |
+
args.convert_model_dtype = True
|
| 134 |
|
| 135 |
# animate
|
| 136 |
args.refert_num = 1
|
wan/animate.py
CHANGED
|
@@ -131,6 +131,8 @@ class WanAnimate:
|
|
| 131 |
checkpoint_dir=checkpoint_dir,
|
| 132 |
config=config
|
| 133 |
)
|
|
|
|
|
|
|
| 134 |
|
| 135 |
if use_sp:
|
| 136 |
self.sp_size = get_world_size()
|
|
|
|
| 131 |
checkpoint_dir=checkpoint_dir,
|
| 132 |
config=config
|
| 133 |
)
|
| 134 |
+
|
| 135 |
+
# self.noise_model = torch.compile(self.noise_model)
|
| 136 |
|
| 137 |
if use_sp:
|
| 138 |
self.sp_size = get_world_size()
|
wan/modules/animate/model_animate.py
CHANGED
|
@@ -313,11 +313,12 @@ class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
| 313 |
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 314 |
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 315 |
d = dim // num_heads
|
| 316 |
-
|
| 317 |
rope_params(1024, d - 4 * (d // 6)),
|
| 318 |
rope_params(1024, 2 * (d // 6)),
|
| 319 |
rope_params(1024, 2 * (d // 6))
|
| 320 |
], dim=1)
|
|
|
|
| 321 |
|
| 322 |
self.img_emb = MLPProj(1280, dim)
|
| 323 |
|
|
@@ -381,9 +382,7 @@ class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
| 381 |
face_pixel_values=None
|
| 382 |
):
|
| 383 |
# params
|
| 384 |
-
|
| 385 |
-
if self.freqs.device != device:
|
| 386 |
-
self.freqs = self.freqs.to(device)
|
| 387 |
|
| 388 |
if y is not None:
|
| 389 |
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
|
@@ -428,7 +427,7 @@ class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
| 428 |
e=e0,
|
| 429 |
seq_lens=seq_lens,
|
| 430 |
grid_sizes=grid_sizes,
|
| 431 |
-
freqs=
|
| 432 |
context=context,
|
| 433 |
context_lens=context_lens)
|
| 434 |
|
|
|
|
| 313 |
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 314 |
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 315 |
d = dim // num_heads
|
| 316 |
+
_freqs = torch.cat([
|
| 317 |
rope_params(1024, d - 4 * (d // 6)),
|
| 318 |
rope_params(1024, 2 * (d // 6)),
|
| 319 |
rope_params(1024, 2 * (d // 6))
|
| 320 |
], dim=1)
|
| 321 |
+
self.register_buffer("freqs", _freqs, persistent=False)
|
| 322 |
|
| 323 |
self.img_emb = MLPProj(1280, dim)
|
| 324 |
|
|
|
|
| 382 |
face_pixel_values=None
|
| 383 |
):
|
| 384 |
# params
|
| 385 |
+
freqs = self.freqs
|
|
|
|
|
|
|
| 386 |
|
| 387 |
if y is not None:
|
| 388 |
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
|
|
|
| 427 |
e=e0,
|
| 428 |
seq_lens=seq_lens,
|
| 429 |
grid_sizes=grid_sizes,
|
| 430 |
+
freqs=freqs,
|
| 431 |
context=context,
|
| 432 |
context_lens=context_lens)
|
| 433 |
|
wan/modules/animate/preprocess/process_pipepline.py
CHANGED
|
@@ -19,9 +19,9 @@ from utils import resize_by_area, get_frame_indices, padding_resize, get_face_bb
|
|
| 19 |
from human_visualization import draw_aapose_by_meta_new
|
| 20 |
from retarget_pose import get_retarget_pose
|
| 21 |
import sam2.modeling.sam.transformer as transformer
|
| 22 |
-
transformer.USE_FLASH_ATTN =
|
| 23 |
transformer.MATH_KERNEL_ON = True
|
| 24 |
-
transformer.OLD_GPU =
|
| 25 |
from sam_utils import build_sam2_video_predictor
|
| 26 |
|
| 27 |
|
|
@@ -401,52 +401,53 @@ class ProcessPipeline():
|
|
| 401 |
key_frame_boxes.append(np.array([x1, y1, x2, y2], dtype=np.float32))
|
| 402 |
|
| 403 |
# init SAM2 for this chunk
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
|
|
|
| 407 |
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
used_box = False
|
| 411 |
-
try:
|
| 412 |
-
# If your predictor exposes a box API, this is ideal.
|
| 413 |
-
_ = self.predictor.add_new_box(
|
| 414 |
-
inference_state=inference_state,
|
| 415 |
-
frame_idx=ann_frame_idx,
|
| 416 |
-
obj_id=ann_obj_id,
|
| 417 |
-
box=box_xyxy[None, :] # shape (1, 4)
|
| 418 |
-
)
|
| 419 |
-
used_box = True
|
| 420 |
-
except Exception:
|
| 421 |
used_box = False
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
# collect masks (single object id)
|
| 452 |
for out_frame_idx in range(len(video_segments)):
|
|
@@ -475,7 +476,7 @@ class ProcessPipeline():
|
|
| 475 |
continue
|
| 476 |
|
| 477 |
# choose a few key frames to seed the object
|
| 478 |
-
key_frame_num =
|
| 479 |
key_frame_step = max(1, len(kp2ds) // key_frame_num)
|
| 480 |
key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))[:key_frame_num]
|
| 481 |
|
|
@@ -529,3 +530,56 @@ class ProcessPipeline():
|
|
| 529 |
|
| 530 |
return all_mask
|
| 531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
from human_visualization import draw_aapose_by_meta_new
|
| 20 |
from retarget_pose import get_retarget_pose
|
| 21 |
import sam2.modeling.sam.transformer as transformer
|
| 22 |
+
transformer.USE_FLASH_ATTN = True
|
| 23 |
transformer.MATH_KERNEL_ON = True
|
| 24 |
+
transformer.OLD_GPU = False
|
| 25 |
from sam_utils import build_sam2_video_predictor
|
| 26 |
|
| 27 |
|
|
|
|
| 401 |
key_frame_boxes.append(np.array([x1, y1, x2, y2], dtype=np.float32))
|
| 402 |
|
| 403 |
# init SAM2 for this chunk
|
| 404 |
+
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
|
| 405 |
+
inference_state = self.predictor.init_state_v2(frames=each_frames)
|
| 406 |
+
self.predictor.reset_state(inference_state)
|
| 407 |
+
ann_obj_id = 1
|
| 408 |
|
| 409 |
+
# seed with box prompts (preferred), else fall back to points
|
| 410 |
+
for ann_frame_idx, box_xyxy in zip(key_frame_index_list, key_frame_boxes):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
used_box = False
|
| 412 |
+
try:
|
| 413 |
+
# If your predictor exposes a box API, this is ideal.
|
| 414 |
+
_ = self.predictor.add_new_box(
|
| 415 |
+
inference_state=inference_state,
|
| 416 |
+
frame_idx=ann_frame_idx,
|
| 417 |
+
obj_id=ann_obj_id,
|
| 418 |
+
box=box_xyxy[None, :] # shape (1, 4)
|
| 419 |
+
)
|
| 420 |
+
used_box = True
|
| 421 |
+
except Exception:
|
| 422 |
+
used_box = False
|
| 423 |
+
|
| 424 |
+
if not used_box:
|
| 425 |
+
# Fallback: sample a few positive points inside the box
|
| 426 |
+
x1, y1, x2, y2 = box_xyxy.astype(int)
|
| 427 |
+
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
| 428 |
+
pts = np.array([
|
| 429 |
+
[cx, cy],
|
| 430 |
+
[x1 + (x2 - x1) // 4, cy],
|
| 431 |
+
[x2 - (x2 - x1) // 4, cy],
|
| 432 |
+
[cx, y1 + (y2 - y1) // 4],
|
| 433 |
+
[cx, y2 - (y2 - y1) // 4],
|
| 434 |
+
], dtype=np.int32)
|
| 435 |
+
labels = np.ones(len(pts), dtype=np.int32) # 1 = positive
|
| 436 |
+
_ = self.predictor.add_new_points(
|
| 437 |
+
inference_state=inference_state,
|
| 438 |
+
frame_idx=ann_frame_idx,
|
| 439 |
+
obj_id=ann_obj_id,
|
| 440 |
+
points=pts,
|
| 441 |
+
labels=labels,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# propagate across the chunk
|
| 445 |
+
video_segments = {}
|
| 446 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
|
| 447 |
+
video_segments[out_frame_idx] = {
|
| 448 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
| 449 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
| 450 |
+
}
|
| 451 |
|
| 452 |
# collect masks (single object id)
|
| 453 |
for out_frame_idx in range(len(video_segments)):
|
|
|
|
| 476 |
continue
|
| 477 |
|
| 478 |
# choose a few key frames to seed the object
|
| 479 |
+
key_frame_num = 1
|
| 480 |
key_frame_step = max(1, len(kp2ds) // key_frame_num)
|
| 481 |
key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))[:key_frame_num]
|
| 482 |
|
|
|
|
| 530 |
|
| 531 |
return all_mask
|
| 532 |
|
| 533 |
+
def get_face_bbox_masks(self, frames, kp2ds_all, scale=1.3, feather_px=0, keep_soft=False):
|
| 534 |
+
"""
|
| 535 |
+
Create a per-frame mask that's simply the face bounding box.
|
| 536 |
+
- scale: bbox scale factor used by get_face_bboxes
|
| 537 |
+
- feather_px: optional Gaussian blur in pixels to feather edges
|
| 538 |
+
- keep_soft: if True, keep float [0,1] soft mask (after blur); else binarize to {0,1}
|
| 539 |
+
"""
|
| 540 |
+
H, W = frames[0].shape[:2]
|
| 541 |
+
|
| 542 |
+
def _clip_box(x1, y1, x2, y2):
|
| 543 |
+
x1 = max(0, min(int(x1), W - 1))
|
| 544 |
+
x2 = max(0, min(int(x2), W - 1))
|
| 545 |
+
y1 = max(0, min(int(y1), H - 1))
|
| 546 |
+
y2 = max(0, min(int(y2), H - 1))
|
| 547 |
+
if x2 <= x1: x2 = min(W - 1, x1 + 1)
|
| 548 |
+
if y2 <= y1: y2 = min(H - 1, y1 + 1)
|
| 549 |
+
return x1, y1, x2, y2
|
| 550 |
+
|
| 551 |
+
masks = []
|
| 552 |
+
last_box = None
|
| 553 |
+
for meta in kp2ds_all:
|
| 554 |
+
# get_face_bboxes returns (x1, x2, y1, y2)
|
| 555 |
+
try:
|
| 556 |
+
x1, x2, y1, y2 = get_face_bboxes(
|
| 557 |
+
meta['keypoints_face'][:, :2],
|
| 558 |
+
scale=scale,
|
| 559 |
+
image_shape=(H, W)
|
| 560 |
+
)
|
| 561 |
+
x1, y1, x2, y2 = _clip_box(x1, y1, x2, y2)
|
| 562 |
+
last_box = (x1, y1, x2, y2)
|
| 563 |
+
except Exception:
|
| 564 |
+
# fallback: reuse last seen box to avoid holes
|
| 565 |
+
if last_box is None:
|
| 566 |
+
# no detection yet: push empty mask
|
| 567 |
+
masks.append(np.zeros((H, W), dtype=np.uint8))
|
| 568 |
+
continue
|
| 569 |
+
x1, y1, x2, y2 = last_box
|
| 570 |
+
|
| 571 |
+
m = np.zeros((H, W), dtype=np.float32)
|
| 572 |
+
m[y1:y2, x1:x2] = 1.0
|
| 573 |
+
|
| 574 |
+
if feather_px and feather_px > 0:
|
| 575 |
+
# kernel size must be odd and >= 3
|
| 576 |
+
k = max(3, int(feather_px) | 1)
|
| 577 |
+
m = cv2.GaussianBlur(m, (k, k), 0)
|
| 578 |
+
|
| 579 |
+
if keep_soft:
|
| 580 |
+
masks.append(m) # float [0,1]
|
| 581 |
+
else:
|
| 582 |
+
masks.append((m >= 0.5).astype(np.uint8)) # hard {0,1}
|
| 583 |
+
|
| 584 |
+
return masks
|
| 585 |
+
|
wan/modules/model.py
CHANGED
|
@@ -30,7 +30,7 @@ def rope_params(max_seq_len, dim, theta=10000):
|
|
| 30 |
freqs = torch.outer(
|
| 31 |
torch.arange(max_seq_len),
|
| 32 |
1.0 / torch.pow(theta,
|
| 33 |
-
torch.arange(0, dim, 2).to(torch.
|
| 34 |
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 35 |
return freqs
|
| 36 |
|
|
@@ -41,14 +41,14 @@ def rope_apply(x, grid_sizes, freqs):
|
|
| 41 |
|
| 42 |
# split freqs
|
| 43 |
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 44 |
-
|
| 45 |
# loop over samples
|
| 46 |
output = []
|
| 47 |
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 48 |
seq_len = f * h * w
|
| 49 |
|
| 50 |
# precompute multipliers
|
| 51 |
-
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.
|
| 52 |
seq_len, n, -1, 2))
|
| 53 |
freqs_i = torch.cat([
|
| 54 |
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
@@ -65,7 +65,6 @@ def rope_apply(x, grid_sizes, freqs):
|
|
| 65 |
output.append(x_i)
|
| 66 |
return torch.stack(output).float()
|
| 67 |
|
| 68 |
-
|
| 69 |
class WanRMSNorm(nn.Module):
|
| 70 |
|
| 71 |
def __init__(self, dim, eps=1e-5):
|
|
@@ -397,12 +396,12 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 397 |
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 398 |
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 399 |
d = dim // num_heads
|
| 400 |
-
|
| 401 |
rope_params(1024, d - 4 * (d // 6)),
|
| 402 |
rope_params(1024, 2 * (d // 6)),
|
| 403 |
rope_params(1024, 2 * (d // 6))
|
| 404 |
-
],
|
| 405 |
-
|
| 406 |
|
| 407 |
# initialize weights
|
| 408 |
self.init_weights()
|
|
@@ -437,9 +436,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 437 |
if self.model_type == 'i2v':
|
| 438 |
assert y is not None
|
| 439 |
# params
|
| 440 |
-
|
| 441 |
-
if self.freqs.device != device:
|
| 442 |
-
self.freqs = self.freqs.to(device)
|
| 443 |
|
| 444 |
if y is not None:
|
| 445 |
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
|
@@ -482,7 +479,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 482 |
e=e0,
|
| 483 |
seq_lens=seq_lens,
|
| 484 |
grid_sizes=grid_sizes,
|
| 485 |
-
freqs=
|
| 486 |
context=context,
|
| 487 |
context_lens=context_lens)
|
| 488 |
|
|
|
|
| 30 |
freqs = torch.outer(
|
| 31 |
torch.arange(max_seq_len),
|
| 32 |
1.0 / torch.pow(theta,
|
| 33 |
+
torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
| 34 |
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 35 |
return freqs
|
| 36 |
|
|
|
|
| 41 |
|
| 42 |
# split freqs
|
| 43 |
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 44 |
+
|
| 45 |
# loop over samples
|
| 46 |
output = []
|
| 47 |
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 48 |
seq_len = f * h * w
|
| 49 |
|
| 50 |
# precompute multipliers
|
| 51 |
+
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
|
| 52 |
seq_len, n, -1, 2))
|
| 53 |
freqs_i = torch.cat([
|
| 54 |
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
|
|
| 65 |
output.append(x_i)
|
| 66 |
return torch.stack(output).float()
|
| 67 |
|
|
|
|
| 68 |
class WanRMSNorm(nn.Module):
|
| 69 |
|
| 70 |
def __init__(self, dim, eps=1e-5):
|
|
|
|
| 396 |
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 397 |
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 398 |
d = dim // num_heads
|
| 399 |
+
_freqs = torch.cat([
|
| 400 |
rope_params(1024, d - 4 * (d // 6)),
|
| 401 |
rope_params(1024, 2 * (d // 6)),
|
| 402 |
rope_params(1024, 2 * (d // 6))
|
| 403 |
+
], dim=1)
|
| 404 |
+
self.register_buffer("freqs", _freqs, persistent=False)
|
| 405 |
|
| 406 |
# initialize weights
|
| 407 |
self.init_weights()
|
|
|
|
| 436 |
if self.model_type == 'i2v':
|
| 437 |
assert y is not None
|
| 438 |
# params
|
| 439 |
+
freqs = self.freqs
|
|
|
|
|
|
|
| 440 |
|
| 441 |
if y is not None:
|
| 442 |
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
|
|
|
| 479 |
e=e0,
|
| 480 |
seq_lens=seq_lens,
|
| 481 |
grid_sizes=grid_sizes,
|
| 482 |
+
freqs=freqs,
|
| 483 |
context=context,
|
| 484 |
context_lens=context_lens)
|
| 485 |
|