Spaces:
Runtime error
Runtime error
| # Adapted from Open-Sora-Plan | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # -------------------------------------------------------- | |
| # References: | |
| # Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan | |
| # -------------------------------------------------------- | |
| import glob | |
| import os | |
| from typing import Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers import ConfigMixin, ModelMixin | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.utils import logging | |
| from einops import rearrange | |
| from torch import nn | |
| logging.set_verbosity_error() | |
| def Normalize(in_channels, num_groups=32): | |
| return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) | |
| def tensor_to_video(x): | |
| x = x.detach().cpu() | |
| x = torch.clamp(x, -1, 1) | |
| x = (x + 1) / 2 | |
| x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> | |
| x = (255 * x).astype(np.uint8) | |
| return x | |
| def nonlinearity(x): | |
| return x * torch.sigmoid(x) | |
| class DiagonalGaussianDistribution(object): | |
| def __init__(self, parameters, deterministic=False): | |
| self.parameters = parameters | |
| self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) | |
| self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | |
| self.deterministic = deterministic | |
| self.std = torch.exp(0.5 * self.logvar) | |
| self.var = torch.exp(self.logvar) | |
| if self.deterministic: | |
| self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) | |
| def sample(self): | |
| x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) | |
| return x | |
| def kl(self, other=None): | |
| if self.deterministic: | |
| return torch.Tensor([0.0]) | |
| else: | |
| if other is None: | |
| return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) | |
| else: | |
| return 0.5 * torch.sum( | |
| torch.pow(self.mean - other.mean, 2) / other.var | |
| + self.var / other.var | |
| - 1.0 | |
| - self.logvar | |
| + other.logvar, | |
| dim=[1, 2, 3], | |
| ) | |
| def nll(self, sample, dims=[1, 2, 3]): | |
| if self.deterministic: | |
| return torch.Tensor([0.0]) | |
| logtwopi = np.log(2.0 * np.pi) | |
| return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) | |
| def mode(self): | |
| return self.mean | |
| def resolve_str_to_obj(str_val, append=True): | |
| return globals()[str_val] | |
| class VideoBaseAE_PL(ModelMixin, ConfigMixin): | |
| config_name = "config.json" | |
| def __init__(self, *args, **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| def encode(self, x: torch.Tensor, *args, **kwargs): | |
| pass | |
| def decode(self, encoding: torch.Tensor, *args, **kwargs): | |
| pass | |
| def num_training_steps(self) -> int: | |
| """Total training steps inferred from datamodule and devices.""" | |
| if self.trainer.max_steps: | |
| return self.trainer.max_steps | |
| limit_batches = self.trainer.limit_train_batches | |
| batches = len(self.train_dataloader()) | |
| batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches) | |
| num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) | |
| if self.trainer.tpu_cores: | |
| num_devices = max(num_devices, self.trainer.tpu_cores) | |
| effective_accum = self.trainer.accumulate_grad_batches * num_devices | |
| return (batches // effective_accum) * self.trainer.max_epochs | |
| def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): | |
| ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, "*.ckpt")) | |
| if ckpt_files: | |
| # Adapt to PyTorch Lightning | |
| last_ckpt_file = ckpt_files[-1] | |
| config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) | |
| model = cls.from_config(config_file) | |
| print("init from {}".format(last_ckpt_file)) | |
| model.init_from_ckpt(last_ckpt_file) | |
| return model | |
| else: | |
| return super().from_pretrained(pretrained_model_name_or_path, **kwargs) | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| z_channels: int, | |
| hidden_size: int, | |
| hidden_size_mult: Tuple[int] = (1, 2, 4, 4), | |
| attn_resolutions: Tuple[int] = (16,), | |
| conv_in: str = "Conv2d", | |
| conv_out: str = "CasualConv3d", | |
| attention: str = "AttnBlock", | |
| resnet_blocks: Tuple[str] = ( | |
| "ResnetBlock2D", | |
| "ResnetBlock2D", | |
| "ResnetBlock2D", | |
| "ResnetBlock3D", | |
| ), | |
| spatial_downsample: Tuple[str] = ( | |
| "Downsample", | |
| "Downsample", | |
| "Downsample", | |
| "", | |
| ), | |
| temporal_downsample: Tuple[str] = ("", "", "TimeDownsampleRes2x", ""), | |
| mid_resnet: str = "ResnetBlock3D", | |
| dropout: float = 0.0, | |
| resolution: int = 256, | |
| num_res_blocks: int = 2, | |
| double_z: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| assert len(resnet_blocks) == len(hidden_size_mult), print(hidden_size_mult, resnet_blocks) | |
| # ---- Config ---- | |
| self.num_resolutions = len(hidden_size_mult) | |
| self.resolution = resolution | |
| self.num_res_blocks = num_res_blocks | |
| # ---- In ---- | |
| self.conv_in = resolve_str_to_obj(conv_in)(3, hidden_size, kernel_size=3, stride=1, padding=1) | |
| # ---- Downsample ---- | |
| curr_res = resolution | |
| in_ch_mult = (1,) + tuple(hidden_size_mult) | |
| self.in_ch_mult = in_ch_mult | |
| self.down = nn.ModuleList() | |
| for i_level in range(self.num_resolutions): | |
| block = nn.ModuleList() | |
| attn = nn.ModuleList() | |
| block_in = hidden_size * in_ch_mult[i_level] | |
| block_out = hidden_size * hidden_size_mult[i_level] | |
| for i_block in range(self.num_res_blocks): | |
| block.append( | |
| resolve_str_to_obj(resnet_blocks[i_level])( | |
| in_channels=block_in, | |
| out_channels=block_out, | |
| dropout=dropout, | |
| ) | |
| ) | |
| block_in = block_out | |
| if curr_res in attn_resolutions: | |
| attn.append(resolve_str_to_obj(attention)(block_in)) | |
| down = nn.Module() | |
| down.block = block | |
| down.attn = attn | |
| if spatial_downsample[i_level]: | |
| down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(block_in, block_in) | |
| curr_res = curr_res // 2 | |
| if temporal_downsample[i_level]: | |
| down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(block_in, block_in) | |
| self.down.append(down) | |
| # ---- Mid ---- | |
| self.mid = nn.Module() | |
| self.mid.block_1 = resolve_str_to_obj(mid_resnet)( | |
| in_channels=block_in, | |
| out_channels=block_in, | |
| dropout=dropout, | |
| ) | |
| self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) | |
| self.mid.block_2 = resolve_str_to_obj(mid_resnet)( | |
| in_channels=block_in, | |
| out_channels=block_in, | |
| dropout=dropout, | |
| ) | |
| # ---- Out ---- | |
| self.norm_out = Normalize(block_in) | |
| self.conv_out = resolve_str_to_obj(conv_out)( | |
| block_in, | |
| 2 * z_channels if double_z else z_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| ) | |
| def forward(self, x): | |
| hs = [self.conv_in(x)] | |
| for i_level in range(self.num_resolutions): | |
| for i_block in range(self.num_res_blocks): | |
| h = self.down[i_level].block[i_block](hs[-1]) | |
| if len(self.down[i_level].attn) > 0: | |
| h = self.down[i_level].attn[i_block](h) | |
| hs.append(h) | |
| if hasattr(self.down[i_level], "downsample"): | |
| hs.append(self.down[i_level].downsample(hs[-1])) | |
| if hasattr(self.down[i_level], "time_downsample"): | |
| hs_down = self.down[i_level].time_downsample(hs[-1]) | |
| hs.append(hs_down) | |
| h = self.mid.block_1(h) | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h) | |
| h = self.norm_out(h) | |
| h = nonlinearity(h) | |
| h = self.conv_out(h) | |
| return h | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| z_channels: int, | |
| hidden_size: int, | |
| hidden_size_mult: Tuple[int] = (1, 2, 4, 4), | |
| attn_resolutions: Tuple[int] = (16,), | |
| conv_in: str = "Conv2d", | |
| conv_out: str = "CasualConv3d", | |
| attention: str = "AttnBlock", | |
| resnet_blocks: Tuple[str] = ( | |
| "ResnetBlock3D", | |
| "ResnetBlock3D", | |
| "ResnetBlock3D", | |
| "ResnetBlock3D", | |
| ), | |
| spatial_upsample: Tuple[str] = ( | |
| "", | |
| "SpatialUpsample2x", | |
| "SpatialUpsample2x", | |
| "SpatialUpsample2x", | |
| ), | |
| temporal_upsample: Tuple[str] = ("", "", "", "TimeUpsampleRes2x"), | |
| mid_resnet: str = "ResnetBlock3D", | |
| dropout: float = 0.0, | |
| resolution: int = 256, | |
| num_res_blocks: int = 2, | |
| ): | |
| super().__init__() | |
| # ---- Config ---- | |
| self.num_resolutions = len(hidden_size_mult) | |
| self.resolution = resolution | |
| self.num_res_blocks = num_res_blocks | |
| # ---- In ---- | |
| block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1] | |
| curr_res = resolution // 2 ** (self.num_resolutions - 1) | |
| self.conv_in = resolve_str_to_obj(conv_in)(z_channels, block_in, kernel_size=3, padding=1) | |
| # ---- Mid ---- | |
| self.mid = nn.Module() | |
| self.mid.block_1 = resolve_str_to_obj(mid_resnet)( | |
| in_channels=block_in, | |
| out_channels=block_in, | |
| dropout=dropout, | |
| ) | |
| self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) | |
| self.mid.block_2 = resolve_str_to_obj(mid_resnet)( | |
| in_channels=block_in, | |
| out_channels=block_in, | |
| dropout=dropout, | |
| ) | |
| # ---- Upsample ---- | |
| self.up = nn.ModuleList() | |
| for i_level in reversed(range(self.num_resolutions)): | |
| block = nn.ModuleList() | |
| attn = nn.ModuleList() | |
| block_out = hidden_size * hidden_size_mult[i_level] | |
| for i_block in range(self.num_res_blocks + 1): | |
| block.append( | |
| resolve_str_to_obj(resnet_blocks[i_level])( | |
| in_channels=block_in, | |
| out_channels=block_out, | |
| dropout=dropout, | |
| ) | |
| ) | |
| block_in = block_out | |
| if curr_res in attn_resolutions: | |
| attn.append(resolve_str_to_obj(attention)(block_in)) | |
| up = nn.Module() | |
| up.block = block | |
| up.attn = attn | |
| if spatial_upsample[i_level]: | |
| up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(block_in, block_in) | |
| curr_res = curr_res * 2 | |
| if temporal_upsample[i_level]: | |
| up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(block_in, block_in) | |
| self.up.insert(0, up) | |
| # ---- Out ---- | |
| self.norm_out = Normalize(block_in) | |
| self.conv_out = resolve_str_to_obj(conv_out)(block_in, 3, kernel_size=3, padding=1) | |
| def forward(self, z): | |
| h = self.conv_in(z) | |
| h = self.mid.block_1(h) | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h) | |
| for i_level in reversed(range(self.num_resolutions)): | |
| for i_block in range(self.num_res_blocks + 1): | |
| h = self.up[i_level].block[i_block](h) | |
| if len(self.up[i_level].attn) > 0: | |
| h = self.up[i_level].attn[i_block](h) | |
| if hasattr(self.up[i_level], "upsample"): | |
| h = self.up[i_level].upsample(h) | |
| if hasattr(self.up[i_level], "time_upsample"): | |
| h = self.up[i_level].time_upsample(h) | |
| h = self.norm_out(h) | |
| h = nonlinearity(h) | |
| h = self.conv_out(h) | |
| return h | |
| class CausalVAEModel(VideoBaseAE_PL): | |
| def __init__( | |
| self, | |
| lr: float = 1e-5, | |
| hidden_size: int = 128, | |
| z_channels: int = 4, | |
| hidden_size_mult: Tuple[int] = (1, 2, 4, 4), | |
| attn_resolutions: Tuple[int] = [], | |
| dropout: float = 0.0, | |
| resolution: int = 256, | |
| double_z: bool = True, | |
| embed_dim: int = 4, | |
| num_res_blocks: int = 2, | |
| loss_type: str = "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator", | |
| loss_params: dict = { | |
| "kl_weight": 0.000001, | |
| "logvar_init": 0.0, | |
| "disc_start": 2001, | |
| "disc_weight": 0.5, | |
| }, | |
| q_conv: str = "CausalConv3d", | |
| encoder_conv_in: str = "CausalConv3d", | |
| encoder_conv_out: str = "CausalConv3d", | |
| encoder_attention: str = "AttnBlock3D", | |
| encoder_resnet_blocks: Tuple[str] = ( | |
| "ResnetBlock3D", | |
| "ResnetBlock3D", | |
| "ResnetBlock3D", | |
| "ResnetBlock3D", | |
| ), | |
| encoder_spatial_downsample: Tuple[str] = ( | |
| "SpatialDownsample2x", | |
| "SpatialDownsample2x", | |
| "SpatialDownsample2x", | |
| "", | |
| ), | |
| encoder_temporal_downsample: Tuple[str] = ( | |
| "", | |
| "TimeDownsample2x", | |
| "TimeDownsample2x", | |
| "", | |
| ), | |
| encoder_mid_resnet: str = "ResnetBlock3D", | |
| decoder_conv_in: str = "CausalConv3d", | |
| decoder_conv_out: str = "CausalConv3d", | |
| decoder_attention: str = "AttnBlock3D", | |
| decoder_resnet_blocks: Tuple[str] = ( | |
| "ResnetBlock3D", | |
| "ResnetBlock3D", | |
| "ResnetBlock3D", | |
| "ResnetBlock3D", | |
| ), | |
| decoder_spatial_upsample: Tuple[str] = ( | |
| "", | |
| "SpatialUpsample2x", | |
| "SpatialUpsample2x", | |
| "SpatialUpsample2x", | |
| ), | |
| decoder_temporal_upsample: Tuple[str] = ("", "", "TimeUpsample2x", "TimeUpsample2x"), | |
| decoder_mid_resnet: str = "ResnetBlock3D", | |
| ) -> None: | |
| super().__init__() | |
| self.tile_sample_min_size = 256 | |
| self.tile_sample_min_size_t = 65 | |
| self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1))) | |
| t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0] | |
| self.tile_latent_min_size_t = int((self.tile_sample_min_size_t - 1) / (2 ** len(t_down_ratio))) + 1 | |
| self.tile_overlap_factor = 0.25 | |
| self.use_tiling = False | |
| self.learning_rate = lr | |
| self.lr_g_factor = 1.0 | |
| self.encoder = Encoder( | |
| z_channels=z_channels, | |
| hidden_size=hidden_size, | |
| hidden_size_mult=hidden_size_mult, | |
| attn_resolutions=attn_resolutions, | |
| conv_in=encoder_conv_in, | |
| conv_out=encoder_conv_out, | |
| attention=encoder_attention, | |
| resnet_blocks=encoder_resnet_blocks, | |
| spatial_downsample=encoder_spatial_downsample, | |
| temporal_downsample=encoder_temporal_downsample, | |
| mid_resnet=encoder_mid_resnet, | |
| dropout=dropout, | |
| resolution=resolution, | |
| num_res_blocks=num_res_blocks, | |
| double_z=double_z, | |
| ) | |
| self.decoder = Decoder( | |
| z_channels=z_channels, | |
| hidden_size=hidden_size, | |
| hidden_size_mult=hidden_size_mult, | |
| attn_resolutions=attn_resolutions, | |
| conv_in=decoder_conv_in, | |
| conv_out=decoder_conv_out, | |
| attention=decoder_attention, | |
| resnet_blocks=decoder_resnet_blocks, | |
| spatial_upsample=decoder_spatial_upsample, | |
| temporal_upsample=decoder_temporal_upsample, | |
| mid_resnet=decoder_mid_resnet, | |
| dropout=dropout, | |
| resolution=resolution, | |
| num_res_blocks=num_res_blocks, | |
| ) | |
| quant_conv_cls = resolve_str_to_obj(q_conv) | |
| self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1) | |
| self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1) | |
| def encode(self, x): | |
| if self.use_tiling and ( | |
| x.shape[-1] > self.tile_sample_min_size | |
| or x.shape[-2] > self.tile_sample_min_size | |
| or x.shape[-3] > self.tile_sample_min_size_t | |
| ): | |
| return self.tiled_encode(x) | |
| h = self.encoder(x) | |
| moments = self.quant_conv(h) | |
| posterior = DiagonalGaussianDistribution(moments) | |
| return posterior | |
| def decode(self, z): | |
| if self.use_tiling and ( | |
| z.shape[-1] > self.tile_latent_min_size | |
| or z.shape[-2] > self.tile_latent_min_size | |
| or z.shape[-3] > self.tile_latent_min_size_t | |
| ): | |
| return self.tiled_decode(z) | |
| z = self.post_quant_conv(z) | |
| dec = self.decoder(z) | |
| return dec | |
| def forward(self, input, sample_posterior=True): | |
| posterior = self.encode(input) | |
| if sample_posterior: | |
| z = posterior.sample() | |
| else: | |
| z = posterior.mode() | |
| dec = self.decode(z) | |
| return dec, posterior | |
| def get_input(self, batch, k): | |
| x = batch[k] | |
| if len(x.shape) == 3: | |
| x = x[..., None] | |
| x = x.to(memory_format=torch.contiguous_format).float() | |
| return x | |
| def training_step(self, batch, batch_idx): | |
| if hasattr(self.loss, "discriminator"): | |
| return self._training_step_gan(batch, batch_idx=batch_idx) | |
| else: | |
| return self._training_step(batch, batch_idx=batch_idx) | |
| def _training_step(self, batch, batch_idx): | |
| inputs = self.get_input(batch, "video") | |
| reconstructions, posterior = self(inputs) | |
| aeloss, log_dict_ae = self.loss( | |
| inputs, | |
| reconstructions, | |
| posterior, | |
| split="train", | |
| ) | |
| self.log( | |
| "aeloss", | |
| aeloss, | |
| prog_bar=True, | |
| logger=True, | |
| on_step=True, | |
| on_epoch=True, | |
| ) | |
| self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) | |
| return aeloss | |
| def _training_step_gan(self, batch, batch_idx): | |
| inputs = self.get_input(batch, "video") | |
| reconstructions, posterior = self(inputs) | |
| opt1, opt2 = self.optimizers() | |
| # ---- AE Loss ---- | |
| aeloss, log_dict_ae = self.loss( | |
| inputs, | |
| reconstructions, | |
| posterior, | |
| 0, | |
| self.global_step, | |
| last_layer=self.get_last_layer(), | |
| split="train", | |
| ) | |
| self.log( | |
| "aeloss", | |
| aeloss, | |
| prog_bar=True, | |
| logger=True, | |
| on_step=True, | |
| on_epoch=True, | |
| ) | |
| opt1.zero_grad() | |
| self.manual_backward(aeloss) | |
| self.clip_gradients(opt1, gradient_clip_val=1, gradient_clip_algorithm="norm") | |
| opt1.step() | |
| # ---- GAN Loss ---- | |
| discloss, log_dict_disc = self.loss( | |
| inputs, | |
| reconstructions, | |
| posterior, | |
| 1, | |
| self.global_step, | |
| last_layer=self.get_last_layer(), | |
| split="train", | |
| ) | |
| self.log( | |
| "discloss", | |
| discloss, | |
| prog_bar=True, | |
| logger=True, | |
| on_step=True, | |
| on_epoch=True, | |
| ) | |
| opt2.zero_grad() | |
| self.manual_backward(discloss) | |
| self.clip_gradients(opt2, gradient_clip_val=1, gradient_clip_algorithm="norm") | |
| opt2.step() | |
| self.log_dict( | |
| {**log_dict_ae, **log_dict_disc}, | |
| prog_bar=False, | |
| logger=True, | |
| on_step=True, | |
| on_epoch=False, | |
| ) | |
| def configure_optimizers(self): | |
| from itertools import chain | |
| lr = self.learning_rate | |
| modules_to_train = [ | |
| self.encoder.named_parameters(), | |
| self.decoder.named_parameters(), | |
| self.post_quant_conv.named_parameters(), | |
| self.quant_conv.named_parameters(), | |
| ] | |
| params_with_time = [] | |
| params_without_time = [] | |
| for name, param in chain(*modules_to_train): | |
| if "time" in name: | |
| params_with_time.append(param) | |
| else: | |
| params_without_time.append(param) | |
| optimizers = [] | |
| opt_ae = torch.optim.Adam( | |
| [ | |
| {"params": params_with_time, "lr": lr}, | |
| {"params": params_without_time, "lr": lr}, | |
| ], | |
| lr=lr, | |
| betas=(0.5, 0.9), | |
| ) | |
| optimizers.append(opt_ae) | |
| if hasattr(self.loss, "discriminator"): | |
| opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) | |
| optimizers.append(opt_disc) | |
| return optimizers, [] | |
| def get_last_layer(self): | |
| if hasattr(self.decoder.conv_out, "conv"): | |
| return self.decoder.conv_out.conv.weight | |
| else: | |
| return self.decoder.conv_out.weight | |
| def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | |
| blend_extent = min(a.shape[3], b.shape[3], blend_extent) | |
| for y in range(blend_extent): | |
| b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( | |
| y / blend_extent | |
| ) | |
| return b | |
| def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | |
| blend_extent = min(a.shape[4], b.shape[4], blend_extent) | |
| for x in range(blend_extent): | |
| b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( | |
| x / blend_extent | |
| ) | |
| return b | |
| def tiled_encode(self, x): | |
| t = x.shape[2] | |
| t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t - 1)] | |
| if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: | |
| t_chunk_start_end = [[0, t]] | |
| else: | |
| t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] | |
| if t_chunk_start_end[-1][-1] > t: | |
| t_chunk_start_end[-1][-1] = t | |
| elif t_chunk_start_end[-1][-1] < t: | |
| last_start_end = [t_chunk_idx[-1], t] | |
| t_chunk_start_end.append(last_start_end) | |
| moments = [] | |
| for idx, (start, end) in enumerate(t_chunk_start_end): | |
| chunk_x = x[:, :, start:end] | |
| if idx != 0: | |
| moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:] | |
| else: | |
| moment = self.tiled_encode2d(chunk_x, return_moments=True) | |
| moments.append(moment) | |
| moments = torch.cat(moments, dim=2) | |
| posterior = DiagonalGaussianDistribution(moments) | |
| return posterior | |
| def tiled_decode(self, x): | |
| t = x.shape[2] | |
| t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t - 1)] | |
| if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: | |
| t_chunk_start_end = [[0, t]] | |
| else: | |
| t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] | |
| if t_chunk_start_end[-1][-1] > t: | |
| t_chunk_start_end[-1][-1] = t | |
| elif t_chunk_start_end[-1][-1] < t: | |
| last_start_end = [t_chunk_idx[-1], t] | |
| t_chunk_start_end.append(last_start_end) | |
| dec_ = [] | |
| for idx, (start, end) in enumerate(t_chunk_start_end): | |
| chunk_x = x[:, :, start:end] | |
| if idx != 0: | |
| dec = self.tiled_decode2d(chunk_x)[:, :, 1:] | |
| else: | |
| dec = self.tiled_decode2d(chunk_x) | |
| dec_.append(dec) | |
| dec_ = torch.cat(dec_, dim=2) | |
| return dec_ | |
| def tiled_encode2d(self, x, return_moments=False): | |
| overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) | |
| blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) | |
| row_limit = self.tile_latent_min_size - blend_extent | |
| # Split the image into 512x512 tiles and encode them separately. | |
| rows = [] | |
| for i in range(0, x.shape[3], overlap_size): | |
| row = [] | |
| for j in range(0, x.shape[4], overlap_size): | |
| tile = x[ | |
| :, | |
| :, | |
| :, | |
| i : i + self.tile_sample_min_size, | |
| j : j + self.tile_sample_min_size, | |
| ] | |
| tile = self.encoder(tile) | |
| tile = self.quant_conv(tile) | |
| row.append(tile) | |
| rows.append(row) | |
| result_rows = [] | |
| for i, row in enumerate(rows): | |
| result_row = [] | |
| for j, tile in enumerate(row): | |
| # blend the above tile and the left tile | |
| # to the current tile and add the current tile to the result row | |
| if i > 0: | |
| tile = self.blend_v(rows[i - 1][j], tile, blend_extent) | |
| if j > 0: | |
| tile = self.blend_h(row[j - 1], tile, blend_extent) | |
| result_row.append(tile[:, :, :, :row_limit, :row_limit]) | |
| result_rows.append(torch.cat(result_row, dim=4)) | |
| moments = torch.cat(result_rows, dim=3) | |
| posterior = DiagonalGaussianDistribution(moments) | |
| if return_moments: | |
| return moments | |
| return posterior | |
| def tiled_decode2d(self, z): | |
| overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) | |
| blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) | |
| row_limit = self.tile_sample_min_size - blend_extent | |
| # Split z into overlapping 64x64 tiles and decode them separately. | |
| # The tiles have an overlap to avoid seams between tiles. | |
| rows = [] | |
| for i in range(0, z.shape[3], overlap_size): | |
| row = [] | |
| for j in range(0, z.shape[4], overlap_size): | |
| tile = z[ | |
| :, | |
| :, | |
| :, | |
| i : i + self.tile_latent_min_size, | |
| j : j + self.tile_latent_min_size, | |
| ] | |
| tile = self.post_quant_conv(tile) | |
| decoded = self.decoder(tile) | |
| row.append(decoded) | |
| rows.append(row) | |
| result_rows = [] | |
| for i, row in enumerate(rows): | |
| result_row = [] | |
| for j, tile in enumerate(row): | |
| # blend the above tile and the left tile | |
| # to the current tile and add the current tile to the result row | |
| if i > 0: | |
| tile = self.blend_v(rows[i - 1][j], tile, blend_extent) | |
| if j > 0: | |
| tile = self.blend_h(row[j - 1], tile, blend_extent) | |
| result_row.append(tile[:, :, :, :row_limit, :row_limit]) | |
| result_rows.append(torch.cat(result_row, dim=4)) | |
| dec = torch.cat(result_rows, dim=3) | |
| return dec | |
| def enable_tiling(self, use_tiling: bool = True): | |
| self.use_tiling = use_tiling | |
| def disable_tiling(self): | |
| self.enable_tiling(False) | |
| def init_from_ckpt(self, path, ignore_keys=list(), remove_loss=False): | |
| sd = torch.load(path, map_location="cpu") | |
| print("init from " + path) | |
| if "state_dict" in sd: | |
| sd = sd["state_dict"] | |
| keys = list(sd.keys()) | |
| for k in keys: | |
| for ik in ignore_keys: | |
| if k.startswith(ik): | |
| print("Deleting key {} from state_dict.".format(k)) | |
| del sd[k] | |
| self.load_state_dict(sd, strict=False) | |
| def validation_step(self, batch, batch_idx): | |
| inputs = self.get_input(batch, "video") | |
| latents = self.encode(inputs).sample() | |
| video_recon = self.decode(latents) | |
| for idx in range(len(video_recon)): | |
| self.logger.log_video(f"recon {batch_idx} {idx}", [tensor_to_video(video_recon[idx])], fps=[10]) | |
| class CausalVAEModelWrapper(nn.Module): | |
| def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs): | |
| super(CausalVAEModelWrapper, self).__init__() | |
| # if os.path.exists(ckpt): | |
| # self.vae = CausalVAEModel.load_from_checkpoint(ckpt) | |
| self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs) | |
| def encode(self, x): # b c t h w | |
| # x = self.vae.encode(x).sample() | |
| x = self.vae.encode(x).sample().mul_(0.18215) | |
| return x | |
| def decode(self, x): | |
| # x = self.vae.decode(x) | |
| x = self.vae.decode(x / 0.18215) | |
| x = rearrange(x, "b c t h w -> b t c h w").contiguous() | |
| return x | |
| def dtype(self): | |
| return self.vae.dtype | |
| # | |
| # def device(self): | |
| # return self.vae.device | |
| videobase_ae_stride = { | |
| "CausalVAEModel_4x8x8": [4, 8, 8], | |
| } | |
| videobase_ae_channel = { | |
| "CausalVAEModel_4x8x8": 4, | |
| } | |
| videobase_ae = { | |
| "CausalVAEModel_4x8x8": CausalVAEModelWrapper, | |
| } | |
| ae_stride_config = {} | |
| ae_stride_config.update(videobase_ae_stride) | |
| ae_channel_config = {} | |
| ae_channel_config.update(videobase_ae_channel) | |
| def getae_wrapper(ae): | |
| """deprecation""" | |
| ae = videobase_ae.get(ae, None) | |
| assert ae is not None | |
| return ae | |
| def video_to_image(func): | |
| def wrapper(self, x, *args, **kwargs): | |
| if x.dim() == 5: | |
| t = x.shape[2] | |
| x = rearrange(x, "b c t h w -> (b t) c h w") | |
| x = func(self, x, *args, **kwargs) | |
| x = rearrange(x, "(b t) c h w -> b c t h w", t=t) | |
| return x | |
| return wrapper | |
| class Block(nn.Module): | |
| def __init__(self, *args, **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| class LinearAttention(Block): | |
| def __init__(self, dim, heads=4, dim_head=32): | |
| super().__init__() | |
| self.heads = heads | |
| hidden_dim = dim_head * heads | |
| self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
| self.to_out = nn.Conv2d(hidden_dim, dim, 1) | |
| def forward(self, x): | |
| b, c, h, w = x.shape | |
| qkv = self.to_qkv(x) | |
| q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) | |
| k = k.softmax(dim=-1) | |
| context = torch.einsum("bhdn,bhen->bhde", k, v) | |
| out = torch.einsum("bhde,bhdn->bhen", context, q) | |
| out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) | |
| return self.to_out(out) | |
| class LinAttnBlock(LinearAttention): | |
| """to match AttnBlock usage""" | |
| def __init__(self, in_channels): | |
| super().__init__(dim=in_channels, heads=1, dim_head=in_channels) | |
| class AttnBlock3D(Block): | |
| """Compatible with old versions, there are issues, use with caution.""" | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = Normalize(in_channels) | |
| self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| def forward(self, x): | |
| h_ = x | |
| h_ = self.norm(h_) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| # compute attention | |
| b, c, t, h, w = q.shape | |
| q = q.reshape(b * t, c, h * w) | |
| q = q.permute(0, 2, 1) # b,hw,c | |
| k = k.reshape(b * t, c, h * w) # b,c,hw | |
| w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] | |
| w_ = w_ * (int(c) ** (-0.5)) | |
| w_ = torch.nn.functional.softmax(w_, dim=2) | |
| # attend to values | |
| v = v.reshape(b * t, c, h * w) | |
| w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) | |
| h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] | |
| h_ = h_.reshape(b, c, t, h, w) | |
| h_ = self.proj_out(h_) | |
| return x + h_ | |
| class AttnBlock3DFix(nn.Module): | |
| """ | |
| Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172. | |
| """ | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = Normalize(in_channels) | |
| self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) | |
| def forward(self, x): | |
| h_ = x | |
| h_ = self.norm(h_) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| # compute attention | |
| # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c) | |
| b, c, t, h, w = q.shape | |
| q = q.permute(0, 2, 1, 3, 4) | |
| q = q.reshape(b * t, c, h * w) | |
| q = q.permute(0, 2, 1) | |
| # k: (b c t h w) -> (b t c h w) -> (b*t c h*w) | |
| k = k.permute(0, 2, 1, 3, 4) | |
| k = k.reshape(b * t, c, h * w) | |
| # w: (b*t hw hw) | |
| w_ = torch.bmm(q, k) | |
| w_ = w_ * (int(c) ** (-0.5)) | |
| w_ = torch.nn.functional.softmax(w_, dim=2) | |
| # attend to values | |
| # v: (b c t h w) -> (b t c h w) -> (bt c hw) | |
| # w_: (bt hw hw) -> (bt hw hw) | |
| v = v.permute(0, 2, 1, 3, 4) | |
| v = v.reshape(b * t, c, h * w) | |
| w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) | |
| h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] | |
| # h_: (b*t c hw) -> (b t c h w) -> (b c t h w) | |
| h_ = h_.reshape(b, t, c, h, w) | |
| h_ = h_.permute(0, 2, 1, 3, 4) | |
| h_ = self.proj_out(h_) | |
| return x + h_ | |
| class AttnBlock(Block): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = Normalize(in_channels) | |
| self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| def forward(self, x): | |
| h_ = x | |
| h_ = self.norm(h_) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| # compute attention | |
| b, c, h, w = q.shape | |
| q = q.reshape(b, c, h * w) | |
| q = q.permute(0, 2, 1) # b,hw,c | |
| k = k.reshape(b, c, h * w) # b,c,hw | |
| w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] | |
| w_ = w_ * (int(c) ** (-0.5)) | |
| w_ = torch.nn.functional.softmax(w_, dim=2) | |
| # attend to values | |
| v = v.reshape(b, c, h * w) | |
| w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) | |
| h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] | |
| h_ = h_.reshape(b, c, h, w) | |
| h_ = self.proj_out(h_) | |
| return x + h_ | |
| class TemporalAttnBlock(Block): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = Normalize(in_channels) | |
| self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.k = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.v = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.proj_out = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| def forward(self, x): | |
| h_ = x | |
| h_ = self.norm(h_) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| # compute attention | |
| b, c, t, h, w = q.shape | |
| q = rearrange(q, "b c t h w -> (b h w) t c") | |
| k = rearrange(k, "b c t h w -> (b h w) c t") | |
| v = rearrange(v, "b c t h w -> (b h w) c t") | |
| w_ = torch.bmm(q, k) | |
| w_ = w_ * (int(c) ** (-0.5)) | |
| w_ = torch.nn.functional.softmax(w_, dim=2) | |
| # attend to values | |
| w_ = w_.permute(0, 2, 1) | |
| h_ = torch.bmm(v, w_) | |
| h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w) | |
| h_ = self.proj_out(h_) | |
| return x + h_ | |
| def make_attn(in_channels, attn_type="vanilla"): | |
| assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown" | |
| print(f"making attention of type '{attn_type}' with {in_channels} in_channels") | |
| print(attn_type) | |
| if attn_type == "vanilla": | |
| return AttnBlock(in_channels) | |
| elif attn_type == "vanilla3D": | |
| return AttnBlock3D(in_channels) | |
| elif attn_type == "none": | |
| return nn.Identity(in_channels) | |
| else: | |
| return LinAttnBlock(in_channels) | |
| class Conv2d(nn.Conv2d): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: Union[int, Tuple[int]] = 3, | |
| stride: Union[int, Tuple[int]] = 1, | |
| padding: Union[str, int, Tuple[int]] = 0, | |
| dilation: Union[int, Tuple[int]] = 1, | |
| groups: int = 1, | |
| bias: bool = True, | |
| padding_mode: str = "zeros", | |
| device=None, | |
| dtype=None, | |
| ) -> None: | |
| super().__init__( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| padding, | |
| dilation, | |
| groups, | |
| bias, | |
| padding_mode, | |
| device, | |
| dtype, | |
| ) | |
| def forward(self, x): | |
| return super().forward(x) | |
| class CausalConv3d(nn.Module): | |
| def __init__( | |
| self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs | |
| ): | |
| super().__init__() | |
| self.kernel_size = cast_tuple(kernel_size, 3) | |
| self.time_kernel_size = self.kernel_size[0] | |
| self.chan_in = chan_in | |
| self.chan_out = chan_out | |
| stride = kwargs.pop("stride", 1) | |
| padding = kwargs.pop("padding", 0) | |
| padding = list(cast_tuple(padding, 3)) | |
| padding[0] = 0 | |
| stride = cast_tuple(stride, 3) | |
| self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding) | |
| self._init_weights(init_method) | |
| def _init_weights(self, init_method): | |
| torch.tensor(self.kernel_size) | |
| if init_method == "avg": | |
| assert self.kernel_size[1] == 1 and self.kernel_size[2] == 1, "only support temporal up/down sample" | |
| assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out" | |
| weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)) | |
| eyes = torch.concat( | |
| [ | |
| torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3, | |
| torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3, | |
| torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3, | |
| ], | |
| dim=-1, | |
| ) | |
| weight[:, :, :, 0, 0] = eyes | |
| self.conv.weight = nn.Parameter( | |
| weight, | |
| requires_grad=True, | |
| ) | |
| elif init_method == "zero": | |
| self.conv.weight = nn.Parameter( | |
| torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)), | |
| requires_grad=True, | |
| ) | |
| if self.conv.bias is not None: | |
| nn.init.constant_(self.conv.bias, 0) | |
| def forward(self, x): | |
| # 1 + 16 16 as video, 1 as image | |
| first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w | |
| x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16 | |
| return self.conv(x) | |
| class GroupNorm(Block): | |
| def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True) | |
| def forward(self, x): | |
| return self.norm(x) | |
| def Normalize(in_channels, num_groups=32): | |
| return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) | |
| class ActNorm(nn.Module): | |
| def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): | |
| assert affine | |
| super().__init__() | |
| self.logdet = logdet | |
| self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) | |
| self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) | |
| self.allow_reverse_init = allow_reverse_init | |
| self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) | |
| def initialize(self, input): | |
| with torch.no_grad(): | |
| flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) | |
| mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) | |
| std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) | |
| self.loc.data.copy_(-mean) | |
| self.scale.data.copy_(1 / (std + 1e-6)) | |
| def forward(self, input, reverse=False): | |
| if reverse: | |
| return self.reverse(input) | |
| if len(input.shape) == 2: | |
| input = input[:, :, None, None] | |
| squeeze = True | |
| else: | |
| squeeze = False | |
| _, _, height, width = input.shape | |
| if self.training and self.initialized.item() == 0: | |
| self.initialize(input) | |
| self.initialized.fill_(1) | |
| h = self.scale * (input + self.loc) | |
| if squeeze: | |
| h = h.squeeze(-1).squeeze(-1) | |
| if self.logdet: | |
| log_abs = torch.log(torch.abs(self.scale)) | |
| logdet = height * width * torch.sum(log_abs) | |
| logdet = logdet * torch.ones(input.shape[0]).to(input) | |
| return h, logdet | |
| return h | |
| def reverse(self, output): | |
| if self.training and self.initialized.item() == 0: | |
| if not self.allow_reverse_init: | |
| raise RuntimeError( | |
| "Initializing ActNorm in reverse direction is " | |
| "disabled by default. Use allow_reverse_init=True to enable." | |
| ) | |
| else: | |
| self.initialize(output) | |
| self.initialized.fill_(1) | |
| if len(output.shape) == 2: | |
| output = output[:, :, None, None] | |
| squeeze = True | |
| else: | |
| squeeze = False | |
| h = output / self.scale - self.loc | |
| if squeeze: | |
| h = h.squeeze(-1).squeeze(-1) | |
| return h | |
| def nonlinearity(x): | |
| return x * torch.sigmoid(x) | |
| def cast_tuple(t, length=1): | |
| return t if isinstance(t, tuple) else ((t,) * length) | |
| def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): | |
| n_dims = len(x.shape) | |
| if src_dim < 0: | |
| src_dim = n_dims + src_dim | |
| if dest_dim < 0: | |
| dest_dim = n_dims + dest_dim | |
| assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims | |
| dims = list(range(n_dims)) | |
| del dims[src_dim] | |
| permutation = [] | |
| ctr = 0 | |
| for i in range(n_dims): | |
| if i == dest_dim: | |
| permutation.append(src_dim) | |
| else: | |
| permutation.append(dims[ctr]) | |
| ctr += 1 | |
| x = x.permute(permutation) | |
| if make_contiguous: | |
| x = x.contiguous() | |
| return x | |
| class Codebook(nn.Module): | |
| def __init__(self, n_codes, embedding_dim): | |
| super().__init__() | |
| self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) | |
| self.register_buffer("N", torch.zeros(n_codes)) | |
| self.register_buffer("z_avg", self.embeddings.data.clone()) | |
| self.n_codes = n_codes | |
| self.embedding_dim = embedding_dim | |
| self._need_init = True | |
| def _tile(self, x): | |
| d, ew = x.shape | |
| if d < self.n_codes: | |
| n_repeats = (self.n_codes + d - 1) // d | |
| std = 0.01 / np.sqrt(ew) | |
| x = x.repeat(n_repeats, 1) | |
| x = x + torch.randn_like(x) * std | |
| return x | |
| def _init_embeddings(self, z): | |
| # z: [b, c, t, h, w] | |
| self._need_init = False | |
| flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) | |
| y = self._tile(flat_inputs) | |
| y.shape[0] | |
| _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] | |
| if dist.is_initialized(): | |
| dist.broadcast(_k_rand, 0) | |
| self.embeddings.data.copy_(_k_rand) | |
| self.z_avg.data.copy_(_k_rand) | |
| self.N.data.copy_(torch.ones(self.n_codes)) | |
| def forward(self, z): | |
| # z: [b, c, t, h, w] | |
| if self._need_init and self.training: | |
| self._init_embeddings(z) | |
| flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) | |
| distances = ( | |
| (flat_inputs**2).sum(dim=1, keepdim=True) | |
| - 2 * flat_inputs @ self.embeddings.t() | |
| + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) | |
| ) | |
| encoding_indices = torch.argmin(distances, dim=1) | |
| encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) | |
| encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) | |
| embeddings = F.embedding(encoding_indices, self.embeddings) | |
| embeddings = shift_dim(embeddings, -1, 1) | |
| commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) | |
| # EMA codebook update | |
| if self.training: | |
| n_total = encode_onehot.sum(dim=0) | |
| encode_sum = flat_inputs.t() @ encode_onehot | |
| if dist.is_initialized(): | |
| dist.all_reduce(n_total) | |
| dist.all_reduce(encode_sum) | |
| self.N.data.mul_(0.99).add_(n_total, alpha=0.01) | |
| self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) | |
| n = self.N.sum() | |
| weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n | |
| encode_normalized = self.z_avg / weights.unsqueeze(1) | |
| self.embeddings.data.copy_(encode_normalized) | |
| y = self._tile(flat_inputs) | |
| _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] | |
| if dist.is_initialized(): | |
| dist.broadcast(_k_rand, 0) | |
| usage = (self.N.view(self.n_codes, 1) >= 1).float() | |
| self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) | |
| embeddings_st = (embeddings - z).detach() + z | |
| avg_probs = torch.mean(encode_onehot, dim=0) | |
| perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) | |
| return dict( | |
| embeddings=embeddings_st, | |
| encodings=encoding_indices, | |
| commitment_loss=commitment_loss, | |
| perplexity=perplexity, | |
| ) | |
| def dictionary_lookup(self, encodings): | |
| embeddings = F.embedding(encodings, self.embeddings) | |
| return embeddings | |
| class ResnetBlock2D(Block): | |
| def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = in_channels if out_channels is None else out_channels | |
| self.use_conv_shortcut = conv_shortcut | |
| self.norm1 = Normalize(in_channels) | |
| self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| self.norm2 = Normalize(out_channels) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| else: | |
| self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) | |
| def forward(self, x): | |
| h = x | |
| h = self.norm1(h) | |
| h = nonlinearity(h) | |
| h = self.conv1(h) | |
| h = self.norm2(h) | |
| h = nonlinearity(h) | |
| h = self.dropout(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| x = self.conv_shortcut(x) | |
| else: | |
| x = self.nin_shortcut(x) | |
| x = x + h | |
| return x | |
| class ResnetBlock3D(Block): | |
| def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = in_channels if out_channels is None else out_channels | |
| self.use_conv_shortcut = conv_shortcut | |
| self.norm1 = Normalize(in_channels) | |
| self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) | |
| self.norm2 = Normalize(out_channels) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1) | |
| else: | |
| self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0) | |
| def forward(self, x): | |
| h = x | |
| h = self.norm1(h) | |
| h = nonlinearity(h) | |
| h = self.conv1(h) | |
| h = self.norm2(h) | |
| h = nonlinearity(h) | |
| h = self.dropout(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| x = self.conv_shortcut(x) | |
| else: | |
| x = self.nin_shortcut(x) | |
| return x + h | |
| class Upsample(Block): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.with_conv = True | |
| if self.with_conv: | |
| self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| def forward(self, x): | |
| x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") | |
| if self.with_conv: | |
| x = self.conv(x) | |
| return x | |
| class Downsample(Block): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.with_conv = True | |
| if self.with_conv: | |
| # no asymmetric padding in torch conv, must do it ourselves | |
| self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) | |
| def forward(self, x): | |
| if self.with_conv: | |
| pad = (0, 1, 0, 1) | |
| x = torch.nn.functional.pad(x, pad, mode="constant", value=0) | |
| x = self.conv(x) | |
| else: | |
| x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) | |
| return x | |
| class SpatialDownsample2x(Block): | |
| def __init__( | |
| self, | |
| chan_in, | |
| chan_out, | |
| kernel_size: Union[int, Tuple[int]] = (3, 3), | |
| stride: Union[int, Tuple[int]] = (2, 2), | |
| ): | |
| super().__init__() | |
| kernel_size = cast_tuple(kernel_size, 2) | |
| stride = cast_tuple(stride, 2) | |
| self.chan_in = chan_in | |
| self.chan_out = chan_out | |
| self.kernel_size = kernel_size | |
| self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=0) | |
| def forward(self, x): | |
| pad = (0, 1, 0, 1, 0, 0) | |
| x = torch.nn.functional.pad(x, pad, mode="constant", value=0) | |
| x = self.conv(x) | |
| return x | |
| class SpatialUpsample2x(Block): | |
| def __init__( | |
| self, | |
| chan_in, | |
| chan_out, | |
| kernel_size: Union[int, Tuple[int]] = (3, 3), | |
| stride: Union[int, Tuple[int]] = (1, 1), | |
| ): | |
| super().__init__() | |
| self.chan_in = chan_in | |
| self.chan_out = chan_out | |
| self.kernel_size = kernel_size | |
| self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=1) | |
| def forward(self, x): | |
| t = x.shape[2] | |
| x = rearrange(x, "b c t h w -> b (c t) h w") | |
| x = F.interpolate(x, scale_factor=(2, 2), mode="nearest") | |
| x = rearrange(x, "b (c t) h w -> b c t h w", t=t) | |
| x = self.conv(x) | |
| return x | |
| class TimeDownsample2x(Block): | |
| def __init__(self, chan_in, chan_out, kernel_size: int = 3): | |
| super().__init__() | |
| self.kernel_size = kernel_size | |
| self.conv = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1)) | |
| def forward(self, x): | |
| first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size - 1, 1, 1)) | |
| x = torch.concatenate((first_frame_pad, x), dim=2) | |
| return self.conv(x) | |
| class TimeUpsample2x(Block): | |
| def __init__(self, chan_in, chan_out): | |
| super().__init__() | |
| def forward(self, x): | |
| if x.size(2) > 1: | |
| x, x_ = x[:, :, :1], x[:, :, 1:] | |
| x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear") | |
| x = torch.concat([x, x_], dim=2) | |
| return x | |
| class TimeDownsampleRes2x(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size: int = 3, | |
| mix_factor: float = 2.0, | |
| ): | |
| super().__init__() | |
| self.kernel_size = cast_tuple(kernel_size, 3) | |
| self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1)) | |
| self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1)) | |
| self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) | |
| def forward(self, x): | |
| alpha = torch.sigmoid(self.mix_factor) | |
| first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1)) | |
| x = torch.concatenate((first_frame_pad, x), dim=2) | |
| return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x) | |
| class TimeUpsampleRes2x(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size: int = 3, | |
| mix_factor: float = 2.0, | |
| ): | |
| super().__init__() | |
| self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1) | |
| self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) | |
| def forward(self, x): | |
| alpha = torch.sigmoid(self.mix_factor) | |
| if x.size(2) > 1: | |
| x, x_ = x[:, :, :1], x[:, :, 1:] | |
| x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear") | |
| x = torch.concat([x, x_], dim=2) | |
| return alpha * x + (1 - alpha) * self.conv(x) | |
| class TimeDownsampleResAdv2x(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size: int = 3, | |
| mix_factor: float = 1.5, | |
| ): | |
| super().__init__() | |
| self.kernel_size = cast_tuple(kernel_size, 3) | |
| self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1)) | |
| self.attn = TemporalAttnBlock(in_channels) | |
| self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0) | |
| self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1)) | |
| self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) | |
| def forward(self, x): | |
| first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1)) | |
| x = torch.concatenate((first_frame_pad, x), dim=2) | |
| alpha = torch.sigmoid(self.mix_factor) | |
| return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(self.attn((self.res(x)))) | |
| class TimeUpsampleResAdv2x(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size: int = 3, | |
| mix_factor: float = 1.5, | |
| ): | |
| super().__init__() | |
| self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0) | |
| self.attn = TemporalAttnBlock(in_channels) | |
| self.norm = Normalize(in_channels=in_channels) | |
| self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1) | |
| self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) | |
| def forward(self, x): | |
| if x.size(2) > 1: | |
| x, x_ = x[:, :, :1], x[:, :, 1:] | |
| x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear") | |
| x = torch.concat([x, x_], dim=2) | |
| alpha = torch.sigmoid(self.mix_factor) | |
| return alpha * x + (1 - alpha) * self.conv(self.attn(self.res(x))) | |