Spaces:
Runtime error
Runtime error
| # Adapted from OpenSora | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # -------------------------------------------------------- | |
| # References: | |
| # OpenSora: https://github.com/hpcaitech/Open-Sora | |
| # -------------------------------------------------------- | |
| import os | |
| from typing import Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder | |
| from einops import rearrange | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| class DiagonalGaussianDistribution(object): | |
| def __init__( | |
| self, | |
| parameters, | |
| deterministic=False, | |
| ): | |
| self.parameters = parameters | |
| self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) | |
| self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | |
| self.deterministic = deterministic | |
| self.std = torch.exp(0.5 * self.logvar) | |
| self.var = torch.exp(self.logvar) | |
| if self.deterministic: | |
| self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device, dtype=self.mean.dtype) | |
| def sample(self): | |
| # torch.randn: standard normal distribution | |
| x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device, dtype=self.mean.dtype) | |
| return x | |
| def kl(self, other=None): | |
| if self.deterministic: | |
| return torch.Tensor([0.0]) | |
| else: | |
| if other is None: # SCH: assumes other is a standard normal distribution | |
| return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3, 4]) | |
| else: | |
| return 0.5 * torch.sum( | |
| torch.pow(self.mean - other.mean, 2) / other.var | |
| + self.var / other.var | |
| - 1.0 | |
| - self.logvar | |
| + other.logvar, | |
| dim=[1, 2, 3, 4], | |
| ) | |
| def nll(self, sample, dims=[1, 2, 3, 4]): | |
| if self.deterministic: | |
| return torch.Tensor([0.0]) | |
| logtwopi = np.log(2.0 * np.pi) | |
| return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) | |
| def mode(self): | |
| return self.mean | |
| def cast_tuple(t, length=1): | |
| return t if isinstance(t, tuple) else ((t,) * length) | |
| def divisible_by(num, den): | |
| return (num % den) == 0 | |
| def is_odd(n): | |
| return not divisible_by(n, 2) | |
| def pad_at_dim(t, pad, dim=-1): | |
| dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) | |
| zeros = (0, 0) * dims_from_right | |
| return F.pad(t, (*zeros, *pad), mode="constant") | |
| def exists(v): | |
| return v is not None | |
| class CausalConv3d(nn.Module): | |
| def __init__( | |
| self, | |
| chan_in, | |
| chan_out, | |
| kernel_size: Union[int, Tuple[int, int, int]], | |
| pad_mode="constant", | |
| strides=None, # allow custom stride | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| kernel_size = cast_tuple(kernel_size, 3) | |
| time_kernel_size, height_kernel_size, width_kernel_size = kernel_size | |
| assert is_odd(height_kernel_size) and is_odd(width_kernel_size) | |
| dilation = kwargs.pop("dilation", 1) | |
| stride = strides[0] if strides is not None else kwargs.pop("stride", 1) | |
| self.pad_mode = pad_mode | |
| time_pad = dilation * (time_kernel_size - 1) + (1 - stride) | |
| height_pad = height_kernel_size // 2 | |
| width_pad = width_kernel_size // 2 | |
| self.time_pad = time_pad | |
| self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) | |
| stride = strides if strides is not None else (stride, 1, 1) | |
| dilation = (dilation, 1, 1) | |
| self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) | |
| def forward(self, x): | |
| x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) | |
| x = self.conv(x) | |
| return x | |
| class ResBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, # SCH: added | |
| filters, | |
| conv_fn, | |
| activation_fn=nn.SiLU, | |
| use_conv_shortcut=False, | |
| num_groups=32, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.filters = filters | |
| self.activate = activation_fn() | |
| self.use_conv_shortcut = use_conv_shortcut | |
| # SCH: MAGVIT uses GroupNorm by default | |
| self.norm1 = nn.GroupNorm(num_groups, in_channels) | |
| self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False) | |
| self.norm2 = nn.GroupNorm(num_groups, self.filters) | |
| self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False) | |
| if in_channels != filters: | |
| if self.use_conv_shortcut: | |
| self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False) | |
| else: | |
| self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False) | |
| def forward(self, x): | |
| residual = x | |
| x = self.norm1(x) | |
| x = self.activate(x) | |
| x = self.conv1(x) | |
| x = self.norm2(x) | |
| x = self.activate(x) | |
| x = self.conv2(x) | |
| if self.in_channels != self.filters: # SCH: ResBlock X->Y | |
| residual = self.conv3(residual) | |
| return x + residual | |
| def get_activation_fn(activation): | |
| if activation == "relu": | |
| activation_fn = nn.ReLU | |
| elif activation == "swish": | |
| activation_fn = nn.SiLU | |
| else: | |
| raise NotImplementedError | |
| return activation_fn | |
| class Encoder(nn.Module): | |
| """Encoder Blocks.""" | |
| def __init__( | |
| self, | |
| in_out_channels=4, | |
| latent_embed_dim=512, # num channels for latent vector | |
| filters=128, | |
| num_res_blocks=4, | |
| channel_multipliers=(1, 2, 2, 4), | |
| temporal_downsample=(False, True, True), | |
| num_groups=32, # for nn.GroupNorm | |
| activation_fn="swish", | |
| ): | |
| super().__init__() | |
| self.filters = filters | |
| self.num_res_blocks = num_res_blocks | |
| self.num_blocks = len(channel_multipliers) | |
| self.channel_multipliers = channel_multipliers | |
| self.temporal_downsample = temporal_downsample | |
| self.num_groups = num_groups | |
| self.embedding_dim = latent_embed_dim | |
| self.activation_fn = get_activation_fn(activation_fn) | |
| self.activate = self.activation_fn() | |
| self.conv_fn = CausalConv3d | |
| self.block_args = dict( | |
| conv_fn=self.conv_fn, | |
| activation_fn=self.activation_fn, | |
| use_conv_shortcut=False, | |
| num_groups=self.num_groups, | |
| ) | |
| # first layer conv | |
| self.conv_in = self.conv_fn( | |
| in_out_channels, | |
| filters, | |
| kernel_size=(3, 3, 3), | |
| bias=False, | |
| ) | |
| # ResBlocks and conv downsample | |
| self.block_res_blocks = nn.ModuleList([]) | |
| self.conv_blocks = nn.ModuleList([]) | |
| filters = self.filters | |
| prev_filters = filters # record for in_channels | |
| for i in range(self.num_blocks): | |
| filters = self.filters * self.channel_multipliers[i] | |
| block_items = nn.ModuleList([]) | |
| for _ in range(self.num_res_blocks): | |
| block_items.append(ResBlock(prev_filters, filters, **self.block_args)) | |
| prev_filters = filters # update in_channels | |
| self.block_res_blocks.append(block_items) | |
| if i < self.num_blocks - 1: | |
| if self.temporal_downsample[i]: | |
| t_stride = 2 if self.temporal_downsample[i] else 1 | |
| s_stride = 1 | |
| self.conv_blocks.append( | |
| self.conv_fn( | |
| prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride) | |
| ) | |
| ) | |
| prev_filters = filters # update in_channels | |
| else: | |
| # if no t downsample, don't add since this does nothing for pipeline models | |
| self.conv_blocks.append(nn.Identity(prev_filters)) # Identity | |
| prev_filters = filters # update in_channels | |
| # last layer res block | |
| self.res_blocks = nn.ModuleList([]) | |
| for _ in range(self.num_res_blocks): | |
| self.res_blocks.append(ResBlock(prev_filters, filters, **self.block_args)) | |
| prev_filters = filters # update in_channels | |
| # MAGVIT uses Group Normalization | |
| self.norm1 = nn.GroupNorm(self.num_groups, prev_filters) | |
| self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), padding="same") | |
| def forward(self, x): | |
| x = self.conv_in(x) | |
| for i in range(self.num_blocks): | |
| for j in range(self.num_res_blocks): | |
| x = self.block_res_blocks[i][j](x) | |
| if i < self.num_blocks - 1: | |
| x = self.conv_blocks[i](x) | |
| for i in range(self.num_res_blocks): | |
| x = self.res_blocks[i](x) | |
| x = self.norm1(x) | |
| x = self.activate(x) | |
| x = self.conv2(x) | |
| return x | |
| class Decoder(nn.Module): | |
| """Decoder Blocks.""" | |
| def __init__( | |
| self, | |
| in_out_channels=4, | |
| latent_embed_dim=512, | |
| filters=128, | |
| num_res_blocks=4, | |
| channel_multipliers=(1, 2, 2, 4), | |
| temporal_downsample=(False, True, True), | |
| num_groups=32, # for nn.GroupNorm | |
| activation_fn="swish", | |
| ): | |
| super().__init__() | |
| self.filters = filters | |
| self.num_res_blocks = num_res_blocks | |
| self.num_blocks = len(channel_multipliers) | |
| self.channel_multipliers = channel_multipliers | |
| self.temporal_downsample = temporal_downsample | |
| self.num_groups = num_groups | |
| self.embedding_dim = latent_embed_dim | |
| self.s_stride = 1 | |
| self.activation_fn = get_activation_fn(activation_fn) | |
| self.activate = self.activation_fn() | |
| self.conv_fn = CausalConv3d | |
| self.block_args = dict( | |
| conv_fn=self.conv_fn, | |
| activation_fn=self.activation_fn, | |
| use_conv_shortcut=False, | |
| num_groups=self.num_groups, | |
| ) | |
| filters = self.filters * self.channel_multipliers[-1] | |
| prev_filters = filters | |
| # last conv | |
| self.conv1 = self.conv_fn(self.embedding_dim, filters, kernel_size=(3, 3, 3), bias=True) | |
| # last layer res block | |
| self.res_blocks = nn.ModuleList([]) | |
| for _ in range(self.num_res_blocks): | |
| self.res_blocks.append(ResBlock(filters, filters, **self.block_args)) | |
| # ResBlocks and conv upsample | |
| self.block_res_blocks = nn.ModuleList([]) | |
| self.num_blocks = len(self.channel_multipliers) | |
| self.conv_blocks = nn.ModuleList([]) | |
| # reverse to keep track of the in_channels, but append also in a reverse direction | |
| for i in reversed(range(self.num_blocks)): | |
| filters = self.filters * self.channel_multipliers[i] | |
| # resblock handling | |
| block_items = nn.ModuleList([]) | |
| for _ in range(self.num_res_blocks): | |
| block_items.append(ResBlock(prev_filters, filters, **self.block_args)) | |
| prev_filters = filters # SCH: update in_channels | |
| self.block_res_blocks.insert(0, block_items) # SCH: append in front | |
| # conv blocks with upsampling | |
| if i > 0: | |
| if self.temporal_downsample[i - 1]: | |
| t_stride = 2 if self.temporal_downsample[i - 1] else 1 | |
| # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 | |
| self.conv_blocks.insert( | |
| 0, | |
| self.conv_fn( | |
| prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3) | |
| ), | |
| ) | |
| else: | |
| self.conv_blocks.insert( | |
| 0, | |
| nn.Identity(prev_filters), | |
| ) | |
| self.norm1 = nn.GroupNorm(self.num_groups, prev_filters) | |
| self.conv_out = self.conv_fn(filters, in_out_channels, 3) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| for i in range(self.num_res_blocks): | |
| x = self.res_blocks[i](x) | |
| for i in reversed(range(self.num_blocks)): | |
| for j in range(self.num_res_blocks): | |
| x = self.block_res_blocks[i][j](x) | |
| if i > 0: | |
| t_stride = 2 if self.temporal_downsample[i - 1] else 1 | |
| x = self.conv_blocks[i - 1](x) | |
| x = rearrange( | |
| x, | |
| "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", | |
| ts=t_stride, | |
| hs=self.s_stride, | |
| ws=self.s_stride, | |
| ) | |
| x = self.norm1(x) | |
| x = self.activate(x) | |
| x = self.conv_out(x) | |
| return x | |
| class VAE_Temporal(nn.Module): | |
| def __init__( | |
| self, | |
| in_out_channels=4, | |
| latent_embed_dim=4, | |
| embed_dim=4, | |
| filters=128, | |
| num_res_blocks=4, | |
| channel_multipliers=(1, 2, 2, 4), | |
| temporal_downsample=(True, True, False), | |
| num_groups=32, # for nn.GroupNorm | |
| activation_fn="swish", | |
| ): | |
| super().__init__() | |
| self.time_downsample_factor = 2 ** sum(temporal_downsample) | |
| # self.time_padding = self.time_downsample_factor - 1 | |
| self.patch_size = (self.time_downsample_factor, 1, 1) | |
| self.out_channels = in_out_channels | |
| # NOTE: following MAGVIT, conv in bias=False in encoder first conv | |
| self.encoder = Encoder( | |
| in_out_channels=in_out_channels, | |
| latent_embed_dim=latent_embed_dim * 2, | |
| filters=filters, | |
| num_res_blocks=num_res_blocks, | |
| channel_multipliers=channel_multipliers, | |
| temporal_downsample=temporal_downsample, | |
| num_groups=num_groups, # for nn.GroupNorm | |
| activation_fn=activation_fn, | |
| ) | |
| self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1) | |
| self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1) | |
| self.decoder = Decoder( | |
| in_out_channels=in_out_channels, | |
| latent_embed_dim=latent_embed_dim, | |
| filters=filters, | |
| num_res_blocks=num_res_blocks, | |
| channel_multipliers=channel_multipliers, | |
| temporal_downsample=temporal_downsample, | |
| num_groups=num_groups, # for nn.GroupNorm | |
| activation_fn=activation_fn, | |
| ) | |
| def get_latent_size(self, input_size): | |
| latent_size = [] | |
| for i in range(3): | |
| if input_size[i] is None: | |
| lsize = None | |
| elif i == 0: | |
| time_padding = ( | |
| 0 | |
| if (input_size[i] % self.time_downsample_factor == 0) | |
| else self.time_downsample_factor - input_size[i] % self.time_downsample_factor | |
| ) | |
| lsize = (input_size[i] + time_padding) // self.patch_size[i] | |
| else: | |
| lsize = input_size[i] // self.patch_size[i] | |
| latent_size.append(lsize) | |
| return latent_size | |
| def encode(self, x): | |
| time_padding = ( | |
| 0 | |
| if (x.shape[2] % self.time_downsample_factor == 0) | |
| else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor | |
| ) | |
| x = pad_at_dim(x, (time_padding, 0), dim=2) | |
| encoded_feature = self.encoder(x) | |
| moments = self.quant_conv(encoded_feature).to(x.dtype) | |
| posterior = DiagonalGaussianDistribution(moments) | |
| return posterior | |
| def decode(self, z, num_frames=None): | |
| time_padding = ( | |
| 0 | |
| if (num_frames % self.time_downsample_factor == 0) | |
| else self.time_downsample_factor - num_frames % self.time_downsample_factor | |
| ) | |
| z = self.post_quant_conv(z) | |
| x = self.decoder(z) | |
| x = x[:, :, time_padding:] | |
| return x | |
| def forward(self, x, sample_posterior=True): | |
| posterior = self.encode(x) | |
| if sample_posterior: | |
| z = posterior.sample() | |
| else: | |
| z = posterior.mode() | |
| recon_video = self.decode(z, num_frames=x.shape[2]) | |
| return recon_video, posterior, z | |
| def VAE_Temporal_SD(**kwargs): | |
| model = VAE_Temporal( | |
| in_out_channels=4, | |
| latent_embed_dim=4, | |
| embed_dim=4, | |
| filters=128, | |
| num_res_blocks=4, | |
| channel_multipliers=(1, 2, 2, 4), | |
| temporal_downsample=(False, True, True), | |
| **kwargs, | |
| ) | |
| return model | |
| class VideoAutoencoderKL(nn.Module): | |
| def __init__( | |
| self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None | |
| ): | |
| super().__init__() | |
| self.module = AutoencoderKL.from_pretrained( | |
| from_pretrained, | |
| cache_dir=cache_dir, | |
| local_files_only=local_files_only, | |
| subfolder=subfolder, | |
| ) | |
| self.out_channels = self.module.config.latent_channels | |
| self.patch_size = (1, 8, 8) | |
| self.micro_batch_size = micro_batch_size | |
| def encode(self, x): | |
| # x: (B, C, T, H, W) | |
| B = x.shape[0] | |
| x = rearrange(x, "B C T H W -> (B T) C H W") | |
| if self.micro_batch_size is None: | |
| x = self.module.encode(x).latent_dist.sample().mul_(0.18215) | |
| else: | |
| # NOTE: cannot be used for training | |
| bs = self.micro_batch_size | |
| x_out = [] | |
| for i in range(0, x.shape[0], bs): | |
| x_bs = x[i : i + bs] | |
| x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215) | |
| x_out.append(x_bs) | |
| x = torch.cat(x_out, dim=0) | |
| x = rearrange(x, "(B T) C H W -> B C T H W", B=B) | |
| return x | |
| def decode(self, x, **kwargs): | |
| # x: (B, C, T, H, W) | |
| B = x.shape[0] | |
| x = rearrange(x, "B C T H W -> (B T) C H W") | |
| if self.micro_batch_size is None: | |
| x = self.module.decode(x / 0.18215).sample | |
| else: | |
| # NOTE: cannot be used for training | |
| bs = self.micro_batch_size | |
| x_out = [] | |
| for i in range(0, x.shape[0], bs): | |
| x_bs = x[i : i + bs] | |
| x_bs = self.module.decode(x_bs / 0.18215).sample | |
| x_out.append(x_bs) | |
| x = torch.cat(x_out, dim=0) | |
| x = rearrange(x, "(B T) C H W -> B C T H W", B=B) | |
| return x | |
| def get_latent_size(self, input_size): | |
| latent_size = [] | |
| for i in range(3): | |
| # assert ( | |
| # input_size[i] is None or input_size[i] % self.patch_size[i] == 0 | |
| # ), "Input size must be divisible by patch size" | |
| latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None) | |
| return latent_size | |
| def device(self): | |
| return next(self.parameters()).device | |
| def dtype(self): | |
| return next(self.parameters()).dtype | |
| class VideoAutoencoderKLTemporalDecoder(nn.Module): | |
| def __init__(self, from_pretrained=None, cache_dir=None, local_files_only=False): | |
| super().__init__() | |
| self.module = AutoencoderKLTemporalDecoder.from_pretrained( | |
| from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only | |
| ) | |
| self.out_channels = self.module.config.latent_channels | |
| self.patch_size = (1, 8, 8) | |
| def encode(self, x): | |
| raise NotImplementedError | |
| def decode(self, x, **kwargs): | |
| B, _, T = x.shape[:3] | |
| x = rearrange(x, "B C T H W -> (B T) C H W") | |
| x = self.module.decode(x / 0.18215, num_frames=T).sample | |
| x = rearrange(x, "(B T) C H W -> B C T H W", B=B) | |
| return x | |
| def get_latent_size(self, input_size): | |
| latent_size = [] | |
| for i in range(3): | |
| # assert ( | |
| # input_size[i] is None or input_size[i] % self.patch_size[i] == 0 | |
| # ), "Input size must be divisible by patch size" | |
| latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None) | |
| return latent_size | |
| def device(self): | |
| return next(self.parameters()).device | |
| def dtype(self): | |
| return next(self.parameters()).dtype | |
| class VideoAutoencoderPipelineConfig(PretrainedConfig): | |
| model_type = "VideoAutoencoderPipeline" | |
| def __init__( | |
| self, | |
| vae_2d=None, | |
| vae_temporal=None, | |
| from_pretrained=None, | |
| freeze_vae_2d=False, | |
| cal_loss=False, | |
| micro_frame_size=None, | |
| shift=0.0, | |
| scale=1.0, | |
| **kwargs, | |
| ): | |
| self.vae_2d = vae_2d | |
| self.vae_temporal = vae_temporal | |
| self.from_pretrained = from_pretrained | |
| self.freeze_vae_2d = freeze_vae_2d | |
| self.cal_loss = cal_loss | |
| self.micro_frame_size = micro_frame_size | |
| self.shift = shift | |
| self.scale = scale | |
| super().__init__(**kwargs) | |
| class VideoAutoencoderPipeline(PreTrainedModel): | |
| config_class = VideoAutoencoderPipelineConfig | |
| def __init__(self, config: VideoAutoencoderPipelineConfig): | |
| super().__init__(config=config) | |
| self.spatial_vae = VideoAutoencoderKL( | |
| from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", | |
| local_files_only=False, | |
| micro_batch_size=4, | |
| subfolder="vae", | |
| ) | |
| self.temporal_vae = VAE_Temporal_SD() | |
| self.cal_loss = config.cal_loss | |
| self.micro_frame_size = config.micro_frame_size | |
| self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0] | |
| if config.freeze_vae_2d: | |
| for param in self.spatial_vae.parameters(): | |
| param.requires_grad = False | |
| self.out_channels = self.temporal_vae.out_channels | |
| # normalization parameters | |
| scale = torch.tensor(config.scale) | |
| shift = torch.tensor(config.shift) | |
| if len(scale.shape) > 0: | |
| scale = scale[None, :, None, None, None] | |
| if len(shift.shape) > 0: | |
| shift = shift[None, :, None, None, None] | |
| self.register_buffer("scale", scale) | |
| self.register_buffer("shift", shift) | |
| def encode(self, x): | |
| x_z = self.spatial_vae.encode(x) | |
| if self.micro_frame_size is None: | |
| posterior = self.temporal_vae.encode(x_z) | |
| z = posterior.sample() | |
| else: | |
| z_list = [] | |
| for i in range(0, x_z.shape[2], self.micro_frame_size): | |
| x_z_bs = x_z[:, :, i : i + self.micro_frame_size] | |
| posterior = self.temporal_vae.encode(x_z_bs) | |
| z_list.append(posterior.sample()) | |
| z = torch.cat(z_list, dim=2) | |
| if self.cal_loss: | |
| return z, posterior, x_z | |
| else: | |
| return (z - self.shift) / self.scale | |
| def decode(self, z, num_frames=None): | |
| if not self.cal_loss: | |
| z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype) | |
| if self.micro_frame_size is None: | |
| x_z = self.temporal_vae.decode(z, num_frames=num_frames) | |
| x = self.spatial_vae.decode(x_z) | |
| else: | |
| x_z_list = [] | |
| for i in range(0, z.size(2), self.micro_z_frame_size): | |
| z_bs = z[:, :, i : i + self.micro_z_frame_size] | |
| x_z_bs = self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames)) | |
| x_z_list.append(x_z_bs) | |
| num_frames -= self.micro_frame_size | |
| x_z = torch.cat(x_z_list, dim=2) | |
| x = self.spatial_vae.decode(x_z) | |
| if self.cal_loss: | |
| return x, x_z | |
| else: | |
| return x | |
| def forward(self, x): | |
| assert self.cal_loss, "This method is only available when cal_loss is True" | |
| z, posterior, x_z = self.encode(x) | |
| x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2]) | |
| return x_rec, x_z_rec, z, posterior, x_z | |
| def get_latent_size(self, input_size): | |
| if self.micro_frame_size is None or input_size[0] is None: | |
| return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size)) | |
| else: | |
| sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]] | |
| sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size)) | |
| sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size) | |
| remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None] | |
| if remain_temporal_size[0] > 0: | |
| remain_size = self.temporal_vae.get_latent_size(remain_temporal_size) | |
| sub_latent_size[0] += remain_size[0] | |
| return sub_latent_size | |
| def get_temporal_last_layer(self): | |
| return self.temporal_vae.decoder.conv_out.conv.weight | |
| def device(self): | |
| return next(self.parameters()).device | |
| def dtype(self): | |
| return next(self.parameters()).dtype | |
| def OpenSoraVAE_V1_2( | |
| micro_batch_size=4, | |
| micro_frame_size=17, | |
| from_pretrained=None, | |
| local_files_only=False, | |
| freeze_vae_2d=False, | |
| cal_loss=False, | |
| ): | |
| vae_2d = dict( | |
| type="VideoAutoencoderKL", | |
| from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", | |
| subfolder="vae", | |
| micro_batch_size=micro_batch_size, | |
| local_files_only=local_files_only, | |
| ) | |
| vae_temporal = dict( | |
| type="VAE_Temporal_SD", | |
| from_pretrained=None, | |
| ) | |
| shift = (-0.10, 0.34, 0.27, 0.98) | |
| scale = (3.85, 2.32, 2.33, 3.06) | |
| kwargs = dict( | |
| vae_2d=vae_2d, | |
| vae_temporal=vae_temporal, | |
| freeze_vae_2d=freeze_vae_2d, | |
| cal_loss=cal_loss, | |
| micro_frame_size=micro_frame_size, | |
| shift=shift, | |
| scale=scale, | |
| ) | |
| if from_pretrained is not None and not os.path.isdir(from_pretrained): | |
| model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs) | |
| else: | |
| config = VideoAutoencoderPipelineConfig(**kwargs) | |
| model = VideoAutoencoderPipeline(config) | |
| return model | |