alexnasa commited on
Commit
26e5c98
·
verified ·
1 Parent(s): 410b861

avoid vram issues

Browse files
wan/modules/animate/preprocess/process_pipepline.py CHANGED
@@ -1,585 +1,585 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import os
3
- import numpy as np
4
- import shutil
5
- import torch
6
- from diffusers import FluxKontextPipeline
7
- import cv2
8
- from loguru import logger
9
- from PIL import Image
10
- try:
11
- import moviepy.editor as mpy
12
- except:
13
- import moviepy as mpy
14
-
15
- from decord import VideoReader
16
- from pose2d import Pose2d
17
- from pose2d_utils import AAPoseMeta
18
- from utils import resize_by_area, get_frame_indices, padding_resize, get_face_bboxes, get_aug_mask, get_mask_body_img
19
- from human_visualization import draw_aapose_by_meta_new
20
- from retarget_pose import get_retarget_pose
21
- import sam2.modeling.sam.transformer as transformer
22
- transformer.USE_FLASH_ATTN = True
23
- transformer.MATH_KERNEL_ON = False
24
- transformer.OLD_GPU = False
25
- from sam_utils import build_sam2_video_predictor
26
-
27
-
28
- class ProcessPipeline():
29
- def __init__(self, det_checkpoint_path, pose2d_checkpoint_path, sam_checkpoint_path, flux_kontext_path):
30
- self.pose2d = Pose2d(checkpoint=pose2d_checkpoint_path, detector_checkpoint=det_checkpoint_path)
31
-
32
- model_cfg = "sam2_hiera_l.yaml"
33
- if sam_checkpoint_path is not None:
34
- self.predictor = build_sam2_video_predictor(model_cfg, sam_checkpoint_path)
35
- if flux_kontext_path is not None:
36
- self.flux_kontext = FluxKontextPipeline.from_pretrained(flux_kontext_path, torch_dtype=torch.bfloat16).to("cuda")
37
-
38
- def __call__(self, video_path, refer_image_path, output_path, resolution_area=[1280, 720], fps=30, iterations=3, k=7, w_len=1, h_len=1, retarget_flag=False, use_flux=False, replace_flag=False):
39
- if replace_flag:
40
-
41
- video_reader = VideoReader(video_path)
42
- frame_num = len(video_reader)
43
- print('frame_num: {}'.format(frame_num))
44
-
45
- video_fps = video_reader.get_avg_fps()
46
- print('video_fps: {}'.format(video_fps))
47
- print('fps: {}'.format(fps))
48
-
49
- # TODO: Maybe we can switch to PyAV later, which can get accurate frame num
50
- duration = video_reader.get_frame_timestamp(-1)[-1]
51
- expected_frame_num = int(duration * video_fps + 0.5)
52
- ratio = abs((frame_num - expected_frame_num)/frame_num)
53
- if ratio > 0.1:
54
- print("Warning: The difference between the actual number of frames and the expected number of frames is two large")
55
- frame_num = expected_frame_num
56
-
57
- if fps == -1:
58
- fps = video_fps
59
-
60
- target_num = int(frame_num / video_fps * fps)
61
- print('target_num: {}'.format(target_num))
62
- idxs = get_frame_indices(frame_num, video_fps, target_num, fps)
63
- frames = video_reader.get_batch(idxs).asnumpy()
64
-
65
- frames = [resize_by_area(frame, resolution_area[0] * resolution_area[1], divisor=16) for frame in frames]
66
- height, width = frames[0].shape[:2]
67
- logger.info(f"Processing pose meta")
68
-
69
-
70
- tpl_pose_metas = self.pose2d(frames)
71
-
72
- face_images = []
73
- for idx, meta in enumerate(tpl_pose_metas):
74
- face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3,
75
- image_shape=(frames[0].shape[0], frames[0].shape[1]))
76
-
77
- x1, x2, y1, y2 = face_bbox_for_image
78
- face_image = frames[idx][y1:y2, x1:x2]
79
- face_image = cv2.resize(face_image, (512, 512))
80
- face_images.append(face_image)
81
-
82
- logger.info(f"Processing reference image: {refer_image_path}")
83
- refer_img = cv2.imread(refer_image_path)
84
- src_ref_path = os.path.join(output_path, 'src_ref.png')
85
- shutil.copy(refer_image_path, src_ref_path)
86
- refer_img = refer_img[..., ::-1]
87
-
88
- refer_img = padding_resize(refer_img, height, width)
89
- logger.info(f"Processing template video: {video_path}")
90
- tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas]
91
- cond_images = []
92
-
93
- for idx, meta in enumerate(tpl_retarget_pose_metas):
94
- canvas = np.zeros_like(refer_img)
95
- conditioning_image = draw_aapose_by_meta_new(canvas, meta)
96
- cond_images.append(conditioning_image)
97
- masks = self.get_mask_from_face_bbox(frames, 400, tpl_pose_metas)
98
-
99
- bg_images = []
100
- aug_masks = []
101
-
102
- for frame, mask in zip(frames, masks):
103
- if iterations > 0:
104
- _, each_mask = get_mask_body_img(frame, mask, iterations=iterations, k=k)
105
- each_aug_mask = get_aug_mask(each_mask, w_len=w_len, h_len=h_len)
106
- else:
107
- each_aug_mask = mask
108
-
109
- each_bg_image = frame * (1 - each_aug_mask[:, :, None])
110
- bg_images.append(each_bg_image)
111
- aug_masks.append(each_aug_mask)
112
-
113
- src_face_path = os.path.join(output_path, 'src_face.mp4')
114
- mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path)
115
-
116
- src_pose_path = os.path.join(output_path, 'src_pose.mp4')
117
- mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path)
118
-
119
- src_bg_path = os.path.join(output_path, 'src_bg.mp4')
120
- mpy.ImageSequenceClip(bg_images, fps=fps).write_videofile(src_bg_path)
121
-
122
- aug_masks_new = [np.stack([mask * 255, mask * 255, mask * 255], axis=2) for mask in aug_masks]
123
- src_mask_path = os.path.join(output_path, 'src_mask.mp4')
124
- mpy.ImageSequenceClip(aug_masks_new, fps=fps).write_videofile(src_mask_path)
125
- return True
126
- else:
127
- logger.info(f"Processing reference image: {refer_image_path}")
128
- refer_img = cv2.imread(refer_image_path)
129
- src_ref_path = os.path.join(output_path, 'src_ref.png')
130
- shutil.copy(refer_image_path, src_ref_path)
131
- refer_img = refer_img[..., ::-1]
132
-
133
- refer_img = resize_by_area(refer_img, resolution_area[0] * resolution_area[1], divisor=16)
134
-
135
- refer_pose_meta = self.pose2d([refer_img])[0]
136
-
137
-
138
- logger.info(f"Processing template video: {video_path}")
139
- video_reader = VideoReader(video_path)
140
- frame_num = len(video_reader)
141
- print('frame_num: {}'.format(frame_num))
142
-
143
- video_fps = video_reader.get_avg_fps()
144
- print('video_fps: {}'.format(video_fps))
145
- print('fps: {}'.format(fps))
146
-
147
- # TODO: Maybe we can switch to PyAV later, which can get accurate frame num
148
- duration = video_reader.get_frame_timestamp(-1)[-1]
149
- expected_frame_num = int(duration * video_fps + 0.5)
150
- ratio = abs((frame_num - expected_frame_num)/frame_num)
151
- if ratio > 0.1:
152
- print("Warning: The difference between the actual number of frames and the expected number of frames is two large")
153
- frame_num = expected_frame_num
154
-
155
- if fps == -1:
156
- fps = video_fps
157
-
158
- target_num = int(frame_num / video_fps * fps)
159
- print('target_num: {}'.format(target_num))
160
- idxs = get_frame_indices(frame_num, video_fps, target_num, fps)
161
- frames = video_reader.get_batch(idxs).asnumpy()
162
-
163
- logger.info(f"Processing pose meta")
164
-
165
- tpl_pose_meta0 = self.pose2d(frames[:1])[0]
166
- tpl_pose_metas = self.pose2d(frames)
167
-
168
- face_images = []
169
- for idx, meta in enumerate(tpl_pose_metas):
170
- face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3,
171
- image_shape=(frames[0].shape[0], frames[0].shape[1]))
172
-
173
- x1, x2, y1, y2 = face_bbox_for_image
174
- face_image = frames[idx][y1:y2, x1:x2]
175
- face_image = cv2.resize(face_image, (512, 512))
176
- face_images.append(face_image)
177
-
178
- if retarget_flag:
179
- if use_flux:
180
- tpl_prompt, refer_prompt = self.get_editing_prompts(tpl_pose_metas, refer_pose_meta)
181
- refer_input = Image.fromarray(refer_img)
182
- refer_edit = self.flux_kontext(
183
- image=refer_input,
184
- height=refer_img.shape[0],
185
- width=refer_img.shape[1],
186
- prompt=refer_prompt,
187
- guidance_scale=2.5,
188
- num_inference_steps=28,
189
- ).images[0]
190
-
191
- refer_edit = Image.fromarray(padding_resize(np.array(refer_edit), refer_img.shape[0], refer_img.shape[1]))
192
- refer_edit_path = os.path.join(output_path, 'refer_edit.png')
193
- refer_edit.save(refer_edit_path)
194
- refer_edit_pose_meta = self.pose2d([np.array(refer_edit)])[0]
195
-
196
- tpl_img = frames[1]
197
- tpl_input = Image.fromarray(tpl_img)
198
-
199
- tpl_edit = self.flux_kontext(
200
- image=tpl_input,
201
- height=tpl_img.shape[0],
202
- width=tpl_img.shape[1],
203
- prompt=tpl_prompt,
204
- guidance_scale=2.5,
205
- num_inference_steps=28,
206
- ).images[0]
207
-
208
- tpl_edit = Image.fromarray(padding_resize(np.array(tpl_edit), tpl_img.shape[0], tpl_img.shape[1]))
209
- tpl_edit_path = os.path.join(output_path, 'tpl_edit.png')
210
- tpl_edit.save(tpl_edit_path)
211
- tpl_edit_pose_meta0 = self.pose2d([np.array(tpl_edit)])[0]
212
- tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, tpl_edit_pose_meta0, refer_edit_pose_meta)
213
- else:
214
- tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, None, None)
215
- else:
216
- tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas]
217
-
218
- cond_images = []
219
- for idx, meta in enumerate(tpl_retarget_pose_metas):
220
- if retarget_flag:
221
- canvas = np.zeros_like(refer_img)
222
- conditioning_image = draw_aapose_by_meta_new(canvas, meta)
223
- else:
224
- canvas = np.zeros_like(frames[0])
225
- conditioning_image = draw_aapose_by_meta_new(canvas, meta)
226
- conditioning_image = padding_resize(conditioning_image, refer_img.shape[0], refer_img.shape[1])
227
-
228
- cond_images.append(conditioning_image)
229
-
230
- src_face_path = os.path.join(output_path, 'src_face.mp4')
231
- mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path)
232
-
233
- src_pose_path = os.path.join(output_path, 'src_pose.mp4')
234
- mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path)
235
- return True
236
-
237
- def get_editing_prompts(self, tpl_pose_metas, refer_pose_meta):
238
- arm_visible = False
239
- leg_visible = False
240
- for tpl_pose_meta in tpl_pose_metas:
241
- tpl_keypoints = tpl_pose_meta['keypoints_body']
242
- if tpl_keypoints[3].all() != 0 or tpl_keypoints[4].all() != 0 or tpl_keypoints[6].all() != 0 or tpl_keypoints[7].all() != 0:
243
- if (tpl_keypoints[3][0] <= 1 and tpl_keypoints[3][1] <= 1 and tpl_keypoints[3][2] >= 0.75) or (tpl_keypoints[4][0] <= 1 and tpl_keypoints[4][1] <= 1 and tpl_keypoints[4][2] >= 0.75) or \
244
- (tpl_keypoints[6][0] <= 1 and tpl_keypoints[6][1] <= 1 and tpl_keypoints[6][2] >= 0.75) or (tpl_keypoints[7][0] <= 1 and tpl_keypoints[7][1] <= 1 and tpl_keypoints[7][2] >= 0.75):
245
- arm_visible = True
246
- if tpl_keypoints[9].all() != 0 or tpl_keypoints[12].all() != 0 or tpl_keypoints[10].all() != 0 or tpl_keypoints[13].all() != 0:
247
- if (tpl_keypoints[9][0] <= 1 and tpl_keypoints[9][1] <= 1 and tpl_keypoints[9][2] >= 0.75) or (tpl_keypoints[12][0] <= 1 and tpl_keypoints[12][1] <= 1 and tpl_keypoints[12][2] >= 0.75) or \
248
- (tpl_keypoints[10][0] <= 1 and tpl_keypoints[10][1] <= 1 and tpl_keypoints[10][2] >= 0.75) or (tpl_keypoints[13][0] <= 1 and tpl_keypoints[13][1] <= 1 and tpl_keypoints[13][2] >= 0.75):
249
- leg_visible = True
250
- if arm_visible and leg_visible:
251
- break
252
-
253
- if leg_visible:
254
- if tpl_pose_meta['width'] > tpl_pose_meta['height']:
255
- tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image."
256
- else:
257
- tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image."
258
-
259
- if refer_pose_meta['width'] > refer_pose_meta['height']:
260
- refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image."
261
- else:
262
- refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image."
263
- elif arm_visible:
264
- if tpl_pose_meta['width'] > tpl_pose_meta['height']:
265
- tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image."
266
- else:
267
- tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image."
268
-
269
- if refer_pose_meta['width'] > refer_pose_meta['height']:
270
- refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image."
271
- else:
272
- refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image."
273
- else:
274
- tpl_prompt = "Change the person to face forward."
275
- refer_prompt = "Change the person to face forward."
276
-
277
- return tpl_prompt, refer_prompt
278
-
279
-
280
- def get_mask(self, frames, th_step, kp2ds_all):
281
- frame_num = len(frames)
282
- if frame_num < th_step:
283
- num_step = 1
284
- else:
285
- num_step = (frame_num + th_step) // th_step
286
-
287
- all_mask = []
288
- for index in range(num_step):
289
- each_frames = frames[index * th_step:(index + 1) * th_step]
290
-
291
- kp2ds = kp2ds_all[index * th_step:(index + 1) * th_step]
292
- if len(each_frames) > 4:
293
- key_frame_num = 4
294
- elif 4 >= len(each_frames) > 0:
295
- key_frame_num = 1
296
- else:
297
- continue
298
-
299
- key_frame_step = len(kp2ds) // key_frame_num
300
- key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))
301
-
302
- key_points_index = [0, 1, 2, 5, 8, 11, 10, 13]
303
- key_frame_body_points_list = []
304
- for key_frame_index in key_frame_index_list:
305
- keypoints_body_list = []
306
- body_key_points = kp2ds[key_frame_index]['keypoints_face']
307
- for each_index in key_points_index:
308
- each_keypoint = body_key_points[each_index]
309
- if None is each_keypoint:
310
- continue
311
- keypoints_body_list.append(each_keypoint)
312
-
313
- keypoints_body = np.array(keypoints_body_list)[:, :2]
314
- wh = np.array([[kp2ds[0]['width'], kp2ds[0]['height']]])
315
- points = (keypoints_body * wh).astype(np.int32)
316
- key_frame_body_points_list.append(points)
317
-
318
- inference_state = self.predictor.init_state_v2(frames=each_frames)
319
- self.predictor.reset_state(inference_state)
320
- ann_obj_id = 1
321
- for ann_frame_idx, points in zip(key_frame_index_list, key_frame_body_points_list):
322
- labels = np.array([1] * points.shape[0], np.int32)
323
- _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
324
- inference_state=inference_state,
325
- frame_idx=ann_frame_idx,
326
- obj_id=ann_obj_id,
327
- points=points,
328
- labels=labels,
329
- )
330
-
331
- video_segments = {}
332
- for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
333
- video_segments[out_frame_idx] = {
334
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
335
- for i, out_obj_id in enumerate(out_obj_ids)
336
- }
337
-
338
- for out_frame_idx in range(len(video_segments)):
339
- for out_obj_id, out_mask in video_segments[out_frame_idx].items():
340
- out_mask = out_mask[0].astype(np.uint8)
341
- all_mask.append(out_mask)
342
-
343
- return all_mask
344
-
345
- def convert_list_to_array(self, metas):
346
- metas_list = []
347
- for meta in metas:
348
- for key, value in meta.items():
349
- if type(value) is list:
350
- value = np.array(value)
351
- meta[key] = value
352
- metas_list.append(meta)
353
- return metas_list
354
-
355
- def get_mask_from_face_bbox(self, frames, th_step, kp2ds_all):
356
- """
357
- Build masks using a face bounding box per key frame (derived from keypoints_face),
358
- then propagate with SAM2 across each chunk of frames.
359
- """
360
- H, W = frames[0].shape[:2]
361
-
362
- def _clip_box(x1, y1, x2, y2, W, H):
363
- x1 = max(0, min(int(x1), W - 1))
364
- x2 = max(0, min(int(x2), W - 1))
365
- y1 = max(0, min(int(y1), H - 1))
366
- y2 = max(0, min(int(y2), H - 1))
367
- if x2 <= x1: x2 = min(W - 1, x1 + 1)
368
- if y2 <= y1: y2 = min(H - 1, y1 + 1)
369
- return x1, y1, x2, y2
370
-
371
- frame_num = len(frames)
372
- if frame_num < th_step:
373
- num_step = 1
374
- else:
375
- num_step = (frame_num + th_step) // th_step
376
-
377
- all_mask = []
378
-
379
- for step_idx in range(num_step):
380
- each_frames = frames[step_idx * th_step:(step_idx + 1) * th_step]
381
- kp2ds = kp2ds_all[step_idx * th_step:(step_idx + 1) * th_step]
382
- if len(each_frames) == 0:
383
- continue
384
-
385
- # pick a few key frames in this chunk
386
- key_frame_num = 4 if len(each_frames) > 4 else 1
387
- key_frame_step = max(1, len(kp2ds) // key_frame_num)
388
- key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))[:key_frame_num]
389
-
390
- # compute face boxes on the selected key frames
391
- key_frame_boxes = []
392
- for kfi in key_frame_index_list:
393
- meta = kp2ds[kfi]
394
- # get_face_bboxes returns (x1, x2, y1, y2) in your code
395
- x1, x2, y1, y2 = get_face_bboxes(
396
- meta['keypoints_face'][:, :2],
397
- scale=1.3,
398
- image_shape=(H, W)
399
- )
400
- x1, y1, x2, y2 = _clip_box(x1, y1, x2, y2, W, H)
401
- key_frame_boxes.append(np.array([x1, y1, x2, y2], dtype=np.float32))
402
-
403
- # init SAM2 for this chunk
404
- with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
405
- inference_state = self.predictor.init_state_v2(frames=each_frames)
406
- self.predictor.reset_state(inference_state)
407
- ann_obj_id = 1
408
-
409
- # seed with box prompts (preferred), else fall back to points
410
- for ann_frame_idx, box_xyxy in zip(key_frame_index_list, key_frame_boxes):
411
- used_box = False
412
- try:
413
- # If your predictor exposes a box API, this is ideal.
414
- _ = self.predictor.add_new_box(
415
- inference_state=inference_state,
416
- frame_idx=ann_frame_idx,
417
- obj_id=ann_obj_id,
418
- box=box_xyxy[None, :] # shape (1, 4)
419
- )
420
- used_box = True
421
- except Exception:
422
- used_box = False
423
-
424
- if not used_box:
425
- # Fallback: sample a few positive points inside the box
426
- x1, y1, x2, y2 = box_xyxy.astype(int)
427
- cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
428
- pts = np.array([
429
- [cx, cy],
430
- [x1 + (x2 - x1) // 4, cy],
431
- [x2 - (x2 - x1) // 4, cy],
432
- [cx, y1 + (y2 - y1) // 4],
433
- [cx, y2 - (y2 - y1) // 4],
434
- ], dtype=np.int32)
435
- labels = np.ones(len(pts), dtype=np.int32) # 1 = positive
436
- _ = self.predictor.add_new_points(
437
- inference_state=inference_state,
438
- frame_idx=ann_frame_idx,
439
- obj_id=ann_obj_id,
440
- points=pts,
441
- labels=labels,
442
- )
443
-
444
- # propagate across the chunk
445
- video_segments = {}
446
- for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
447
- video_segments[out_frame_idx] = {
448
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
449
- for i, out_obj_id in enumerate(out_obj_ids)
450
- }
451
-
452
- # collect masks (single object id)
453
- for out_frame_idx in range(len(video_segments)):
454
- # (H, W) boolean/uint8
455
- mask = next(iter(video_segments[out_frame_idx].values()))
456
- mask = mask[0].astype(np.uint8)
457
- all_mask.append(mask)
458
-
459
- return all_mask
460
- def get_mask_from_face_point(self, frames, th_step, kp2ds_all):
461
- """
462
- Build masks using a single face *center point* per key frame,
463
- then propagate with SAM2 across each chunk of frames.
464
- """
465
- H, W = frames[0].shape[:2]
466
-
467
- frame_num = len(frames)
468
- num_step = 1 if frame_num < th_step else (frame_num + th_step) // th_step
469
-
470
- all_mask = []
471
-
472
- for step_idx in range(num_step):
473
- each_frames = frames[step_idx * th_step:(step_idx + 1) * th_step]
474
- kp2ds = kp2ds_all[step_idx * th_step:(step_idx + 1) * th_step]
475
- if len(each_frames) == 0:
476
- continue
477
-
478
- # choose a few key frames to seed the object
479
- key_frame_num = 1
480
- key_frame_step = max(1, len(kp2ds) // key_frame_num)
481
- key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))[:key_frame_num]
482
-
483
- # compute center point from face bbox for each selected key frame
484
- center_pts = []
485
- for kfi in key_frame_index_list:
486
- meta = kp2ds[kfi]
487
- # your helper returns (x1, x2, y1, y2)
488
- x1, x2, y1, y2 = get_face_bboxes(
489
- meta['keypoints_face'][:, :2],
490
- scale=1.3,
491
- image_shape=(H, W)
492
- )
493
- cx = (x1 + x2) // 2
494
- cy = (y1 + y2) // 2
495
- # clip just in case
496
- cx = int(max(0, min(cx, W - 1)))
497
- cy = int(max(0, min(cy, H - 1)))
498
- center_pts.append(np.array([cx, cy], dtype=np.int32))
499
-
500
- # init SAM2 for this chunk
501
- inference_state = self.predictor.init_state_v2(frames=each_frames)
502
- self.predictor.reset_state(inference_state)
503
- ann_obj_id = 1
504
-
505
- # seed each key frame with a single positive point at the face center
506
- for ann_frame_idx, pt in zip(key_frame_index_list, center_pts):
507
- pts = pt[None, :] # shape (1, 2)
508
- labels = np.ones(1, dtype=np.int32) # 1 = positive
509
- _ = self.predictor.add_new_points(
510
- inference_state=inference_state,
511
- frame_idx=ann_frame_idx,
512
- obj_id=ann_obj_id,
513
- points=pts,
514
- labels=labels,
515
- )
516
-
517
- # propagate across the chunk
518
- video_segments = {}
519
- for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
520
- video_segments[out_frame_idx] = {
521
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
522
- for i, out_obj_id in enumerate(out_obj_ids)
523
- }
524
-
525
- # collect masks (single object id)
526
- for out_frame_idx in range(len(video_segments)):
527
- mask = next(iter(video_segments[out_frame_idx].values()))
528
- mask = mask[0].astype(np.uint8)
529
- all_mask.append(mask)
530
-
531
- return all_mask
532
-
533
- def get_face_bbox_masks(self, frames, kp2ds_all, scale=1.3, feather_px=0, keep_soft=False):
534
- """
535
- Create a per-frame mask that's simply the face bounding box.
536
- - scale: bbox scale factor used by get_face_bboxes
537
- - feather_px: optional Gaussian blur in pixels to feather edges
538
- - keep_soft: if True, keep float [0,1] soft mask (after blur); else binarize to {0,1}
539
- """
540
- H, W = frames[0].shape[:2]
541
-
542
- def _clip_box(x1, y1, x2, y2):
543
- x1 = max(0, min(int(x1), W - 1))
544
- x2 = max(0, min(int(x2), W - 1))
545
- y1 = max(0, min(int(y1), H - 1))
546
- y2 = max(0, min(int(y2), H - 1))
547
- if x2 <= x1: x2 = min(W - 1, x1 + 1)
548
- if y2 <= y1: y2 = min(H - 1, y1 + 1)
549
- return x1, y1, x2, y2
550
-
551
- masks = []
552
- last_box = None
553
- for meta in kp2ds_all:
554
- # get_face_bboxes returns (x1, x2, y1, y2)
555
- try:
556
- x1, x2, y1, y2 = get_face_bboxes(
557
- meta['keypoints_face'][:, :2],
558
- scale=scale,
559
- image_shape=(H, W)
560
- )
561
- x1, y1, x2, y2 = _clip_box(x1, y1, x2, y2)
562
- last_box = (x1, y1, x2, y2)
563
- except Exception:
564
- # fallback: reuse last seen box to avoid holes
565
- if last_box is None:
566
- # no detection yet: push empty mask
567
- masks.append(np.zeros((H, W), dtype=np.uint8))
568
- continue
569
- x1, y1, x2, y2 = last_box
570
-
571
- m = np.zeros((H, W), dtype=np.float32)
572
- m[y1:y2, x1:x2] = 1.0
573
-
574
- if feather_px and feather_px > 0:
575
- # kernel size must be odd and >= 3
576
- k = max(3, int(feather_px) | 1)
577
- m = cv2.GaussianBlur(m, (k, k), 0)
578
-
579
- if keep_soft:
580
- masks.append(m) # float [0,1]
581
- else:
582
- masks.append((m >= 0.5).astype(np.uint8)) # hard {0,1}
583
-
584
- return masks
585
-
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import os
3
+ import numpy as np
4
+ import shutil
5
+ import torch
6
+ from diffusers import FluxKontextPipeline
7
+ import cv2
8
+ from loguru import logger
9
+ from PIL import Image
10
+ try:
11
+ import moviepy.editor as mpy
12
+ except:
13
+ import moviepy as mpy
14
+
15
+ from decord import VideoReader
16
+ from pose2d import Pose2d
17
+ from pose2d_utils import AAPoseMeta
18
+ from utils import resize_by_area, get_frame_indices, padding_resize, get_face_bboxes, get_aug_mask, get_mask_body_img
19
+ from human_visualization import draw_aapose_by_meta_new
20
+ from retarget_pose import get_retarget_pose
21
+ import sam2.modeling.sam.transformer as transformer
22
+ transformer.USE_FLASH_ATTN = True
23
+ transformer.MATH_KERNEL_ON = True
24
+ transformer.OLD_GPU = False
25
+ from sam_utils import build_sam2_video_predictor
26
+
27
+
28
+ class ProcessPipeline():
29
+ def __init__(self, det_checkpoint_path, pose2d_checkpoint_path, sam_checkpoint_path, flux_kontext_path):
30
+ self.pose2d = Pose2d(checkpoint=pose2d_checkpoint_path, detector_checkpoint=det_checkpoint_path)
31
+
32
+ model_cfg = "sam2_hiera_l.yaml"
33
+ if sam_checkpoint_path is not None:
34
+ self.predictor = build_sam2_video_predictor(model_cfg, sam_checkpoint_path)
35
+ if flux_kontext_path is not None:
36
+ self.flux_kontext = FluxKontextPipeline.from_pretrained(flux_kontext_path, torch_dtype=torch.bfloat16).to("cuda")
37
+
38
+ def __call__(self, video_path, refer_image_path, output_path, resolution_area=[1280, 720], fps=30, iterations=3, k=7, w_len=1, h_len=1, retarget_flag=False, use_flux=False, replace_flag=False):
39
+ if replace_flag:
40
+
41
+ video_reader = VideoReader(video_path)
42
+ frame_num = len(video_reader)
43
+ print('frame_num: {}'.format(frame_num))
44
+
45
+ video_fps = video_reader.get_avg_fps()
46
+ print('video_fps: {}'.format(video_fps))
47
+ print('fps: {}'.format(fps))
48
+
49
+ # TODO: Maybe we can switch to PyAV later, which can get accurate frame num
50
+ duration = video_reader.get_frame_timestamp(-1)[-1]
51
+ expected_frame_num = int(duration * video_fps + 0.5)
52
+ ratio = abs((frame_num - expected_frame_num)/frame_num)
53
+ if ratio > 0.1:
54
+ print("Warning: The difference between the actual number of frames and the expected number of frames is two large")
55
+ frame_num = expected_frame_num
56
+
57
+ if fps == -1:
58
+ fps = video_fps
59
+
60
+ target_num = int(frame_num / video_fps * fps)
61
+ print('target_num: {}'.format(target_num))
62
+ idxs = get_frame_indices(frame_num, video_fps, target_num, fps)
63
+ frames = video_reader.get_batch(idxs).asnumpy()
64
+
65
+ frames = [resize_by_area(frame, resolution_area[0] * resolution_area[1], divisor=16) for frame in frames]
66
+ height, width = frames[0].shape[:2]
67
+ logger.info(f"Processing pose meta")
68
+
69
+
70
+ tpl_pose_metas = self.pose2d(frames)
71
+
72
+ face_images = []
73
+ for idx, meta in enumerate(tpl_pose_metas):
74
+ face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3,
75
+ image_shape=(frames[0].shape[0], frames[0].shape[1]))
76
+
77
+ x1, x2, y1, y2 = face_bbox_for_image
78
+ face_image = frames[idx][y1:y2, x1:x2]
79
+ face_image = cv2.resize(face_image, (512, 512))
80
+ face_images.append(face_image)
81
+
82
+ logger.info(f"Processing reference image: {refer_image_path}")
83
+ refer_img = cv2.imread(refer_image_path)
84
+ src_ref_path = os.path.join(output_path, 'src_ref.png')
85
+ shutil.copy(refer_image_path, src_ref_path)
86
+ refer_img = refer_img[..., ::-1]
87
+
88
+ refer_img = padding_resize(refer_img, height, width)
89
+ logger.info(f"Processing template video: {video_path}")
90
+ tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas]
91
+ cond_images = []
92
+
93
+ for idx, meta in enumerate(tpl_retarget_pose_metas):
94
+ canvas = np.zeros_like(refer_img)
95
+ conditioning_image = draw_aapose_by_meta_new(canvas, meta)
96
+ cond_images.append(conditioning_image)
97
+ masks = self.get_mask_from_face_bbox(frames, 400, tpl_pose_metas)
98
+
99
+ bg_images = []
100
+ aug_masks = []
101
+
102
+ for frame, mask in zip(frames, masks):
103
+ if iterations > 0:
104
+ _, each_mask = get_mask_body_img(frame, mask, iterations=iterations, k=k)
105
+ each_aug_mask = get_aug_mask(each_mask, w_len=w_len, h_len=h_len)
106
+ else:
107
+ each_aug_mask = mask
108
+
109
+ each_bg_image = frame * (1 - each_aug_mask[:, :, None])
110
+ bg_images.append(each_bg_image)
111
+ aug_masks.append(each_aug_mask)
112
+
113
+ src_face_path = os.path.join(output_path, 'src_face.mp4')
114
+ mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path)
115
+
116
+ src_pose_path = os.path.join(output_path, 'src_pose.mp4')
117
+ mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path)
118
+
119
+ src_bg_path = os.path.join(output_path, 'src_bg.mp4')
120
+ mpy.ImageSequenceClip(bg_images, fps=fps).write_videofile(src_bg_path)
121
+
122
+ aug_masks_new = [np.stack([mask * 255, mask * 255, mask * 255], axis=2) for mask in aug_masks]
123
+ src_mask_path = os.path.join(output_path, 'src_mask.mp4')
124
+ mpy.ImageSequenceClip(aug_masks_new, fps=fps).write_videofile(src_mask_path)
125
+ return True
126
+ else:
127
+ logger.info(f"Processing reference image: {refer_image_path}")
128
+ refer_img = cv2.imread(refer_image_path)
129
+ src_ref_path = os.path.join(output_path, 'src_ref.png')
130
+ shutil.copy(refer_image_path, src_ref_path)
131
+ refer_img = refer_img[..., ::-1]
132
+
133
+ refer_img = resize_by_area(refer_img, resolution_area[0] * resolution_area[1], divisor=16)
134
+
135
+ refer_pose_meta = self.pose2d([refer_img])[0]
136
+
137
+
138
+ logger.info(f"Processing template video: {video_path}")
139
+ video_reader = VideoReader(video_path)
140
+ frame_num = len(video_reader)
141
+ print('frame_num: {}'.format(frame_num))
142
+
143
+ video_fps = video_reader.get_avg_fps()
144
+ print('video_fps: {}'.format(video_fps))
145
+ print('fps: {}'.format(fps))
146
+
147
+ # TODO: Maybe we can switch to PyAV later, which can get accurate frame num
148
+ duration = video_reader.get_frame_timestamp(-1)[-1]
149
+ expected_frame_num = int(duration * video_fps + 0.5)
150
+ ratio = abs((frame_num - expected_frame_num)/frame_num)
151
+ if ratio > 0.1:
152
+ print("Warning: The difference between the actual number of frames and the expected number of frames is two large")
153
+ frame_num = expected_frame_num
154
+
155
+ if fps == -1:
156
+ fps = video_fps
157
+
158
+ target_num = int(frame_num / video_fps * fps)
159
+ print('target_num: {}'.format(target_num))
160
+ idxs = get_frame_indices(frame_num, video_fps, target_num, fps)
161
+ frames = video_reader.get_batch(idxs).asnumpy()
162
+
163
+ logger.info(f"Processing pose meta")
164
+
165
+ tpl_pose_meta0 = self.pose2d(frames[:1])[0]
166
+ tpl_pose_metas = self.pose2d(frames)
167
+
168
+ face_images = []
169
+ for idx, meta in enumerate(tpl_pose_metas):
170
+ face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3,
171
+ image_shape=(frames[0].shape[0], frames[0].shape[1]))
172
+
173
+ x1, x2, y1, y2 = face_bbox_for_image
174
+ face_image = frames[idx][y1:y2, x1:x2]
175
+ face_image = cv2.resize(face_image, (512, 512))
176
+ face_images.append(face_image)
177
+
178
+ if retarget_flag:
179
+ if use_flux:
180
+ tpl_prompt, refer_prompt = self.get_editing_prompts(tpl_pose_metas, refer_pose_meta)
181
+ refer_input = Image.fromarray(refer_img)
182
+ refer_edit = self.flux_kontext(
183
+ image=refer_input,
184
+ height=refer_img.shape[0],
185
+ width=refer_img.shape[1],
186
+ prompt=refer_prompt,
187
+ guidance_scale=2.5,
188
+ num_inference_steps=28,
189
+ ).images[0]
190
+
191
+ refer_edit = Image.fromarray(padding_resize(np.array(refer_edit), refer_img.shape[0], refer_img.shape[1]))
192
+ refer_edit_path = os.path.join(output_path, 'refer_edit.png')
193
+ refer_edit.save(refer_edit_path)
194
+ refer_edit_pose_meta = self.pose2d([np.array(refer_edit)])[0]
195
+
196
+ tpl_img = frames[1]
197
+ tpl_input = Image.fromarray(tpl_img)
198
+
199
+ tpl_edit = self.flux_kontext(
200
+ image=tpl_input,
201
+ height=tpl_img.shape[0],
202
+ width=tpl_img.shape[1],
203
+ prompt=tpl_prompt,
204
+ guidance_scale=2.5,
205
+ num_inference_steps=28,
206
+ ).images[0]
207
+
208
+ tpl_edit = Image.fromarray(padding_resize(np.array(tpl_edit), tpl_img.shape[0], tpl_img.shape[1]))
209
+ tpl_edit_path = os.path.join(output_path, 'tpl_edit.png')
210
+ tpl_edit.save(tpl_edit_path)
211
+ tpl_edit_pose_meta0 = self.pose2d([np.array(tpl_edit)])[0]
212
+ tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, tpl_edit_pose_meta0, refer_edit_pose_meta)
213
+ else:
214
+ tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, None, None)
215
+ else:
216
+ tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas]
217
+
218
+ cond_images = []
219
+ for idx, meta in enumerate(tpl_retarget_pose_metas):
220
+ if retarget_flag:
221
+ canvas = np.zeros_like(refer_img)
222
+ conditioning_image = draw_aapose_by_meta_new(canvas, meta)
223
+ else:
224
+ canvas = np.zeros_like(frames[0])
225
+ conditioning_image = draw_aapose_by_meta_new(canvas, meta)
226
+ conditioning_image = padding_resize(conditioning_image, refer_img.shape[0], refer_img.shape[1])
227
+
228
+ cond_images.append(conditioning_image)
229
+
230
+ src_face_path = os.path.join(output_path, 'src_face.mp4')
231
+ mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path)
232
+
233
+ src_pose_path = os.path.join(output_path, 'src_pose.mp4')
234
+ mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path)
235
+ return True
236
+
237
+ def get_editing_prompts(self, tpl_pose_metas, refer_pose_meta):
238
+ arm_visible = False
239
+ leg_visible = False
240
+ for tpl_pose_meta in tpl_pose_metas:
241
+ tpl_keypoints = tpl_pose_meta['keypoints_body']
242
+ if tpl_keypoints[3].all() != 0 or tpl_keypoints[4].all() != 0 or tpl_keypoints[6].all() != 0 or tpl_keypoints[7].all() != 0:
243
+ if (tpl_keypoints[3][0] <= 1 and tpl_keypoints[3][1] <= 1 and tpl_keypoints[3][2] >= 0.75) or (tpl_keypoints[4][0] <= 1 and tpl_keypoints[4][1] <= 1 and tpl_keypoints[4][2] >= 0.75) or \
244
+ (tpl_keypoints[6][0] <= 1 and tpl_keypoints[6][1] <= 1 and tpl_keypoints[6][2] >= 0.75) or (tpl_keypoints[7][0] <= 1 and tpl_keypoints[7][1] <= 1 and tpl_keypoints[7][2] >= 0.75):
245
+ arm_visible = True
246
+ if tpl_keypoints[9].all() != 0 or tpl_keypoints[12].all() != 0 or tpl_keypoints[10].all() != 0 or tpl_keypoints[13].all() != 0:
247
+ if (tpl_keypoints[9][0] <= 1 and tpl_keypoints[9][1] <= 1 and tpl_keypoints[9][2] >= 0.75) or (tpl_keypoints[12][0] <= 1 and tpl_keypoints[12][1] <= 1 and tpl_keypoints[12][2] >= 0.75) or \
248
+ (tpl_keypoints[10][0] <= 1 and tpl_keypoints[10][1] <= 1 and tpl_keypoints[10][2] >= 0.75) or (tpl_keypoints[13][0] <= 1 and tpl_keypoints[13][1] <= 1 and tpl_keypoints[13][2] >= 0.75):
249
+ leg_visible = True
250
+ if arm_visible and leg_visible:
251
+ break
252
+
253
+ if leg_visible:
254
+ if tpl_pose_meta['width'] > tpl_pose_meta['height']:
255
+ tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image."
256
+ else:
257
+ tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image."
258
+
259
+ if refer_pose_meta['width'] > refer_pose_meta['height']:
260
+ refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image."
261
+ else:
262
+ refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image."
263
+ elif arm_visible:
264
+ if tpl_pose_meta['width'] > tpl_pose_meta['height']:
265
+ tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image."
266
+ else:
267
+ tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image."
268
+
269
+ if refer_pose_meta['width'] > refer_pose_meta['height']:
270
+ refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image."
271
+ else:
272
+ refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image."
273
+ else:
274
+ tpl_prompt = "Change the person to face forward."
275
+ refer_prompt = "Change the person to face forward."
276
+
277
+ return tpl_prompt, refer_prompt
278
+
279
+
280
+ def get_mask(self, frames, th_step, kp2ds_all):
281
+ frame_num = len(frames)
282
+ if frame_num < th_step:
283
+ num_step = 1
284
+ else:
285
+ num_step = (frame_num + th_step) // th_step
286
+
287
+ all_mask = []
288
+ for index in range(num_step):
289
+ each_frames = frames[index * th_step:(index + 1) * th_step]
290
+
291
+ kp2ds = kp2ds_all[index * th_step:(index + 1) * th_step]
292
+ if len(each_frames) > 4:
293
+ key_frame_num = 4
294
+ elif 4 >= len(each_frames) > 0:
295
+ key_frame_num = 1
296
+ else:
297
+ continue
298
+
299
+ key_frame_step = len(kp2ds) // key_frame_num
300
+ key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))
301
+
302
+ key_points_index = [0, 1, 2, 5, 8, 11, 10, 13]
303
+ key_frame_body_points_list = []
304
+ for key_frame_index in key_frame_index_list:
305
+ keypoints_body_list = []
306
+ body_key_points = kp2ds[key_frame_index]['keypoints_face']
307
+ for each_index in key_points_index:
308
+ each_keypoint = body_key_points[each_index]
309
+ if None is each_keypoint:
310
+ continue
311
+ keypoints_body_list.append(each_keypoint)
312
+
313
+ keypoints_body = np.array(keypoints_body_list)[:, :2]
314
+ wh = np.array([[kp2ds[0]['width'], kp2ds[0]['height']]])
315
+ points = (keypoints_body * wh).astype(np.int32)
316
+ key_frame_body_points_list.append(points)
317
+
318
+ inference_state = self.predictor.init_state_v2(frames=each_frames)
319
+ self.predictor.reset_state(inference_state)
320
+ ann_obj_id = 1
321
+ for ann_frame_idx, points in zip(key_frame_index_list, key_frame_body_points_list):
322
+ labels = np.array([1] * points.shape[0], np.int32)
323
+ _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
324
+ inference_state=inference_state,
325
+ frame_idx=ann_frame_idx,
326
+ obj_id=ann_obj_id,
327
+ points=points,
328
+ labels=labels,
329
+ )
330
+
331
+ video_segments = {}
332
+ for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
333
+ video_segments[out_frame_idx] = {
334
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
335
+ for i, out_obj_id in enumerate(out_obj_ids)
336
+ }
337
+
338
+ for out_frame_idx in range(len(video_segments)):
339
+ for out_obj_id, out_mask in video_segments[out_frame_idx].items():
340
+ out_mask = out_mask[0].astype(np.uint8)
341
+ all_mask.append(out_mask)
342
+
343
+ return all_mask
344
+
345
+ def convert_list_to_array(self, metas):
346
+ metas_list = []
347
+ for meta in metas:
348
+ for key, value in meta.items():
349
+ if type(value) is list:
350
+ value = np.array(value)
351
+ meta[key] = value
352
+ metas_list.append(meta)
353
+ return metas_list
354
+
355
+ def get_mask_from_face_bbox(self, frames, th_step, kp2ds_all):
356
+ """
357
+ Build masks using a face bounding box per key frame (derived from keypoints_face),
358
+ then propagate with SAM2 across each chunk of frames.
359
+ """
360
+ H, W = frames[0].shape[:2]
361
+
362
+ def _clip_box(x1, y1, x2, y2, W, H):
363
+ x1 = max(0, min(int(x1), W - 1))
364
+ x2 = max(0, min(int(x2), W - 1))
365
+ y1 = max(0, min(int(y1), H - 1))
366
+ y2 = max(0, min(int(y2), H - 1))
367
+ if x2 <= x1: x2 = min(W - 1, x1 + 1)
368
+ if y2 <= y1: y2 = min(H - 1, y1 + 1)
369
+ return x1, y1, x2, y2
370
+
371
+ frame_num = len(frames)
372
+ if frame_num < th_step:
373
+ num_step = 1
374
+ else:
375
+ num_step = (frame_num + th_step) // th_step
376
+
377
+ all_mask = []
378
+
379
+ for step_idx in range(num_step):
380
+ each_frames = frames[step_idx * th_step:(step_idx + 1) * th_step]
381
+ kp2ds = kp2ds_all[step_idx * th_step:(step_idx + 1) * th_step]
382
+ if len(each_frames) == 0:
383
+ continue
384
+
385
+ # pick a few key frames in this chunk
386
+ key_frame_num = 4 if len(each_frames) > 4 else 1
387
+ key_frame_step = max(1, len(kp2ds) // key_frame_num)
388
+ key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))[:key_frame_num]
389
+
390
+ # compute face boxes on the selected key frames
391
+ key_frame_boxes = []
392
+ for kfi in key_frame_index_list:
393
+ meta = kp2ds[kfi]
394
+ # get_face_bboxes returns (x1, x2, y1, y2) in your code
395
+ x1, x2, y1, y2 = get_face_bboxes(
396
+ meta['keypoints_face'][:, :2],
397
+ scale=1.3,
398
+ image_shape=(H, W)
399
+ )
400
+ x1, y1, x2, y2 = _clip_box(x1, y1, x2, y2, W, H)
401
+ key_frame_boxes.append(np.array([x1, y1, x2, y2], dtype=np.float32))
402
+
403
+ # init SAM2 for this chunk
404
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
405
+ inference_state = self.predictor.init_state_v2(frames=each_frames)
406
+ self.predictor.reset_state(inference_state)
407
+ ann_obj_id = 1
408
+
409
+ # seed with box prompts (preferred), else fall back to points
410
+ for ann_frame_idx, box_xyxy in zip(key_frame_index_list, key_frame_boxes):
411
+ used_box = False
412
+ try:
413
+ # If your predictor exposes a box API, this is ideal.
414
+ _ = self.predictor.add_new_box(
415
+ inference_state=inference_state,
416
+ frame_idx=ann_frame_idx,
417
+ obj_id=ann_obj_id,
418
+ box=box_xyxy[None, :] # shape (1, 4)
419
+ )
420
+ used_box = True
421
+ except Exception:
422
+ used_box = False
423
+
424
+ if not used_box:
425
+ # Fallback: sample a few positive points inside the box
426
+ x1, y1, x2, y2 = box_xyxy.astype(int)
427
+ cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
428
+ pts = np.array([
429
+ [cx, cy],
430
+ [x1 + (x2 - x1) // 4, cy],
431
+ [x2 - (x2 - x1) // 4, cy],
432
+ [cx, y1 + (y2 - y1) // 4],
433
+ [cx, y2 - (y2 - y1) // 4],
434
+ ], dtype=np.int32)
435
+ labels = np.ones(len(pts), dtype=np.int32) # 1 = positive
436
+ _ = self.predictor.add_new_points(
437
+ inference_state=inference_state,
438
+ frame_idx=ann_frame_idx,
439
+ obj_id=ann_obj_id,
440
+ points=pts,
441
+ labels=labels,
442
+ )
443
+
444
+ # propagate across the chunk
445
+ video_segments = {}
446
+ for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
447
+ video_segments[out_frame_idx] = {
448
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
449
+ for i, out_obj_id in enumerate(out_obj_ids)
450
+ }
451
+
452
+ # collect masks (single object id)
453
+ for out_frame_idx in range(len(video_segments)):
454
+ # (H, W) boolean/uint8
455
+ mask = next(iter(video_segments[out_frame_idx].values()))
456
+ mask = mask[0].astype(np.uint8)
457
+ all_mask.append(mask)
458
+
459
+ return all_mask
460
+ def get_mask_from_face_point(self, frames, th_step, kp2ds_all):
461
+ """
462
+ Build masks using a single face *center point* per key frame,
463
+ then propagate with SAM2 across each chunk of frames.
464
+ """
465
+ H, W = frames[0].shape[:2]
466
+
467
+ frame_num = len(frames)
468
+ num_step = 1 if frame_num < th_step else (frame_num + th_step) // th_step
469
+
470
+ all_mask = []
471
+
472
+ for step_idx in range(num_step):
473
+ each_frames = frames[step_idx * th_step:(step_idx + 1) * th_step]
474
+ kp2ds = kp2ds_all[step_idx * th_step:(step_idx + 1) * th_step]
475
+ if len(each_frames) == 0:
476
+ continue
477
+
478
+ # choose a few key frames to seed the object
479
+ key_frame_num = 1
480
+ key_frame_step = max(1, len(kp2ds) // key_frame_num)
481
+ key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))[:key_frame_num]
482
+
483
+ # compute center point from face bbox for each selected key frame
484
+ center_pts = []
485
+ for kfi in key_frame_index_list:
486
+ meta = kp2ds[kfi]
487
+ # your helper returns (x1, x2, y1, y2)
488
+ x1, x2, y1, y2 = get_face_bboxes(
489
+ meta['keypoints_face'][:, :2],
490
+ scale=1.3,
491
+ image_shape=(H, W)
492
+ )
493
+ cx = (x1 + x2) // 2
494
+ cy = (y1 + y2) // 2
495
+ # clip just in case
496
+ cx = int(max(0, min(cx, W - 1)))
497
+ cy = int(max(0, min(cy, H - 1)))
498
+ center_pts.append(np.array([cx, cy], dtype=np.int32))
499
+
500
+ # init SAM2 for this chunk
501
+ inference_state = self.predictor.init_state_v2(frames=each_frames)
502
+ self.predictor.reset_state(inference_state)
503
+ ann_obj_id = 1
504
+
505
+ # seed each key frame with a single positive point at the face center
506
+ for ann_frame_idx, pt in zip(key_frame_index_list, center_pts):
507
+ pts = pt[None, :] # shape (1, 2)
508
+ labels = np.ones(1, dtype=np.int32) # 1 = positive
509
+ _ = self.predictor.add_new_points(
510
+ inference_state=inference_state,
511
+ frame_idx=ann_frame_idx,
512
+ obj_id=ann_obj_id,
513
+ points=pts,
514
+ labels=labels,
515
+ )
516
+
517
+ # propagate across the chunk
518
+ video_segments = {}
519
+ for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
520
+ video_segments[out_frame_idx] = {
521
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
522
+ for i, out_obj_id in enumerate(out_obj_ids)
523
+ }
524
+
525
+ # collect masks (single object id)
526
+ for out_frame_idx in range(len(video_segments)):
527
+ mask = next(iter(video_segments[out_frame_idx].values()))
528
+ mask = mask[0].astype(np.uint8)
529
+ all_mask.append(mask)
530
+
531
+ return all_mask
532
+
533
+ def get_face_bbox_masks(self, frames, kp2ds_all, scale=1.3, feather_px=0, keep_soft=False):
534
+ """
535
+ Create a per-frame mask that's simply the face bounding box.
536
+ - scale: bbox scale factor used by get_face_bboxes
537
+ - feather_px: optional Gaussian blur in pixels to feather edges
538
+ - keep_soft: if True, keep float [0,1] soft mask (after blur); else binarize to {0,1}
539
+ """
540
+ H, W = frames[0].shape[:2]
541
+
542
+ def _clip_box(x1, y1, x2, y2):
543
+ x1 = max(0, min(int(x1), W - 1))
544
+ x2 = max(0, min(int(x2), W - 1))
545
+ y1 = max(0, min(int(y1), H - 1))
546
+ y2 = max(0, min(int(y2), H - 1))
547
+ if x2 <= x1: x2 = min(W - 1, x1 + 1)
548
+ if y2 <= y1: y2 = min(H - 1, y1 + 1)
549
+ return x1, y1, x2, y2
550
+
551
+ masks = []
552
+ last_box = None
553
+ for meta in kp2ds_all:
554
+ # get_face_bboxes returns (x1, x2, y1, y2)
555
+ try:
556
+ x1, x2, y1, y2 = get_face_bboxes(
557
+ meta['keypoints_face'][:, :2],
558
+ scale=scale,
559
+ image_shape=(H, W)
560
+ )
561
+ x1, y1, x2, y2 = _clip_box(x1, y1, x2, y2)
562
+ last_box = (x1, y1, x2, y2)
563
+ except Exception:
564
+ # fallback: reuse last seen box to avoid holes
565
+ if last_box is None:
566
+ # no detection yet: push empty mask
567
+ masks.append(np.zeros((H, W), dtype=np.uint8))
568
+ continue
569
+ x1, y1, x2, y2 = last_box
570
+
571
+ m = np.zeros((H, W), dtype=np.float32)
572
+ m[y1:y2, x1:x2] = 1.0
573
+
574
+ if feather_px and feather_px > 0:
575
+ # kernel size must be odd and >= 3
576
+ k = max(3, int(feather_px) | 1)
577
+ m = cv2.GaussianBlur(m, (k, k), 0)
578
+
579
+ if keep_soft:
580
+ masks.append(m) # float [0,1]
581
+ else:
582
+ masks.append((m >= 0.5).astype(np.uint8)) # hard {0,1}
583
+
584
+ return masks
585
+