MiniMax-Remover / pipeline_minimax_remover.py
PengWeixuanSZU's picture
Update pipeline_minimax_remover.py
5c93d77 verified
from typing import Callable, Dict, List, Optional, Union
import torch
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.models import AutoencoderKLWan
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
import scipy
import numpy as np
import torch.nn.functional as F
from transformer_minimax_remover import Transformer3DModel
from einops import rearrange
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
class Minimax_Remover_Pipeline(DiffusionPipeline):
model_cpu_offload_seq = "transformer->vae"
_callback_tensor_inputs = ["latents"]
def __init__(
self,
transformer: Transformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler
):
super().__init__()
self.register_modules(
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def prepare_latents(
self,
batch_size: int,
num_channels_latents: 16,
height: int = 720,
width: int = 1280,
num_latent_frames: int = 21,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def expand_masks(self, masks, iterations):
masks = masks.cpu().detach().numpy()
# numpy array, masks [0,1], f h w c
masks2 = []
for i in range(len(masks)):
mask = masks[i]
mask = mask > 0
mask = scipy.ndimage.binary_dilation(mask, iterations=iterations)
masks2.append(mask)
masks = np.array(masks2).astype(np.float32)
masks = torch.from_numpy(masks)
masks = masks.repeat(1,1,1,3)
masks = rearrange(masks, "f h w c -> c f h w")
masks = masks[None,...]
return masks
def resize(self, images, w, h):
bsz,_,_,_,_ = images.shape
images = rearrange(images, "b c f w h -> (b f) c w h")
images = F.interpolate(images, (w,h), mode='bilinear')
images = rearrange(images, "(b f) c w h -> b c f w h", b=bsz)
return images
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def __call__(
self,
height: int = 720,
width: int = 1280,
num_frames: int = 81,
num_inference_steps: int = 50,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
images: Optional[torch.Tensor] = None,
masks: Optional[torch.Tensor] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
iterations: int = 16
):
self._current_timestep = None
self._interrupt = False
device = self._execution_device
batch_size = 1
transformer_dtype = torch.float16
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
num_channels_latents = 16
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latents = self.prepare_latents(
batch_size,
num_channels_latents,
height,
width,
num_latent_frames,
torch.float16,
device,
generator,
latents,
)
masks = self.expand_masks(masks, iterations)
masks = self.resize(masks, height, width).to("cuda:0").half()
masks[masks>0] = 1
images = rearrange(images, "f h w c -> c f h w")
images = self.resize(images[None,...], height, width).to("cuda:0").half()
masked_images = images * (1-masks)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(self.vae.device, torch.float16)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
self.vae.device, torch.float16
)
with torch.no_grad():
masked_latents = self.vae.encode(masked_images.half()).latent_dist.mode()
masks_latents = self.vae.encode(2*masks.half()-1.0).latent_dist.mode()
masked_latents = (masked_latents - latents_mean) * latents_std
masks_latents = (masks_latents - latents_mean) * latents_std
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = latents.to(transformer_dtype)
#print("latent_model_input, masked_latents, masks_latents", latent_model_input.shape, masked_latents.shape, masks_latents.shape)
latent_model_input = torch.cat([latent_model_input, masked_latents, masks_latents], dim=1)
timestep = t.expand(latents.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input.half(),
timestep=timestep
)[0]
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
progress_bar.update()
latents = latents.half() / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
return WanPipelineOutput(frames=video)