File size: 8,243 Bytes
79f197e
 
 
 
 
 
 
 
 
 
 
32e3eb4
566ae65
d569c73
 
 
585cd1b
d569c73
 
 
585cd1b
566ae65
32e3eb4
 
 
 
 
 
 
d569c73
 
 
79f197e
5893d39
79f197e
 
 
566ae65
d569c73
32e3eb4
 
 
 
 
 
 
d569c73
79f197e
 
32e3eb4
79f197e
 
 
32e3eb4
79f197e
 
 
d569c73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79f197e
 
32e3eb4
79f197e
 
 
32e3eb4
79f197e
 
 
 
32e3eb4
d569c73
 
 
 
32e3eb4
d569c73
 
32e3eb4
d569c73
 
 
 
bc21b9d
 
d569c73
 
 
 
 
 
79f197e
 
32e3eb4
79f197e
 
 
32e3eb4
79f197e
 
 
32e3eb4
d569c73
 
 
 
 
 
 
 
 
 
bc21b9d
d569c73
 
 
bc21b9d
d569c73
bc21b9d
 
 
d569c73
 
 
 
 
 
585cd1b
 
 
 
 
 
 
 
42718c3
 
585cd1b
42718c3
 
 
 
 
 
 
 
 
 
585cd1b
42718c3
 
 
 
 
 
32e3eb4
 
 
42718c3
32e3eb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585cd1b
32e3eb4
 
42718c3
32e3eb4
 
 
 
d569c73
 
566ae65
 
 
32e3eb4
42718c3
 
585cd1b
32e3eb4
79f197e
d569c73
32e3eb4
d569c73
 
585cd1b
d569c73
 
566ae65
 
32e3eb4
d569c73
 
 
 
bc21b9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""
Real-time video classification using VJEPA2 model with streaming capabilities.

This module implements a real-time video classification system that:
1. Captures video frames from a webcam
2. Processes batches of frames using the V-JEPA 2 model
3. Displays predictions overlaid on the video stream
4. Maintains a history of recent predictions

