Spaces:
Running
on
L4
Running
on
L4
File size: 8,243 Bytes
79f197e 32e3eb4 566ae65 d569c73 585cd1b d569c73 585cd1b 566ae65 32e3eb4 d569c73 79f197e 5893d39 79f197e 566ae65 d569c73 32e3eb4 d569c73 79f197e 32e3eb4 79f197e 32e3eb4 79f197e d569c73 79f197e 32e3eb4 79f197e 32e3eb4 79f197e 32e3eb4 d569c73 32e3eb4 d569c73 32e3eb4 d569c73 bc21b9d d569c73 79f197e 32e3eb4 79f197e 32e3eb4 79f197e 32e3eb4 d569c73 bc21b9d d569c73 bc21b9d d569c73 bc21b9d d569c73 585cd1b 42718c3 585cd1b 42718c3 585cd1b 42718c3 32e3eb4 42718c3 32e3eb4 585cd1b 32e3eb4 42718c3 32e3eb4 d569c73 566ae65 32e3eb4 42718c3 585cd1b 32e3eb4 79f197e d569c73 32e3eb4 d569c73 585cd1b d569c73 566ae65 32e3eb4 d569c73 bc21b9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
"""
Real-time video classification using VJEPA2 model with streaming capabilities.
This module implements a real-time video classification system that:
1. Captures video frames from a webcam
2. Processes batches of frames using the V-JEPA 2 model
3. Displays predictions overlaid on the video stream
4. Maintains a history of recent predictions
The system uses FastRTC for video streaming and Gradio for the web interface.
"""
import os
import cv2
import time
import torch
import random
import gradio as gr
import numpy as np
from loguru import logger
from gradio.utils import get_space
from fastrtc import (
Stream,
VideoStreamHandler,
AdditionalOutputs,
get_cloudflare_turn_credentials_async,
get_cloudflare_turn_credentials,
)
from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor
# Model configuration
CHECKPOINT = "facebook/vjepa2-vitl-fpc16-256-ssv2" # Pre-trained VJEPA2 model checkpoint
TORCH_DTYPE = torch.float16 # Use half precision for faster inference
TORCH_DEVICE = "cuda" # Use GPU for inference
UPDATE_EVERY_N_FRAMES = 64 # How often to update predictions (in frames)
HF_TOKEN = os.getenv("HF_TOKEN")
model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16)
model = model.to(TORCH_DEVICE)
video_processor = AutoVideoProcessor.from_pretrained(CHECKPOINT)
frames_per_clip = model.config.frames_per_clip
def add_text_on_image(image, text):
"""
Overlays text on an image with a black background bar at the top.
Args:
image (np.ndarray): Input image to add text to
text (str): Text to overlay on the image
Returns:
np.ndarray: Image with text overlaid
"""
# Add a black background to the text
image[:70] = 0
line_spacing = 10
top_margin = 20
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
thickness = 1
color = (255, 255, 255) # White
words = text.split()
lines = []
current_line = ""
img_width = image.shape[1]
# Build lines that fit within the image width
for word in words:
test_line = current_line + (" " if current_line else "") + word
(test_width, _), _ = cv2.getTextSize(test_line, font, font_scale, thickness)
if test_width > img_width - 20: # 20 px margin
lines.append(current_line)
current_line = word
else:
current_line = test_line
if current_line:
lines.append(current_line)
# Draw each line, centered
y = top_margin
for line in lines:
(line_width, line_height), _ = cv2.getTextSize(line, font, font_scale, thickness)
x = (img_width - line_width) // 2
cv2.putText(image, line, (x, y + line_height), font, font_scale, color, thickness, cv2.LINE_AA)
y += line_height + line_spacing
return image
class RunningFramesCache:
"""
Maintains a rolling buffer of video frames for model input.
This class manages a fixed-size queue of frames, keeping only the most recent
frames needed for model inference. It supports subsampling frames to reduce
memory usage and processing requirements.
Args:
save_every_k_frame (int): Only save every k-th frame (for subsampling)
max_frames (int): Maximum number of frames to keep in cache
"""
def __init__(self, save_every_k_frame: int = 1, max_frames: int = 16):
self.save_every_k_frame = save_every_k_frame
self.max_frames = max_frames
self._frames = []
self.counter = 0
def add_frame(self, frame: np.ndarray):
self.counter += 1
self._frames.append(frame)
if len(self._frames) > self.max_frames:
self._frames.pop(0)
def get_last_n_frames(self, n: int) -> list[np.ndarray]:
return self._frames[-n:]
def __len__(self) -> int:
return len(self._frames)
class RunningResult:
"""
Maintains a history of recent model predictions with timestamps.
This class keeps track of the most recent predictions made by the model,
including timestamps for each prediction. It provides formatted output
for display in the UI.
Args:
max_predictions (int): Maximum number of predictions to keep in history
"""
def __init__(self, max_predictions: int = 4):
self.predictions = []
self.max_predictions = max_predictions
def add_prediction(self, prediction: str):
# add time in a format of HH:MM:SS
current_time_formatted = time.strftime("%H:%M:%S", time.gmtime(time.time()))
self.predictions.append((current_time_formatted, prediction))
if len(self.predictions) > self.max_predictions:
self.predictions.pop(0)
def get_formatted_predictions(self) -> str:
if not self.predictions:
return "Starting..."
current, *past = self.predictions[::-1]
text = f">>> {current[1]}\n\n" + "\n".join(
[f"[{time_formatted}] {prediction}" for time_formatted, prediction in past]
)
return text
def get_last_prediction(self) -> str:
return self.predictions[-1][1] if self.predictions else "Starting..."
def process_frames(image: np.ndarray, frames_state: list, result_state: list, session_cache: list):
if not session_cache:
session_id = random.randint(1, 1000)
session_cache.append(session_id)
else:
session_id = session_cache[0]
# Initialize frames cache if not exists (and put in gradio state)
if not frames_state:
logger.info(f"({session_id}) initialized frames cache")
running_frames_cache = RunningFramesCache(
save_every_k_frame=128 / frames_per_clip,
max_frames=frames_per_clip,
)
frames_state.append(running_frames_cache)
else:
running_frames_cache = frames_state[0]
# Initialize result cache if not exists (and put in gradio state)
if not result_state:
logger.info(f"({session_id}) initialized result cache")
running_result = RunningResult(4)
result_state.append(running_result)
else:
running_result = result_state[0]
# Add frame to frames cache
image = np.flip(image, axis=1).copy()
running_frames_cache.add_frame(image)
# Run model if enough frames are available
if (
running_frames_cache.counter % UPDATE_EVERY_N_FRAMES == 0
and len(running_frames_cache) >= model.config.frames_per_clip
):
# Prepare frames for model
frames = running_frames_cache.get_last_n_frames(model.config.frames_per_clip)
frames = np.array(frames)
inputs = video_processor(frames, device=TORCH_DEVICE, return_tensors="pt")
inputs = inputs.to(dtype=TORCH_DTYPE)
# Run model
with torch.no_grad():
logits = model(**inputs).logits
# Get top prediction
top_index = logits.argmax(dim=-1).item()
class_name = model.config.id2label[top_index]
logger.info(f"({session_id}) action: '{class_name}'")
running_result.add_prediction(class_name)
# Get formatted predictions and last prediction
formatted_predictions = running_result.get_formatted_predictions()
last_prediction = running_result.get_last_prediction()
image = add_text_on_image(image, last_prediction)
return image, AdditionalOutputs(formatted_predictions)
async def get_credentials():
return await get_cloudflare_turn_credentials_async(hf_token=HF_TOKEN)
frames_cache = gr.State([])
result_cache = gr.State([])
session_id = gr.State([])
# Initialize the video stream with processing callback
stream = Stream(
handler=VideoStreamHandler(process_frames, skip_frames=True),
modality="video",
mode="send-receive",
additional_inputs=[frames_cache, result_cache, session_id],
additional_outputs=[gr.TextArea(label="Actions", value="", lines=5)],
additional_outputs_handler=lambda _, output: output,
rtc_configuration=get_credentials if get_space() else None,
server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000) if get_space() else None,
concurrency_limit=3 if get_space() else None,
)
if __name__ == "__main__":
stream.ui.launch()
|