IceClear
upload files
42f2c22
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union, Callable
import torch
from torch import nn
from common.cache import Cache
from common.distributed.ops import slice_inputs
from . import na
from .embedding import TimeEmbedding
from .modulation import get_ada_layer
from .nablocks import get_nablock
from .normalization import get_norm_layer
from .patch import get_na_patch_layers
# Fake func, no checkpointing is required for inference
def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs):
return module(*args, **kwargs)
@dataclass
class NaDiTOutput:
vid_sample: torch.Tensor
class NaDiT(nn.Module):
"""
Native Resolution Diffusion Transformer (NaDiT)
"""
gradient_checkpointing = False
def __init__(
self,
vid_in_channels: int,
vid_out_channels: int,
vid_dim: int,
txt_in_dim: Union[int, List[int]],
txt_dim: Optional[int],
emb_dim: int,
heads: int,
head_dim: int,
expand_ratio: int,
norm: Optional[str],
norm_eps: float,
ada: str,
qk_bias: bool,
qk_norm: Optional[str],
patch_size: Union[int, Tuple[int, int, int]],
num_layers: int,
block_type: Union[str, Tuple[str]],
mm_layers: Union[int, Tuple[bool]],
mlp_type: str = "normal",
patch_type: str = "v1",
rope_type: Optional[str] = "rope3d",
rope_dim: Optional[int] = None,
window: Optional[Tuple] = None,
window_method: Optional[Tuple[str]] = None,
msa_type: Optional[Tuple[str]] = None,
mca_type: Optional[Tuple[str]] = None,
txt_in_norm: Optional[str] = None,
txt_in_norm_scale_factor: int = 0.01,
txt_proj_type: Optional[str] = "linear",
vid_out_norm: Optional[str] = None,
**kwargs,
):
ada = get_ada_layer(ada)
norm = get_norm_layer(norm)
qk_norm = get_norm_layer(qk_norm)
rope_dim = rope_dim if rope_dim is not None else head_dim // 2
if isinstance(block_type, str):
block_type = [block_type] * num_layers
elif len(block_type) != num_layers:
raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
super().__init__()
NaPatchIn, NaPatchOut = get_na_patch_layers(patch_type)
self.vid_in = NaPatchIn(
in_channels=vid_in_channels,
patch_size=patch_size,
dim=vid_dim,
)
if not isinstance(txt_in_dim, int):
self.txt_in = nn.ModuleList([])
for in_dim in txt_in_dim:
txt_norm_layer = get_norm_layer(txt_in_norm)(txt_dim, norm_eps, True)
if txt_proj_type == "linear":
txt_proj_layer = nn.Linear(in_dim, txt_dim)
else:
txt_proj_layer = nn.Sequential(
nn.Linear(in_dim, in_dim), nn.GELU("tanh"), nn.Linear(in_dim, txt_dim)
)
torch.nn.init.constant_(txt_norm_layer.weight, txt_in_norm_scale_factor)
self.txt_in.append(
nn.Sequential(
txt_proj_layer,
txt_norm_layer,
)
)
else:
self.txt_in = (
nn.Linear(txt_in_dim, txt_dim)
if txt_in_dim and txt_in_dim != txt_dim
else nn.Identity()
)
self.emb_in = TimeEmbedding(
sinusoidal_dim=256,
hidden_dim=max(vid_dim, txt_dim),
output_dim=emb_dim,
)
if window is None or isinstance(window[0], int):
window = [window] * num_layers
if window_method is None or isinstance(window_method, str):
window_method = [window_method] * num_layers
if msa_type is None or isinstance(msa_type, str):
msa_type = [msa_type] * num_layers
if mca_type is None or isinstance(mca_type, str):
mca_type = [mca_type] * num_layers
self.blocks = nn.ModuleList(
[
get_nablock(block_type[i])(
vid_dim=vid_dim,
txt_dim=txt_dim,
emb_dim=emb_dim,
heads=heads,
head_dim=head_dim,
expand_ratio=expand_ratio,
norm=norm,
norm_eps=norm_eps,
ada=ada,
qk_bias=qk_bias,
qk_norm=qk_norm,
shared_weights=not (
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
),
mlp_type=mlp_type,
window=window[i],
window_method=window_method[i],
msa_type=msa_type[i],
mca_type=mca_type[i],
rope_type=rope_type,
rope_dim=rope_dim,
is_last_layer=(i == num_layers - 1),
**kwargs,
)
for i in range(num_layers)
]
)
self.vid_out_norm = None
if vid_out_norm is not None:
self.vid_out_norm = get_norm_layer(vid_out_norm)(
dim=vid_dim,
eps=norm_eps,
elementwise_affine=True,
)
self.vid_out_ada = ada(
dim=vid_dim,
emb_dim=emb_dim,
layers=["out"],
modes=["in"],
)
self.vid_out = NaPatchOut(
out_channels=vid_out_channels,
patch_size=patch_size,
dim=vid_dim,
)
def set_gradient_checkpointing(self, enable: bool):
self.gradient_checkpointing = enable
def forward(
self,
vid: torch.FloatTensor, # l c
txt: Union[torch.FloatTensor, List[torch.FloatTensor]], # l c
vid_shape: torch.LongTensor, # b 3
txt_shape: Union[torch.LongTensor, List[torch.LongTensor]], # b 1
timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b
disable_cache: bool = False, # for test
):
cache = Cache(disable=disable_cache)
# slice vid after patching in when using sequence parallelism
if isinstance(txt, list):
assert isinstance(self.txt_in, nn.ModuleList)
txt = [
na.unflatten(fc(i), s) for fc, i, s in zip(self.txt_in, txt, txt_shape)
] # B L D
txt, txt_shape = na.flatten([torch.cat(t, dim=0) for t in zip(*txt)])
txt = slice_inputs(txt, dim=0)
else:
txt = slice_inputs(txt, dim=0)
txt = self.txt_in(txt)
# Video input.
# Sequence parallel slicing is done inside patching class.
vid, vid_shape = self.vid_in(vid, vid_shape, cache)
# Embedding input.
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
# Body
for i, block in enumerate(self.blocks):
vid, txt, vid_shape, txt_shape = gradient_checkpointing(
enabled=(self.gradient_checkpointing and self.training),
module=block,
vid=vid,
txt=txt,
vid_shape=vid_shape,
txt_shape=txt_shape,
emb=emb,
cache=cache,
)
# Video output norm.
if self.vid_out_norm:
vid = self.vid_out_norm(vid)
vid = self.vid_out_ada(
vid,
emb=emb,
layer="out",
mode="in",
hid_len=cache("vid_len", lambda: vid_shape.prod(-1)),
cache=cache,
branch_tag="vid",
)
# Video output.
vid, vid_shape = self.vid_out(vid, vid_shape, cache)
return NaDiTOutput(vid_sample=vid)