import os import gradio as gr import cv2 import numpy as np from PIL import Image os.makedirs("./SAM2-Video-Predictor/checkpoints/", exist_ok=True) os.makedirs("./model/", exist_ok=True) from huggingface_hub import snapshot_download def download_sam2(): snapshot_download(repo_id="facebook/sam2-hiera-large", local_dir="./SAM2-Video-Predictor/checkpoints/") print("Download sam2 completed") def download_remover(): snapshot_download(repo_id="zibojia/minimax-remover", local_dir="./model/") print("Download minimax remover completed") download_sam2() download_remover() import torch import argparse import random import torch.nn.functional as F import time import random from omegaconf import OmegaConf from einops import rearrange from diffusers.models import AutoencoderKLWan import scipy from transformer_minimax_remover import Transformer3DModel from einops import rearrange from diffusers.schedulers import UniPCMultistepScheduler from pipeline_minimax_remover import Minimax_Remover_Pipeline from diffusers.utils import export_to_video from decord import VideoReader, cpu from moviepy.editor import ImageSequenceClip from sam2 import load_model from sam2.build_sam import build_sam2, build_sam2_video_predictor from sam2.sam2_image_predictor import SAM2ImagePredictor import spaces COLOR_PALETTE = [ (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255), (255, 128, 0), (128, 0, 255), (0, 128, 255), (128, 255, 0) ] random_seed = 42 video_length = 201 W = 1024 H = W device = "cpu" def get_pipe_image_and_video_predictor(): vae = AutoencoderKLWan.from_pretrained("./model/vae", torch_dtype=torch.float16) transformer = Transformer3DModel.from_pretrained("./model/transformer", torch_dtype=torch.float16) scheduler = UniPCMultistepScheduler.from_pretrained("./model/scheduler") pipe = Minimax_Remover_Pipeline(transformer=transformer, vae=vae, scheduler=scheduler) sam2_checkpoint = "./SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt" config = "sam2_hiera_l.yaml" video_predictor = build_sam2_video_predictor(config, sam2_checkpoint, device=device) model = build_sam2(config, sam2_checkpoint, device=device) model.image_size = 1024 image_predictor = SAM2ImagePredictor(sam_model=model) return pipe, image_predictor, video_predictor def get_video_info(video_path, video_state): video_state["input_points"] = [] video_state["scaled_points"] = [] video_state["input_labels"] = [] video_state["frame_idx"] = 0 vr = VideoReader(video_path, ctx=cpu(0)) first_frame = vr[0].asnumpy() del vr if first_frame.shape[0] > first_frame.shape[1]: W_ = W H_ = int(W_ * first_frame.shape[0] / first_frame.shape[1]) else: H_ = H W_ = int(H_ * first_frame.shape[1] / first_frame.shape[0]) first_frame = cv2.resize(first_frame, (W_, H_)) video_state["origin_images"] = np.expand_dims(first_frame, axis=0) video_state["inference_state"] = None video_state["video_path"] = video_path video_state["masks"] = None video_state["painted_images"] = None image = Image.fromarray(first_frame) return image def segment_frame(evt: gr.SelectData, label, video_state): if video_state["origin_images"] is None: return None x, y = evt.index new_point = [x, y] label_value = 1 if label == "Positive" else 0 video_state["input_points"].append(new_point) video_state["input_labels"].append(label_value) height, width = video_state["origin_images"][0].shape[0:2] scaled_points = [] for pt in video_state["input_points"]: sx = pt[0] / width sy = pt[1] / height scaled_points.append([sx, sy]) video_state["scaled_points"] = scaled_points image_predictor.set_image(video_state["origin_images"][0]) mask, _, _ = image_predictor.predict( point_coords=video_state["scaled_points"], point_labels=video_state["input_labels"], multimask_output=False, normalize_coords=False, ) mask = np.squeeze(mask) mask = cv2.resize(mask, (width, height)) mask = mask[:,:,None] color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0 color = color[None, None, :] org_image = video_state["origin_images"][0].astype(np.float32) / 255.0 painted_image = (1 - mask * 0.5) * org_image + mask * 0.5 * color painted_image = np.uint8(np.clip(painted_image * 255, 0, 255)) video_state["painted_images"] = np.expand_dims(painted_image, axis=0) video_state["masks"] = np.expand_dims(mask[:,:,0], axis=0) for i in range(len(video_state["input_points"])): point = video_state["input_points"][i] if video_state["input_labels"][i] == 0: cv2.circle(painted_image, point, radius=3, color=(0, 0, 255), thickness=-1) # 红色点,半径为3 else: cv2.circle(painted_image, point, radius=3, color=(255, 0, 0), thickness=-1) return Image.fromarray(painted_image) def clear_clicks(video_state): video_state["input_points"] = [] video_state["input_labels"] = [] video_state["scaled_points"] = [] video_state["inference_state"] = None video_state["masks"] = None video_state["painted_images"] = None return Image.fromarray(video_state["origin_images"][0]) if video_state["origin_images"] is not None else None def preprocess_for_removal(images, masks): out_images = [] out_masks = [] for img, msk in zip(images, masks): if img.shape[0] > img.shape[1]: img_resized = cv2.resize(img, (480, 832), interpolation=cv2.INTER_LINEAR) else: img_resized = cv2.resize(img, (832, 480), interpolation=cv2.INTER_LINEAR) img_resized = img_resized.astype(np.float32) / 127.5 - 1.0 # [-1, 1] out_images.append(img_resized) if msk.shape[0] > msk.shape[1]: msk_resized = cv2.resize(msk, (480, 832), interpolation=cv2.INTER_NEAREST) else: msk_resized = cv2.resize(msk, (832, 480), interpolation=cv2.INTER_NEAREST) msk_resized = msk_resized.astype(np.float32) msk_resized = (msk_resized > 0.5).astype(np.float32) out_masks.append(msk_resized) arr_images = np.stack(out_images) arr_masks = np.stack(out_masks) return torch.from_numpy(arr_images).half(), torch.from_numpy(arr_masks).half() @spaces.GPU(duration=200) def inference_and_return_video(dilation_iterations, num_inference_steps, video_state): if video_state["origin_images"] is None or video_state["masks"] is None: return None images = video_state["origin_images"] masks = video_state["masks"] images = np.array(images) masks = np.array(masks) img_tensor, mask_tensor = preprocess_for_removal(images, masks) img_tensor=img_tensor.to("cuda") mask_tensor = mask_tensor[:,:,:,:1].to("cuda") if mask_tensor.shape[1] < mask_tensor.shape[2]: height = 480 width = 832 else: height = 832 width = 480 pipe.to("cuda") with torch.no_grad(): out = pipe( images=img_tensor, masks=mask_tensor, num_frames=mask_tensor.shape[0], height=height, width=width, num_inference_steps=int(num_inference_steps), generator=torch.Generator(device=device).manual_seed(random_seed), iterations=int(dilation_iterations) ).frames[0] out = np.uint8(out * 255) output_frames = [img for img in out] video_file = f"/tmp/{time.time()}-{random.random()}-removed_output.mp4" clip = ImageSequenceClip(output_frames, fps=15) clip.write_videofile(video_file, codec='libx264', audio=False, verbose=False, logger=None) return video_file @spaces.GPU(duration=100) def track_video(n_frames,video_state): input_points = video_state["input_points"] input_labels = video_state["input_labels"] frame_idx = video_state["frame_idx"] obj_id = video_state["obj_id"] scaled_points = video_state["scaled_points"] vr = VideoReader(video_state["video_path"], ctx=cpu(0)) height, width = vr[0].shape[0:2] images = [vr[i].asnumpy() for i in range(min(len(vr), n_frames))] del vr if images[0].shape[0] > images[0].shape[1]: W_ = W H_ = int(W_ * images[0].shape[0] / images[0].shape[1]) else: H_ = H W_ = int(H_ * images[0].shape[1] / images[0].shape[0]) images = [cv2.resize(img, (W_, H_)) for img in images] video_state["origin_images"] = images images = np.array(images) sam2_checkpoint = "./SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt" config = "sam2_hiera_l.yaml" video_predictor_local = build_sam2_video_predictor(config, sam2_checkpoint, device="cuda") inference_state = video_predictor_local.init_state(images=images/255, device="cuda") #video_state["inference_state"] = inference_state #cause bug if len(torch.from_numpy(video_state["masks"][0]).shape) == 3: mask = torch.from_numpy(video_state["masks"][0])[:,:,0] else: mask = torch.from_numpy(video_state["masks"][0]) video_predictor_local.add_new_mask( inference_state=inference_state, frame_idx=0, obj_id=obj_id, mask=mask ) output_frames = [] mask_frames = [] color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0 color = color[None, None, :] for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video(inference_state): frame = images[out_frame_idx].astype(np.float32) / 255.0 mask = np.zeros((H, W, 3), dtype=np.float32) for i, logit in enumerate(out_mask_logits): out_mask = logit.cpu().squeeze().detach().numpy() out_mask = (out_mask[:,:,None] > 0).astype(np.float32) mask += out_mask mask = np.clip(mask, 0, 1) mask = cv2.resize(mask, (W_, H_)) mask_frames.append(mask) painted = (1 - mask * 0.5) * frame + mask * 0.5 * color painted = np.uint8(np.clip(painted * 255, 0, 255)) output_frames.append(painted) video_state["masks"] =mask_frames video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4" clip = ImageSequenceClip(output_frames, fps=15) clip.write_videofile(video_file, codec='libx264', audio=False, verbose=False, logger=None) print("line 286 done") return video_file,video_state text = """