# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # // # // Licensed under the Apache License, Version 2.0 (the "License"); # // you may not use this file except in compliance with the License. # // You may obtain a copy of the License at # // # // http://www.apache.org/licenses/LICENSE-2.0 # // # // Unless required by applicable law or agreed to in writing, software # // distributed under the License is distributed on an "AS IS" BASIS, # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # // See the License for the specific language governing permissions and # // limitations under the License. from typing import Tuple, Union import torch from einops import rearrange from torch import nn from torch.nn.modules.utils import _triple from common.cache import Cache from common.distributed.ops import gather_outputs, slice_inputs from .. import na class PatchIn(nn.Module): def __init__( self, in_channels: int, patch_size: Union[int, Tuple[int, int, int]], dim: int, ): super().__init__() t, h, w = _triple(patch_size) self.patch_size = t, h, w self.proj = nn.Linear(in_channels * t * h * w, dim) def forward( self, vid: torch.Tensor, ) -> torch.Tensor: t, h, w = self.patch_size if t > 1: assert vid.size(2) % t == 1 vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) vid = self.proj(vid) return vid class PatchOut(nn.Module): def __init__( self, out_channels: int, patch_size: Union[int, Tuple[int, int, int]], dim: int, ): super().__init__() t, h, w = _triple(patch_size) self.patch_size = t, h, w self.proj = nn.Linear(dim, out_channels * t * h * w) def forward( self, vid: torch.Tensor, ) -> torch.Tensor: t, h, w = self.patch_size vid = self.proj(vid) vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) if t > 1: vid = vid[:, :, (t - 1) :] return vid class NaPatchIn(PatchIn): def forward( self, vid: torch.Tensor, # l c vid_shape: torch.LongTensor, cache: Cache = Cache(disable=True), # for test ) -> torch.Tensor: cache = cache.namespace("patch") vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) t, h, w = self.patch_size if not (t == h == w == 1): vid = na.unflatten(vid, vid_shape) for i in range(len(vid)): if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) vid, vid_shape = na.flatten(vid) # slice vid after patching in when using sequence parallelism vid = slice_inputs(vid, dim=0) vid = self.proj(vid) return vid, vid_shape class NaPatchOut(PatchOut): def forward( self, vid: torch.FloatTensor, # l c vid_shape: torch.LongTensor, cache: Cache = Cache(disable=True), # for test ) -> Tuple[ torch.FloatTensor, torch.LongTensor, ]: cache = cache.namespace("patch") vid_shape_before_patchify = cache.get("vid_shape_before_patchify") t, h, w = self.patch_size vid = self.proj(vid) # gather vid before patching out when enabling sequence parallelism vid = gather_outputs( vid, gather_dim=0, padding_dim=0, unpad_shape=vid_shape, cache=cache.namespace("vid") ) if not (t == h == w == 1): vid = na.unflatten(vid, vid_shape) for i in range(len(vid)): vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] vid, vid_shape = na.flatten(vid) return vid, vid_shape