qubvel-hf HF Staff commited on
Commit
79f197e
·
1 Parent(s): bc21b9d

Add comments and docstrings

Browse files
Files changed (1) hide show
  1. app.py +61 -4
app.py CHANGED
@@ -1,3 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
  import time
3
  import torch
@@ -8,13 +19,24 @@ from fastrtc import Stream, VideoStreamHandler, AdditionalOutputs
8
  from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor
9
 
10
 
11
- CHECKPOINT = "qubvel-hf/vjepa2-vitl-fpc16-256-ssv2"
12
- TORCH_DTYPE = torch.float16
13
- TORCH_DEVICE = "cuda"
14
- UPDATE_EVERY_N_FRAMES = 64
 
15
 
16
 
17
  def add_text_on_image(image, text):
 
 
 
 
 
 
 
 
 
 
18
  # Add a black background to the text
19
  image[:70] = 0
20
 
@@ -56,6 +78,17 @@ def add_text_on_image(image, text):
56
 
57
 
58
  class RunningFramesCache:
 
 
 
 
 
 
 
 
 
 
 
59
  def __init__(self, save_every_k_frame: int = 1, max_frames: int = 16):
60
  self.save_every_k_frame = save_every_k_frame
61
  self.max_frames = max_frames
@@ -74,6 +107,16 @@ class RunningFramesCache:
74
 
75
 
76
  class RunningResult:
 
 
 
 
 
 
 
 
 
 
77
  def __init__(self, max_predictions: int = 4):
78
  self.predictions = []
79
  self.max_predictions = max_predictions
@@ -100,6 +143,19 @@ class RunningResult:
100
 
101
 
102
  class FrameProcessingCallback:
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def __init__(self):
104
  # Loading model and processor
105
  self.model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16)
@@ -146,6 +202,7 @@ class FrameProcessingCallback:
146
  return image, AdditionalOutputs(formatted_predictions)
147
 
148
 
 
149
  stream = Stream(
150
  handler=VideoStreamHandler(FrameProcessingCallback(), skip_frames=True),
151
  modality="video",
 
1
+ """
2
+ Real-time video classification using VJEPA2 model with streaming capabilities.
3
+
4
+ This module implements a real-time video classification system that:
5
+ 1. Captures video frames from a webcam
6
+ 2. Processes batches of frames using the V-JEPA 2 model
7
+ 3. Displays predictions overlaid on the video stream
8
+ 4. Maintains a history of recent predictions
9
+
10
+ The system uses FastRTC for video streaming and Gradio for the web interface.
11
+ """
12
  import cv2
13
  import time
14
  import torch
 
19
  from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor
20
 
21
 
22
+ # Model configuration
23
+ CHECKPOINT = "qubvel-hf/vjepa2-vitl-fpc16-256-ssv2" # Pre-trained VJEPA2 model checkpoint
24
+ TORCH_DTYPE = torch.float16 # Use half precision for faster inference
25
+ TORCH_DEVICE = "cuda" # Use GPU for inference
26
+ UPDATE_EVERY_N_FRAMES = 64 # How often to update predictions (in frames)
27
 
28
 
29
  def add_text_on_image(image, text):
30
+ """
31
+ Overlays text on an image with a black background bar at the top.
32
+
33
+ Args:
34
+ image (np.ndarray): Input image to add text to
35
+ text (str): Text to overlay on the image
36
+
37
+ Returns:
38
+ np.ndarray: Image with text overlaid
39
+ """
40
  # Add a black background to the text
41
  image[:70] = 0
42
 
 
78
 
79
 
80
  class RunningFramesCache:
81
+ """
82
+ Maintains a rolling buffer of video frames for model input.
83
+
84
+ This class manages a fixed-size queue of frames, keeping only the most recent
85
+ frames needed for model inference. It supports subsampling frames to reduce
86
+ memory usage and processing requirements.
87
+
88
+ Args:
89
+ save_every_k_frame (int): Only save every k-th frame (for subsampling)
90
+ max_frames (int): Maximum number of frames to keep in cache
91
+ """
92
  def __init__(self, save_every_k_frame: int = 1, max_frames: int = 16):
93
  self.save_every_k_frame = save_every_k_frame
94
  self.max_frames = max_frames
 
107
 
108
 
109
  class RunningResult:
110
+ """
111
+ Maintains a history of recent model predictions with timestamps.
112
+
113
+ This class keeps track of the most recent predictions made by the model,
114
+ including timestamps for each prediction. It provides formatted output
115
+ for display in the UI.
116
+
117
+ Args:
118
+ max_predictions (int): Maximum number of predictions to keep in history
119
+ """
120
  def __init__(self, max_predictions: int = 4):
121
  self.predictions = []
122
  self.max_predictions = max_predictions
 
143
 
144
 
145
  class FrameProcessingCallback:
146
+ """
147
+ Handles real-time video frame processing and model inference.
148
+
149
+ This class is responsible for:
150
+ 1. Loading and managing the V-JEPA 2 model
151
+ 2. Processing incoming video frames
152
+ 3. Running model inference at regular intervals
153
+ 4. Managing frame caching and prediction history
154
+ 5. Formatting output for display
155
+
156
+ The callback is called for each frame from the video stream and handles
157
+ the coordination between frame capture, model inference, and result display.
158
+ """
159
  def __init__(self):
160
  # Loading model and processor
161
  self.model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16)
 
202
  return image, AdditionalOutputs(formatted_predictions)
203
 
204
 
205
+ # Initialize the video stream with processing callback
206
  stream = Stream(
207
  handler=VideoStreamHandler(FrameProcessingCallback(), skip_frames=True),
208
  modality="video",