Spaces:
Running
on
L4
Running
on
L4
Add comments and docstrings
Browse files
app.py
CHANGED
@@ -1,3 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import cv2
|
2 |
import time
|
3 |
import torch
|
@@ -8,13 +19,24 @@ from fastrtc import Stream, VideoStreamHandler, AdditionalOutputs
|
|
8 |
from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor
|
9 |
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
15 |
|
16 |
|
17 |
def add_text_on_image(image, text):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
# Add a black background to the text
|
19 |
image[:70] = 0
|
20 |
|
@@ -56,6 +78,17 @@ def add_text_on_image(image, text):
|
|
56 |
|
57 |
|
58 |
class RunningFramesCache:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
def __init__(self, save_every_k_frame: int = 1, max_frames: int = 16):
|
60 |
self.save_every_k_frame = save_every_k_frame
|
61 |
self.max_frames = max_frames
|
@@ -74,6 +107,16 @@ class RunningFramesCache:
|
|
74 |
|
75 |
|
76 |
class RunningResult:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
def __init__(self, max_predictions: int = 4):
|
78 |
self.predictions = []
|
79 |
self.max_predictions = max_predictions
|
@@ -100,6 +143,19 @@ class RunningResult:
|
|
100 |
|
101 |
|
102 |
class FrameProcessingCallback:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
def __init__(self):
|
104 |
# Loading model and processor
|
105 |
self.model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16)
|
@@ -146,6 +202,7 @@ class FrameProcessingCallback:
|
|
146 |
return image, AdditionalOutputs(formatted_predictions)
|
147 |
|
148 |
|
|
|
149 |
stream = Stream(
|
150 |
handler=VideoStreamHandler(FrameProcessingCallback(), skip_frames=True),
|
151 |
modality="video",
|
|
|
1 |
+
"""
|
2 |
+
Real-time video classification using VJEPA2 model with streaming capabilities.
|
3 |
+
|
4 |
+
This module implements a real-time video classification system that:
|
5 |
+
1. Captures video frames from a webcam
|
6 |
+
2. Processes batches of frames using the V-JEPA 2 model
|
7 |
+
3. Displays predictions overlaid on the video stream
|
8 |
+
4. Maintains a history of recent predictions
|
9 |
+
|
10 |
+
The system uses FastRTC for video streaming and Gradio for the web interface.
|
11 |
+
"""
|
12 |
import cv2
|
13 |
import time
|
14 |
import torch
|
|
|
19 |
from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor
|
20 |
|
21 |
|
22 |
+
# Model configuration
|
23 |
+
CHECKPOINT = "qubvel-hf/vjepa2-vitl-fpc16-256-ssv2" # Pre-trained VJEPA2 model checkpoint
|
24 |
+
TORCH_DTYPE = torch.float16 # Use half precision for faster inference
|
25 |
+
TORCH_DEVICE = "cuda" # Use GPU for inference
|
26 |
+
UPDATE_EVERY_N_FRAMES = 64 # How often to update predictions (in frames)
|
27 |
|
28 |
|
29 |
def add_text_on_image(image, text):
|
30 |
+
"""
|
31 |
+
Overlays text on an image with a black background bar at the top.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
image (np.ndarray): Input image to add text to
|
35 |
+
text (str): Text to overlay on the image
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
np.ndarray: Image with text overlaid
|
39 |
+
"""
|
40 |
# Add a black background to the text
|
41 |
image[:70] = 0
|
42 |
|
|
|
78 |
|
79 |
|
80 |
class RunningFramesCache:
|
81 |
+
"""
|
82 |
+
Maintains a rolling buffer of video frames for model input.
|
83 |
+
|
84 |
+
This class manages a fixed-size queue of frames, keeping only the most recent
|
85 |
+
frames needed for model inference. It supports subsampling frames to reduce
|
86 |
+
memory usage and processing requirements.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
save_every_k_frame (int): Only save every k-th frame (for subsampling)
|
90 |
+
max_frames (int): Maximum number of frames to keep in cache
|
91 |
+
"""
|
92 |
def __init__(self, save_every_k_frame: int = 1, max_frames: int = 16):
|
93 |
self.save_every_k_frame = save_every_k_frame
|
94 |
self.max_frames = max_frames
|
|
|
107 |
|
108 |
|
109 |
class RunningResult:
|
110 |
+
"""
|
111 |
+
Maintains a history of recent model predictions with timestamps.
|
112 |
+
|
113 |
+
This class keeps track of the most recent predictions made by the model,
|
114 |
+
including timestamps for each prediction. It provides formatted output
|
115 |
+
for display in the UI.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
max_predictions (int): Maximum number of predictions to keep in history
|
119 |
+
"""
|
120 |
def __init__(self, max_predictions: int = 4):
|
121 |
self.predictions = []
|
122 |
self.max_predictions = max_predictions
|
|
|
143 |
|
144 |
|
145 |
class FrameProcessingCallback:
|
146 |
+
"""
|
147 |
+
Handles real-time video frame processing and model inference.
|
148 |
+
|
149 |
+
This class is responsible for:
|
150 |
+
1. Loading and managing the V-JEPA 2 model
|
151 |
+
2. Processing incoming video frames
|
152 |
+
3. Running model inference at regular intervals
|
153 |
+
4. Managing frame caching and prediction history
|
154 |
+
5. Formatting output for display
|
155 |
+
|
156 |
+
The callback is called for each frame from the video stream and handles
|
157 |
+
the coordination between frame capture, model inference, and result display.
|
158 |
+
"""
|
159 |
def __init__(self):
|
160 |
# Loading model and processor
|
161 |
self.model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16)
|
|
|
202 |
return image, AdditionalOutputs(formatted_predictions)
|
203 |
|
204 |
|
205 |
+
# Initialize the video stream with processing callback
|
206 |
stream = Stream(
|
207 |
handler=VideoStreamHandler(FrameProcessingCallback(), skip_frames=True),
|
208 |
modality="video",
|