Spaces:
Running
on
Zero
Running
on
Zero
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
# // | |
# // Licensed under the Apache License, Version 2.0 (the "License"); | |
# // you may not use this file except in compliance with the License. | |
# // You may obtain a copy of the License at | |
# // | |
# // http://www.apache.org/licenses/LICENSE-2.0 | |
# // | |
# // Unless required by applicable law or agreed to in writing, software | |
# // distributed under the License is distributed on an "AS IS" BASIS, | |
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# // See the License for the specific language governing permissions and | |
# // limitations under the License. | |
from dataclasses import dataclass | |
from typing import List, Optional, Tuple, Union, Callable | |
import torch | |
from torch import nn | |
from common.cache import Cache | |
from common.distributed.ops import slice_inputs | |
from . import na | |
from .embedding import TimeEmbedding | |
from .modulation import get_ada_layer | |
from .nablocks import get_nablock | |
from .normalization import get_norm_layer | |
from .patch import get_na_patch_layers | |
# Fake func, no checkpointing is required for inference | |
def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs): | |
return module(*args, **kwargs) | |
class NaDiTOutput: | |
vid_sample: torch.Tensor | |
class NaDiT(nn.Module): | |
""" | |
Native Resolution Diffusion Transformer (NaDiT) | |
""" | |
gradient_checkpointing = False | |
def __init__( | |
self, | |
vid_in_channels: int, | |
vid_out_channels: int, | |
vid_dim: int, | |
txt_in_dim: Union[int, List[int]], | |
txt_dim: Optional[int], | |
emb_dim: int, | |
heads: int, | |
head_dim: int, | |
expand_ratio: int, | |
norm: Optional[str], | |
norm_eps: float, | |
ada: str, | |
qk_bias: bool, | |
qk_norm: Optional[str], | |
patch_size: Union[int, Tuple[int, int, int]], | |
num_layers: int, | |
block_type: Union[str, Tuple[str]], | |
mm_layers: Union[int, Tuple[bool]], | |
mlp_type: str = "normal", | |
patch_type: str = "v1", | |
rope_type: Optional[str] = "rope3d", | |
rope_dim: Optional[int] = None, | |
window: Optional[Tuple] = None, | |
window_method: Optional[Tuple[str]] = None, | |
msa_type: Optional[Tuple[str]] = None, | |
mca_type: Optional[Tuple[str]] = None, | |
txt_in_norm: Optional[str] = None, | |
txt_in_norm_scale_factor: int = 0.01, | |
txt_proj_type: Optional[str] = "linear", | |
vid_out_norm: Optional[str] = None, | |
**kwargs, | |
): | |
ada = get_ada_layer(ada) | |
norm = get_norm_layer(norm) | |
qk_norm = get_norm_layer(qk_norm) | |
rope_dim = rope_dim if rope_dim is not None else head_dim // 2 | |
if isinstance(block_type, str): | |
block_type = [block_type] * num_layers | |
elif len(block_type) != num_layers: | |
raise ValueError("The ``block_type`` list should equal to ``num_layers``.") | |
super().__init__() | |
NaPatchIn, NaPatchOut = get_na_patch_layers(patch_type) | |
self.vid_in = NaPatchIn( | |
in_channels=vid_in_channels, | |
patch_size=patch_size, | |
dim=vid_dim, | |
) | |
if not isinstance(txt_in_dim, int): | |
self.txt_in = nn.ModuleList([]) | |
for in_dim in txt_in_dim: | |
txt_norm_layer = get_norm_layer(txt_in_norm)(txt_dim, norm_eps, True) | |
if txt_proj_type == "linear": | |
txt_proj_layer = nn.Linear(in_dim, txt_dim) | |
else: | |
txt_proj_layer = nn.Sequential( | |
nn.Linear(in_dim, in_dim), nn.GELU("tanh"), nn.Linear(in_dim, txt_dim) | |
) | |
torch.nn.init.constant_(txt_norm_layer.weight, txt_in_norm_scale_factor) | |
self.txt_in.append( | |
nn.Sequential( | |
txt_proj_layer, | |
txt_norm_layer, | |
) | |
) | |
else: | |
self.txt_in = ( | |
nn.Linear(txt_in_dim, txt_dim) | |
if txt_in_dim and txt_in_dim != txt_dim | |
else nn.Identity() | |
) | |
self.emb_in = TimeEmbedding( | |
sinusoidal_dim=256, | |
hidden_dim=max(vid_dim, txt_dim), | |
output_dim=emb_dim, | |
) | |
if window is None or isinstance(window[0], int): | |
window = [window] * num_layers | |
if window_method is None or isinstance(window_method, str): | |
window_method = [window_method] * num_layers | |
if msa_type is None or isinstance(msa_type, str): | |
msa_type = [msa_type] * num_layers | |
if mca_type is None or isinstance(mca_type, str): | |
mca_type = [mca_type] * num_layers | |
self.blocks = nn.ModuleList( | |
[ | |
get_nablock(block_type[i])( | |
vid_dim=vid_dim, | |
txt_dim=txt_dim, | |
emb_dim=emb_dim, | |
heads=heads, | |
head_dim=head_dim, | |
expand_ratio=expand_ratio, | |
norm=norm, | |
norm_eps=norm_eps, | |
ada=ada, | |
qk_bias=qk_bias, | |
qk_norm=qk_norm, | |
shared_weights=not ( | |
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] | |
), | |
mlp_type=mlp_type, | |
window=window[i], | |
window_method=window_method[i], | |
msa_type=msa_type[i], | |
mca_type=mca_type[i], | |
rope_type=rope_type, | |
rope_dim=rope_dim, | |
is_last_layer=(i == num_layers - 1), | |
**kwargs, | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.vid_out_norm = None | |
if vid_out_norm is not None: | |
self.vid_out_norm = get_norm_layer(vid_out_norm)( | |
dim=vid_dim, | |
eps=norm_eps, | |
elementwise_affine=True, | |
) | |
self.vid_out_ada = ada( | |
dim=vid_dim, | |
emb_dim=emb_dim, | |
layers=["out"], | |
modes=["in"], | |
) | |
self.vid_out = NaPatchOut( | |
out_channels=vid_out_channels, | |
patch_size=patch_size, | |
dim=vid_dim, | |
) | |
def set_gradient_checkpointing(self, enable: bool): | |
self.gradient_checkpointing = enable | |
def forward( | |
self, | |
vid: torch.FloatTensor, # l c | |
txt: Union[torch.FloatTensor, List[torch.FloatTensor]], # l c | |
vid_shape: torch.LongTensor, # b 3 | |
txt_shape: Union[torch.LongTensor, List[torch.LongTensor]], # b 1 | |
timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b | |
disable_cache: bool = False, # for test | |
): | |
cache = Cache(disable=disable_cache) | |
# slice vid after patching in when using sequence parallelism | |
if isinstance(txt, list): | |
assert isinstance(self.txt_in, nn.ModuleList) | |
txt = [ | |
na.unflatten(fc(i), s) for fc, i, s in zip(self.txt_in, txt, txt_shape) | |
] # B L D | |
txt, txt_shape = na.flatten([torch.cat(t, dim=0) for t in zip(*txt)]) | |
txt = slice_inputs(txt, dim=0) | |
else: | |
txt = slice_inputs(txt, dim=0) | |
txt = self.txt_in(txt) | |
# Video input. | |
# Sequence parallel slicing is done inside patching class. | |
vid, vid_shape = self.vid_in(vid, vid_shape, cache) | |
# Embedding input. | |
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) | |
# Body | |
for i, block in enumerate(self.blocks): | |
vid, txt, vid_shape, txt_shape = gradient_checkpointing( | |
enabled=(self.gradient_checkpointing and self.training), | |
module=block, | |
vid=vid, | |
txt=txt, | |
vid_shape=vid_shape, | |
txt_shape=txt_shape, | |
emb=emb, | |
cache=cache, | |
) | |
# Video output norm. | |
if self.vid_out_norm: | |
vid = self.vid_out_norm(vid) | |
vid = self.vid_out_ada( | |
vid, | |
emb=emb, | |
layer="out", | |
mode="in", | |
hid_len=cache("vid_len", lambda: vid_shape.prod(-1)), | |
cache=cache, | |
branch_tag="vid", | |
) | |
# Video output. | |
vid, vid_shape = self.vid_out(vid, vid_shape, cache) | |
return NaDiTOutput(vid_sample=vid) | |