|
|
import torch |
|
|
from typing import Tuple |
|
|
|
|
|
class EfficientVideoSampling: |
|
|
@staticmethod |
|
|
def compute_retention_mask( |
|
|
*, |
|
|
video_embeds: torch.FloatTensor, |
|
|
thw: torch.LongTensor, |
|
|
spatial_merge_size: int, |
|
|
q: float, |
|
|
): |
|
|
""" |
|
|
Computes the retention mask for video embeddings based on the grid dimensions. |
|
|
|
|
|
Args: |
|
|
video_embeds (`torch.FloatTensor` of shape `(T * H * W, hidden_size)`): |
|
|
The video embeddings to compute the retention mask for. |
|
|
thw (`torch.LongTensor` of shape `(3)`): |
|
|
The temporal, height and width of feature shape of each video in LLM. |
|
|
spatial_merge_size (`int`): The spatial merge size of the video embeddings. |
|
|
If embeddings will be downsampled *later*, this should be the downsampling factor. |
|
|
q: (`float`): Pruning rate factor, indicating number of tokens to prune (remove) |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: The retention mask for the video embeddings (T * H * W). |
|
|
1 for tokens to keep, 0 for tokens to prune. |
|
|
""" |
|
|
T, H, W = thw |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_embeds = video_embeds.reshape( |
|
|
T, H // spatial_merge_size, W // spatial_merge_size, video_embeds.size(-1) |
|
|
) |
|
|
|
|
|
|
|
|
similarity = torch.nn.functional.cosine_similarity( |
|
|
video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1 |
|
|
) |
|
|
dissimilarity = 1 - similarity |
|
|
|
|
|
|
|
|
dissimilarity = torch.cat( |
|
|
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], dim=0 |
|
|
) |
|
|
dissimilarity_flat = dissimilarity.view(-1) |
|
|
|
|
|
min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size) |
|
|
evs_num_tokens = int(T * min_num_tokens * (1 - q)) |
|
|
num_tokens_to_keep = max(min_num_tokens, evs_num_tokens) |
|
|
|
|
|
order = torch.argsort(dissimilarity_flat, |
|
|
dim=-1, |
|
|
descending=True, |
|
|
stable=True) |
|
|
topk_indices = order[:num_tokens_to_keep] |
|
|
|
|
|
retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool) |
|
|
retention_mask[topk_indices] = True |
|
|
retention_mask = retention_mask.reshape(dissimilarity.size()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = retention_mask.view(-1) |
|
|
return mask |
|
|
|