Spaces:
Running
on
Zero
Running
on
Zero
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
# // | |
# // Licensed under the Apache License, Version 2.0 (the "License"); | |
# // you may not use this file except in compliance with the License. | |
# // You may obtain a copy of the License at | |
# // | |
# // http://www.apache.org/licenses/LICENSE-2.0 | |
# // | |
# // Unless required by applicable law or agreed to in writing, software | |
# // distributed under the License is distributed on an "AS IS" BASIS, | |
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# // See the License for the specific language governing permissions and | |
# // limitations under the License. | |
from typing import Tuple, Union | |
import torch | |
from einops import rearrange | |
from torch import nn | |
from torch.nn.modules.utils import _triple | |
from common.cache import Cache | |
from common.distributed.ops import gather_outputs, slice_inputs | |
from .. import na | |
class PatchIn(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
patch_size: Union[int, Tuple[int, int, int]], | |
dim: int, | |
): | |
super().__init__() | |
t, h, w = _triple(patch_size) | |
self.patch_size = t, h, w | |
self.proj = nn.Linear(in_channels * t * h * w, dim) | |
def forward( | |
self, | |
vid: torch.Tensor, | |
) -> torch.Tensor: | |
t, h, w = self.patch_size | |
if t > 1: | |
assert vid.size(2) % t == 1 | |
vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) | |
vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) | |
vid = self.proj(vid) | |
return vid | |
class PatchOut(nn.Module): | |
def __init__( | |
self, | |
out_channels: int, | |
patch_size: Union[int, Tuple[int, int, int]], | |
dim: int, | |
): | |
super().__init__() | |
t, h, w = _triple(patch_size) | |
self.patch_size = t, h, w | |
self.proj = nn.Linear(dim, out_channels * t * h * w) | |
def forward( | |
self, | |
vid: torch.Tensor, | |
) -> torch.Tensor: | |
t, h, w = self.patch_size | |
vid = self.proj(vid) | |
vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) | |
if t > 1: | |
vid = vid[:, :, (t - 1) :] | |
return vid | |
class NaPatchIn(PatchIn): | |
def forward( | |
self, | |
vid: torch.Tensor, # l c | |
vid_shape: torch.LongTensor, | |
cache: Cache = Cache(disable=True), # for test | |
) -> torch.Tensor: | |
cache = cache.namespace("patch") | |
vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) | |
t, h, w = self.patch_size | |
if not (t == h == w == 1): | |
vid = na.unflatten(vid, vid_shape) | |
for i in range(len(vid)): | |
if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: | |
vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) | |
vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) | |
vid, vid_shape = na.flatten(vid) | |
# slice vid after patching in when using sequence parallelism | |
vid = slice_inputs(vid, dim=0) | |
vid = self.proj(vid) | |
return vid, vid_shape | |
class NaPatchOut(PatchOut): | |
def forward( | |
self, | |
vid: torch.FloatTensor, # l c | |
vid_shape: torch.LongTensor, | |
cache: Cache = Cache(disable=True), # for test | |
) -> Tuple[ | |
torch.FloatTensor, | |
torch.LongTensor, | |
]: | |
cache = cache.namespace("patch") | |
vid_shape_before_patchify = cache.get("vid_shape_before_patchify") | |
t, h, w = self.patch_size | |
vid = self.proj(vid) | |
# gather vid before patching out when enabling sequence parallelism | |
vid = gather_outputs( | |
vid, gather_dim=0, padding_dim=0, unpad_shape=vid_shape, cache=cache.namespace("vid") | |
) | |
if not (t == h == w == 1): | |
vid = na.unflatten(vid, vid_shape) | |
for i in range(len(vid)): | |
vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) | |
if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: | |
vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] | |
vid, vid_shape = na.flatten(vid) | |
return vid, vid_shape | |