pcuenq's picture
pcuenq HF Staff
Upload folder using huggingface_hub
493df70 verified
import torch
from typing import Tuple
class EfficientVideoSampling:
@staticmethod
def compute_retention_mask(
*,
video_embeds: torch.FloatTensor,
thw: torch.LongTensor,
spatial_merge_size: int,
q: float,
):
"""
Computes the retention mask for video embeddings based on the grid dimensions.
Args:
video_embeds (`torch.FloatTensor` of shape `(T * H * W, hidden_size)`):
The video embeddings to compute the retention mask for.
thw (`torch.LongTensor` of shape `(3)`):
The temporal, height and width of feature shape of each video in LLM.
spatial_merge_size (`int`): The spatial merge size of the video embeddings.
If embeddings will be downsampled *later*, this should be the downsampling factor.
q: (`float`): Pruning rate factor, indicating number of tokens to prune (remove)
Returns:
`torch.Tensor`: The retention mask for the video embeddings (T * H * W).
1 for tokens to keep, 0 for tokens to prune.
"""
T, H, W = thw
# video_embeds = einops.rearrange(
# video_embeds,
# "(T H W) C -> T H W C",
# T=T,
# H=H // spatial_merge_size,
# W=W // spatial_merge_size,
# )
# Use reshape instead of einops to avoid graph breaks
video_embeds = video_embeds.reshape(
T, H // spatial_merge_size, W // spatial_merge_size, video_embeds.size(-1)
)
# Core EVS
similarity = torch.nn.functional.cosine_similarity(
video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
)
dissimilarity = 1 - similarity
# Always ensure we include all tokens from the first frame
dissimilarity = torch.cat(
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], dim=0
)
dissimilarity_flat = dissimilarity.view(-1)
min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size) # a single frame
evs_num_tokens = int(T * min_num_tokens * (1 - q))
num_tokens_to_keep = max(min_num_tokens, evs_num_tokens)
order = torch.argsort(dissimilarity_flat,
dim=-1,
descending=True,
stable=True)
topk_indices = order[:num_tokens_to_keep]
retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
retention_mask[topk_indices] = True
retention_mask = retention_mask.reshape(dissimilarity.size())
# print(
# f"Computed retention mask of shape {retention_mask.shape=} with sparsity {retention_mask.float().mean().item():.4f} for {q=}",
# )
mask = retention_mask.view(-1) # "T H W -> (T H W)"
return mask