The system uses FastRTC for video streaming and Gradio for the web interface.
"""

import os
import cv2
import time
import torch
import random
import gradio as gr
import numpy as np

from loguru import logger
from gradio.utils import get_space
from fastrtc import (
    Stream,
    VideoStreamHandler,
    AdditionalOutputs,
    get_cloudflare_turn_credentials_async,
    get_cloudflare_turn_credentials,
)
from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor


# Model configuration
CHECKPOINT = "facebook/vjepa2-vitl-fpc16-256-ssv2"  # Pre-trained VJEPA2 model checkpoint
TORCH_DTYPE = torch.float16  # Use half precision for faster inference
TORCH_DEVICE = "cuda"  # Use GPU for inference
UPDATE_EVERY_N_FRAMES = 64  # How often to update predictions (in frames)
HF_TOKEN = os.getenv("HF_TOKEN")


model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16)
model = model.to(TORCH_DEVICE)
video_processor = AutoVideoProcessor.from_pretrained(CHECKPOINT)
frames_per_clip = model.config.frames_per_clip


def add_text_on_image(image, text):
    """
    Overlays text on an image with a black background bar at the top.

    Args:
        image (np.ndarray): Input image to add text to
        text (str): Text to overlay on the image

    Returns:
        np.ndarray: Image with text overlaid
    """
    # Add a black background to the text
    image[:70] = 0

    line_spacing = 10
    top_margin = 20

    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.5
    thickness = 1
    color = (255, 255, 255)  # White

    words = text.split()
    lines = []
    current_line = ""

    img_width = image.shape[1]

    # Build lines that fit within the image width
    for word in words:
        test_line = current_line + (" " if current_line else "") + word
        (test_width, _), _ = cv2.getTextSize(test_line, font, font_scale, thickness)
        if test_width > img_width - 20:  # 20 px margin
            lines.append(current_line)
            current_line = word
        else:
            current_line = test_line
    if current_line:
        lines.append(current_line)

    # Draw each line, centered
    y = top_margin
    for line in lines:
        (line_width, line_height), _ = cv2.getTextSize(line, font, font_scale, thickness)
        x = (img_width - line_width) // 2
        cv2.putText(image, line, (x, y + line_height), font, font_scale, color, thickness, cv2.LINE_AA)
        y += line_height + line_spacing

    return image


class RunningFramesCache:
    """
    Maintains a rolling buffer of video frames for model input.

    This class manages a fixed-size queue of frames, keeping only the most recent
    frames needed for model inference. It supports subsampling frames to reduce
    memory usage and processing requirements.

    Args:
        save_every_k_frame (int): Only save every k-th frame (for subsampling)
        max_frames (int): Maximum number of frames to keep in cache
    """

    def __init__(self, save_every_k_frame: int = 1, max_frames: int = 16):
        self.save_every_k_frame = save_every_k_frame
        self.max_frames = max_frames
        self._frames = []
        self.counter = 0

    def add_frame(self, frame: np.ndarray):
        self.counter += 1
        self._frames.append(frame)
        if len(self._frames) > self.max_frames:
            self._frames.pop(0)

    def get_last_n_frames(self, n: int) -> list[np.ndarray]:
        return self._frames[-n:]

    def __len__(self) -> int:
        return len(self._frames)


class RunningResult:
    """
    Maintains a history of recent model predictions with timestamps.

    This class keeps track of the most recent predictions made by the model,
    including timestamps for each prediction. It provides formatted output
    for display in the UI.

    Args:
        max_predictions (int): Maximum number of predictions to keep in history
    """

    def __init__(self, max_predictions: int = 4):
        self.predictions = []
        self.max_predictions = max_predictions

    def add_prediction(self, prediction: str):
        # add time in a format of HH:MM:SS
        current_time_formatted = time.strftime("%H:%M:%S", time.gmtime(time.time()))
        self.predictions.append((current_time_formatted, prediction))
        if len(self.predictions) > self.max_predictions:
            self.predictions.pop(0)

    def get_formatted_predictions(self) -> str:
        if not self.predictions:
            return "Starting..."

        current, *past = self.predictions[::-1]
        text = f">>> {current[1]}\n\n" + "\n".join(
            [f"[{time_formatted}] {prediction}" for time_formatted, prediction in past]
        )
        return text

    def get_last_prediction(self) -> str:
        return self.predictions[-1][1] if self.predictions else "Starting..."


def process_frames(image: np.ndarray, frames_state: list, result_state: list, session_cache: list):

    if not session_cache:
        session_id = random.randint(1, 1000)
        session_cache.append(session_id)
    else:
        session_id = session_cache[0]

    # Initialize frames cache if not exists (and put in gradio state)
    if not frames_state:
        logger.info(f"({session_id}) initialized frames cache")
        running_frames_cache = RunningFramesCache(
            save_every_k_frame=128 / frames_per_clip,
            max_frames=frames_per_clip,
        )
        frames_state.append(running_frames_cache)
    else:
        running_frames_cache = frames_state[0]

    # Initialize result cache if not exists (and put in gradio state)
    if not result_state:
        logger.info(f"({session_id}) initialized result cache")
        running_result = RunningResult(4)
        result_state.append(running_result)
    else:
        running_result = result_state[0]

    # Add frame to frames cache
    image = np.flip(image, axis=1).copy()
    running_frames_cache.add_frame(image)

    # Run model if enough frames are available
    if (
        running_frames_cache.counter % UPDATE_EVERY_N_FRAMES == 0
        and len(running_frames_cache) >= model.config.frames_per_clip
    ):
        # Prepare frames for model
        frames = running_frames_cache.get_last_n_frames(model.config.frames_per_clip)
        frames = np.array(frames)

        inputs = video_processor(frames, device=TORCH_DEVICE, return_tensors="pt")
        inputs = inputs.to(dtype=TORCH_DTYPE)

        # Run model
        with torch.no_grad():
            logits = model(**inputs).logits

        # Get top prediction
        top_index = logits.argmax(dim=-1).item()
        class_name = model.config.id2label[top_index]
        logger.info(f"({session_id}) action: '{class_name}'")
        running_result.add_prediction(class_name)

    # Get formatted predictions and last prediction
    formatted_predictions = running_result.get_formatted_predictions()
    last_prediction = running_result.get_last_prediction()
    image = add_text_on_image(image, last_prediction)
    return image, AdditionalOutputs(formatted_predictions)


async def get_credentials():
    return await get_cloudflare_turn_credentials_async(hf_token=HF_TOKEN)


frames_cache = gr.State([])
result_cache = gr.State([])
session_id = gr.State([])

# Initialize the video stream with processing callback
stream = Stream(
    handler=VideoStreamHandler(process_frames, skip_frames=True),
    modality="video",
    mode="send-receive",
    additional_inputs=[frames_cache, result_cache, session_id],
    additional_outputs=[gr.TextArea(label="Actions", value="", lines=5)],
    additional_outputs_handler=lambda _, output: output,
    rtc_configuration=get_credentials if get_space() else None,
    server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000) if get_space() else None,
    concurrency_limit=3 if get_space() else None,
)


if __name__ == "__main__":
    stream.ui.launch()