File size: 2,917 Bytes
493df70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
from typing import Tuple

class EfficientVideoSampling:
    @staticmethod
    def compute_retention_mask(
        *,
        video_embeds: torch.FloatTensor,
        thw: torch.LongTensor,
        spatial_merge_size: int,
        q: float,
    ):
        """
        Computes the retention mask for video embeddings based on the grid dimensions.

        Args:
            video_embeds (`torch.FloatTensor` of shape `(T * H * W, hidden_size)`):
                The video embeddings to compute the retention mask for.
            thw (`torch.LongTensor` of shape `(3)`):
                The temporal, height and width of feature shape of each video in LLM.
            spatial_merge_size (`int`): The spatial merge size of the video embeddings.
                If embeddings will be downsampled *later*, this should be the downsampling factor.
            q: (`float`): Pruning rate factor, indicating number of tokens to prune (remove)

        Returns:
            `torch.Tensor`: The retention mask for the video embeddings (T * H * W).
                1 for tokens to keep, 0 for tokens to prune.
        """
        T, H, W = thw

        # video_embeds = einops.rearrange(
        #     video_embeds,
        #     "(T H W) C -> T H W C",
        #     T=T,
        #     H=H // spatial_merge_size,
        #     W=W // spatial_merge_size,
        # )
        # Use reshape instead of einops to avoid graph breaks
        video_embeds = video_embeds.reshape(
            T, H // spatial_merge_size, W // spatial_merge_size, video_embeds.size(-1)
        )

        # Core EVS
        similarity = torch.nn.functional.cosine_similarity(
            video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
        )
        dissimilarity = 1 - similarity

        # Always ensure we include all tokens from the first frame
        dissimilarity = torch.cat(
            [255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], dim=0
        )
        dissimilarity_flat = dissimilarity.view(-1)

        min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size)  # a single frame
        evs_num_tokens = int(T * min_num_tokens * (1 - q))
        num_tokens_to_keep = max(min_num_tokens, evs_num_tokens)

        order = torch.argsort(dissimilarity_flat,
                              dim=-1,
                              descending=True,
                              stable=True)
        topk_indices = order[:num_tokens_to_keep]

        retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
        retention_mask[topk_indices] = True
        retention_mask = retention_mask.reshape(dissimilarity.size())

        # print(
        #     f"Computed retention mask of shape {retention_mask.shape=} with sparsity {retention_mask.float().mean().item():.4f} for {q=}",
        # )
        mask = retention_mask.view(-1)  # "T H W -> (T H W)"
        return mask