Spaces:
Sleeping
Sleeping
import asyncio | |
import json | |
import logging | |
import time | |
import uuid | |
from collections.abc import Callable, Coroutine | |
from datetime import UTC, datetime, timedelta | |
from fractions import Fraction | |
from functools import lru_cache | |
import av | |
import cv2 | |
import numpy as np | |
from aiortc import ( | |
RTCConfiguration, | |
RTCIceCandidate, | |
RTCIceServer, | |
RTCPeerConnection, | |
RTCSessionDescription, | |
VideoStreamTrack, | |
) | |
from fastapi import WebSocket, WebSocketDisconnect | |
from .models import ( | |
EmergencyStopMessageDict, | |
ErrorMessageDict, | |
HeartbeatAckMessageDict, | |
JoinedMessageDict, | |
MessageType, | |
ParticipantJoinedMessageDict, | |
ParticipantRole, | |
RawWebRTCMessageType, | |
RawWebRTCSignalingMessage, | |
RecoveryConfig, | |
RecoveryPolicy, | |
RecoveryTriggeredMessageDict, | |
StatusUpdateMessageDict, | |
StreamStartedMessageDict, | |
StreamStatsMessageDict, | |
StreamStoppedMessageDict, | |
VideoConfig, | |
# Core data structures | |
VideoConfigUpdateMessageDict, | |
WebRTCAnswerMessageDict, | |
WebRTCIceMessageDict, | |
WebRTCOfferMessageDict, | |
WebSocketMessageDict, | |
) | |
logger = logging.getLogger(__name__) | |
# ============= FRAME CACHE ============= | |
# Cache up to 8 different resolutions | |
def get_black_frame(width: int, height: int) -> np.ndarray: | |
"""Get cached black frame for given dimensions""" | |
return np.zeros((height, width, 3), dtype=np.uint8) | |
# Cache up to 4 different info frame variants | |
def get_connection_info_frame( | |
width: int, | |
height: int, | |
bg_color: tuple[int, int, int], | |
text_color: tuple[int, int, int], | |
) -> np.ndarray: | |
"""Get cached connection info frame""" | |
frame = np.full((height, width, 3), bg_color, dtype=np.uint8) | |
# Add status text | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
font_scale = min(width / 640, height / 480) * 1.2 # Scale with resolution | |
thickness = max(1, int(font_scale)) | |
# Main status message | |
text = "RECONNECTING..." | |
text_size = cv2.getTextSize(text, font, font_scale, thickness)[0] | |
text_x = (width - text_size[0]) // 2 | |
text_y = height // 2 | |
cv2.putText(frame, text, (text_x, text_y), font, font_scale, text_color, thickness) | |
# Subtitle | |
subtitle = "Video stream interrupted" | |
subtitle_scale = font_scale * 0.5 | |
subtitle_size = cv2.getTextSize(subtitle, font, subtitle_scale, 1)[0] | |
subtitle_x = (width - subtitle_size[0]) // 2 | |
subtitle_y = text_y + int(40 * font_scale) | |
cv2.putText( | |
frame, | |
subtitle, | |
(subtitle_x, subtitle_y), | |
font, | |
subtitle_scale, | |
text_color, | |
1, | |
) | |
return frame | |
def add_frame_hold_indicator(frame: np.ndarray, reuse_count: int) -> np.ndarray: | |
"""Add a subtle indicator that this frame is being held""" | |
height, width = frame.shape[:2] | |
# Create a small colored indicator in top-right corner | |
indicator_size = max(6, min(width, height) // 80) # Scale with frame size | |
colors = [ | |
(255, 200, 0), | |
(255, 150, 0), | |
(255, 100, 0), | |
(255, 50, 0), | |
] # Yellow to red | |
color = colors[min(reuse_count - 1, len(colors) - 1)] | |
# Add the indicator | |
y_start = 10 | |
y_end = y_start + indicator_size | |
x_start = width - 20 | |
x_end = x_start + indicator_size | |
if y_end < height and x_end < width: | |
frame[y_start:y_end, x_start:x_end] = color | |
return frame | |
# ============= VIDEO FRAME TRACK ============= | |
class VideoFrameTrack(VideoStreamTrack): | |
"""Video track for WebRTC streaming with recovery support""" | |
def __init__(self, recovery_config: RecoveryConfig | None = None): | |
super().__init__() | |
self.frame_queue = asyncio.Queue(maxsize=2) # Small buffer for low latency | |
self.pts = 0 | |
self.time_base = 1 / 30 # 30 FPS | |
# Frame recovery system | |
self.config = recovery_config or RecoveryConfig() | |
self.last_good_frame = None | |
self.last_good_frame_time = 0 | |
self.frame_reuse_count = 0 | |
self.last_frame_dimensions = (480, 640) # Default fallback dimensions | |
logger.info( | |
f"VideoFrameTrack created with recovery policy: {self.config.recovery_policy.value}" | |
) | |
async def recv(self) -> av.VideoFrame: | |
"""Get next video frame for WebRTC transmission with smart recovery""" | |
current_time = time.time() * 1000 # Convert to milliseconds | |
try: | |
# Try to get a fresh frame with timeout | |
frame_data = await asyncio.wait_for(self.frame_queue.get(), timeout=0.1) | |
if frame_data: | |
# Decode JPEG to numpy array | |
nparr = np.frombuffer(frame_data, np.uint8) | |
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
if img is not None: | |
# Convert BGR to RGB for WebRTC | |
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
# Store as last good frame for recovery | |
self.last_good_frame = img_rgb.copy() | |
self.last_good_frame_time = current_time | |
self.last_frame_dimensions = ( | |
img_rgb.shape[0], | |
img_rgb.shape[1], | |
) | |
self.frame_reuse_count = 0 | |
# Create video frame | |
frame = av.VideoFrame.from_ndarray(img_rgb, format="rgb24") | |
else: | |
# JPEG decode failed - use recovery | |
frame = self._create_recovery_frame(current_time) | |
else: | |
# No frame data - use recovery | |
frame = self._create_recovery_frame(current_time) | |
except TimeoutError: | |
# Timeout waiting for frame - use smart recovery | |
frame = self._create_recovery_frame(current_time) | |
# Set timing | |
frame.pts = self.pts | |
frame.time_base = Fraction(1, 30) | |
self.pts += 1 | |
return frame | |
def _create_recovery_frame(self, current_time: float): | |
"""Create a recovery frame using the configured policy""" | |
# Determine which policy to use | |
time_since_last_good = ( | |
current_time - self.last_good_frame_time | |
if self.last_good_frame is not None | |
else float("inf") | |
) | |
if ( | |
self.last_good_frame is not None | |
and time_since_last_good < self.config.frame_timeout_ms | |
and self.frame_reuse_count < self.config.max_frame_reuse_count | |
): | |
# Use primary recovery policy | |
policy = self.config.recovery_policy | |
self.frame_reuse_count += 1 | |
else: | |
# Use fallback policy | |
policy = self.config.fallback_policy | |
# Generate frame based on policy | |
recovery_frame = self._apply_recovery_policy(policy) | |
frame = av.VideoFrame.from_ndarray(recovery_frame, format="rgb24") | |
frame.pts = self.pts | |
frame.time_base = Fraction(1, 30) | |
return frame | |
def _apply_recovery_policy(self, policy: RecoveryPolicy) -> np.ndarray: | |
"""Apply the specified recovery policy""" | |
height, width = self.last_frame_dimensions | |
if policy == RecoveryPolicy.FREEZE_LAST_FRAME: | |
if self.last_good_frame is not None: | |
frame = self.last_good_frame.copy() | |
if self.config.show_hold_indicators: | |
frame = add_frame_hold_indicator(frame, self.frame_reuse_count) | |
return frame | |
return get_black_frame(width, height) | |
if policy == RecoveryPolicy.CONNECTION_INFO: | |
return get_connection_info_frame( | |
width, | |
height, | |
self.config.info_frame_bg_color, | |
self.config.info_frame_text_color, | |
) | |
if policy == RecoveryPolicy.BLACK_SCREEN: | |
return get_black_frame(width, height) | |
if policy == RecoveryPolicy.FADE_TO_BLACK: | |
if self.last_good_frame is not None: | |
frame = self.last_good_frame.copy() | |
# Apply fade effect | |
fade_factor = max( | |
0.0, | |
1.0 | |
- ( | |
self.frame_reuse_count | |
* self.config.fade_intensity | |
/ self.config.max_frame_reuse_count | |
), | |
) | |
frame = (frame * fade_factor).astype(np.uint8) | |
return frame | |
return get_black_frame(width, height) | |
if policy == RecoveryPolicy.OVERLAY_STATUS: | |
if self.last_good_frame is not None: | |
frame = self.last_good_frame.copy() | |
# Create overlay | |
overlay = np.full_like( | |
frame, self.config.info_frame_bg_color, dtype=np.uint8 | |
) | |
# Add text to overlay | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
font_scale = min(width / 640, height / 480) * 0.8 | |
text = "RECONNECTING" | |
text_size = cv2.getTextSize(text, font, font_scale, 2)[0] | |
text_x = (width - text_size[0]) // 2 | |
text_y = height // 2 | |
cv2.putText( | |
overlay, | |
text, | |
(text_x, text_y), | |
font, | |
font_scale, | |
self.config.info_frame_text_color, | |
2, | |
) | |
# Blend overlay with original frame | |
alpha = self.config.overlay_opacity | |
frame = cv2.addWeighted(frame, 1 - alpha, overlay, alpha, 0) | |
return frame | |
return get_black_frame(width, height) | |
# Unknown policy, fallback to black | |
logger.error(f"Unknown recovery policy: {policy}") | |
return get_black_frame(width, height) | |
def add_frame(self, frame_data: bytes) -> None: | |
"""Add frame to queue (non-blocking)""" | |
try: | |
self.frame_queue.put_nowait(frame_data) | |
except asyncio.QueueFull: | |
# Drop oldest frame for low latency | |
try: | |
self.frame_queue.get_nowait() | |
self.frame_queue.put_nowait(frame_data) | |
logger.debug("Dropped old frame to maintain low latency") | |
except asyncio.QueueEmpty: | |
pass | |
# ============= WEBRTC CONNECTION ============= | |
class WebRTCConnection: | |
"""WebRTC connection handling both producers and consumers""" | |
def __init__( | |
self, | |
client_id: str, | |
room_id: str, | |
on_frame_callback: Callable, | |
video_core: "VideoCore", | |
): | |
self.client_id = client_id | |
self.room_id = room_id | |
self.on_frame = on_frame_callback | |
self.video_core = video_core # Reference to core for broadcasting | |
self.pc = None # RTCPeerConnection | |
self.video_track: VideoFrameTrack | None = None | |
self.is_producer = False | |
self.is_consumer = False | |
self.background_tasks: set[asyncio.Task] = set() | |
async def initialize(self): | |
"""Initialize peer connection""" | |
config = RTCConfiguration( | |
iceServers=[ | |
RTCIceServer(urls=["stun:stun.l.google.com:19302"]), | |
RTCIceServer(urls=["stun:stun1.l.google.com:19302"]), | |
] | |
) | |
self.pc = RTCPeerConnection(configuration=config) | |
# Set up event handlers | |
self.pc.on("track", self._on_track) | |
self.pc.on("connectionstatechange", self._on_connection_state) | |
self.pc.on("iceconnectionstatechange", self._on_ice_state) | |
logger.info(f"WebRTC connection {self.client_id} initialized") | |
def _on_connection_state(self) -> None: | |
logger.info( | |
f"WebRTC {self.client_id} connection state: {self.pc.connectionState}" | |
) | |
def _on_ice_state(self) -> None: | |
logger.info(f"WebRTC {self.client_id} ICE state: {self.pc.iceConnectionState}") | |
def _on_track(self, track: av.VideoStream) -> None: | |
"""Handle incoming video track from producer""" | |
logger.info(f"WebRTC {self.client_id} received track: {track.kind}") | |
if track.kind == "video" or track.type == "video": | |
self.is_producer = True | |
logger.info(f"WebRTC {self.client_id} is now a PRODUCER") | |
# Process incoming video frames | |
self._add_background_task(self._process_incoming_video(track)) | |
async def _process_incoming_video(self, track): | |
"""Process video frames from producer""" | |
frame_count = 0 | |
try: | |
while True: | |
frame = await track.recv() | |
frame_count += 1 | |
# Convert to OpenCV format | |
img = frame.to_ndarray(format="rgb24") | |
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
# Encode as JPEG | |
success, jpeg_data = cv2.imencode( | |
".jpg", img_bgr, [cv2.IMWRITE_JPEG_QUALITY, 80] | |
) | |
if success: | |
# Send to processing pipeline | |
await self.on_frame(self.client_id, jpeg_data.tobytes()) | |
# Broadcast frames to consumers | |
if self.video_core: | |
await self.video_core.broadcast_to_consumers( | |
self.room_id, jpeg_data.tobytes() | |
) | |
if frame_count % 30 == 0: # Log every second | |
logger.info( | |
f"WebRTC {self.client_id} processed {frame_count} frames" | |
) | |
except Exception: | |
logger.exception(f"Error processing video track {self.client_id}") | |
async def handle_offer(self, sdp: str, participant_role: str | None = None) -> str: | |
"""Handle WebRTC offer and determine if producer or consumer""" | |
try: | |
logger.debug( | |
f"Processing offer for {self.client_id}, SDP length: {len(sdp)}" | |
) | |
offer = RTCSessionDescription(sdp=sdp, type="offer") | |
await self.pc.setRemoteDescription(offer) | |
# Use explicit role if provided, otherwise fall back to detection | |
if participant_role == "producer": | |
self.is_producer = True | |
self.is_consumer = False | |
logger.info(f"WebRTC {self.client_id} set as PRODUCER (explicit role)") | |
elif participant_role == "consumer": | |
self.is_consumer = True | |
self.is_producer = False | |
room = self.video_core.rooms.get(self.room_id) | |
recovery_config = room.recovery_config if room else RecoveryConfig() | |
self.video_track = VideoFrameTrack(recovery_config) | |
self.pc.addTrack(self.video_track) | |
logger.info( | |
f"WebRTC {self.client_id} set as CONSUMER (explicit role) - video track added" | |
) | |
else: | |
# Auto-detection logic (fallback) | |
has_recvonly = "a=recvonly" in sdp | |
has_sendonly = "a=sendonly" in sdp | |
has_video = "m=video" in sdp | |
# Check if there are video track sources in the offer (indicates producer) | |
has_video_sources = "a=ssrc:" in sdp and "m=video" in sdp | |
logger.info(f"🔍 Role detection for {self.client_id}:") | |
logger.info(f" - has_recvonly: {has_recvonly}") | |
logger.info(f" - has_sendonly: {has_sendonly}") | |
logger.info(f" - has_video: {has_video}") | |
logger.info(f" - has_video_sources: {has_video_sources}") | |
if has_recvonly: | |
is_consumer_request = True | |
logger.info(" - CONSUMER detected: has a=recvonly") | |
elif has_sendonly or has_video_sources: | |
is_consumer_request = False | |
logger.info( | |
" - PRODUCER detected: has a=sendonly or video sources" | |
) | |
elif has_video: | |
# Default: if it has video but no clear direction, assume consumer | |
is_consumer_request = True | |
logger.info( | |
" - CONSUMER detected: has video but no clear direction" | |
) | |
else: | |
# No video at all, treat as consumer | |
is_consumer_request = True | |
logger.info(" - CONSUMER detected: no video") | |
if is_consumer_request: | |
# This is a consumer - add video track for sending TO the consumer | |
self.is_consumer = True | |
room = self.video_core.rooms.get(self.room_id) | |
recovery_config = room.recovery_config if room else RecoveryConfig() | |
self.video_track = VideoFrameTrack(recovery_config) | |
self.pc.addTrack(self.video_track) | |
logger.info( | |
f"WebRTC {self.client_id} is now a CONSUMER - video track added" | |
) | |
else: | |
# This is a producer | |
self.is_consumer = False | |
self.is_producer = True | |
logger.info(f"WebRTC {self.client_id} is a PRODUCER") | |
# Create answer | |
answer = await self.pc.createAnswer() | |
await self.pc.setLocalDescription(answer) | |
# Wait for ICE gathering with timeout | |
timeout_count = 0 | |
while self.pc.iceGatheringState != "complete" and timeout_count < 50: | |
await asyncio.sleep(0.1) | |
timeout_count += 1 | |
if timeout_count >= 50: | |
logger.warning(f"ICE gathering timeout for {self.client_id}") | |
except Exception: | |
logger.exception(f"Error in handle_offer for {self.client_id}") | |
raise | |
else: | |
return self.pc.localDescription.sdp | |
async def add_ice_candidate(self, candidate_data: dict): | |
"""Add ICE candidate""" | |
try: | |
if candidate_data.get("end"): | |
return | |
candidate_str = candidate_data.get("candidate", "") | |
sdp_mid = candidate_data.get("sdpMid") | |
sdp_m_line_index = candidate_data.get("sdpMLineIndex") | |
if not candidate_str: | |
logger.debug(f"Skipping empty ICE candidate for {self.client_id}") | |
return | |
# Parse the candidate string | |
parts = candidate_str.split() | |
if len(parts) < 8 or not parts[0].startswith("candidate:"): | |
logger.warning( | |
f"Invalid candidate format for {self.client_id}: {candidate_str}" | |
) | |
return | |
try: | |
foundation = parts[0].split(":")[1] | |
component = int(parts[1]) | |
protocol = parts[2].lower() | |
priority = int(parts[3]) | |
ip = parts[4] | |
port = int(parts[5]) | |
typ = parts[7] if parts[6] == "typ" else "host" | |
candidate = RTCIceCandidate( | |
foundation=foundation, | |
component=component, | |
protocol=protocol, | |
priority=priority, | |
ip=ip, | |
port=port, | |
type=typ, | |
sdpMid=sdp_mid, | |
sdpMLineIndex=sdp_m_line_index, | |
) | |
await self.pc.addIceCandidate(candidate) | |
logger.debug(f"ICE candidate added for {self.client_id}: {typ}") | |
except (ValueError, IndexError): | |
logger.exception(f"Failed to parse ICE candidate for {self.client_id}") | |
except Exception: | |
logger.exception(f"Failed to add ICE candidate for {self.client_id}") | |
def send_video_frame(self, frame_data: bytes): | |
"""Send video frame to consumer""" | |
if self.is_consumer and self.video_track: | |
self.video_track.add_frame(frame_data) | |
async def close(self): | |
"""Close connection""" | |
if self.pc: | |
await self.pc.close() | |
def _add_background_task(self, coro: Coroutine): | |
"""Add a background task with automatic cleanup""" | |
task = asyncio.create_task(coro) | |
self.background_tasks.add(task) | |
task.add_done_callback(self.background_tasks.discard) | |
return task | |
# ============= VIDEO ROOM (simplified) ============= | |
class VideoRoom: | |
"""Simple video room with producer/consumer pattern""" | |
def __init__( | |
self, | |
room_id: str, | |
workspace_id: str, | |
config: VideoConfig = None, | |
recovery_config: RecoveryConfig = None, | |
): | |
self.id = room_id | |
self.workspace_id = workspace_id | |
self.config = config or VideoConfig() | |
self.recovery_config = recovery_config or RecoveryConfig() | |
# Participants (same pattern as robotics) | |
self.producer: str | None = None | |
self.consumers: list[str] = [] | |
# Video state | |
self.last_frame: bytes | None = None | |
self.frame_count = 0 | |
self.total_bytes = 0 | |
self.start_time = datetime.now(tz=UTC) | |
self.last_frame_time: datetime | None = None | |
# Activity tracking | |
self.created_at = datetime.now(tz=UTC) | |
self.last_activity = datetime.now(tz=UTC) | |
# ============= VIDEO CORE (main class) ============= | |
class VideoCore: | |
"""Core video system""" | |
def __init__(self): | |
# Nested structure: workspace_id -> room_id -> VideoRoom | |
self.workspaces: dict[str, dict[str, VideoRoom]] = {} | |
self.websocket_connections: dict[str, WebSocket] = {} | |
self.webrtc_connections: dict[str, WebRTCConnection] = {} | |
self.connection_metadata: dict[str, dict] = {} | |
# Track background tasks to prevent garbage collection | |
self.background_tasks: set = set() | |
# Cleanup configuration | |
self.inactivity_timeout = timedelta(hours=1) # 1 hour of inactivity | |
self.cleanup_interval = timedelta(minutes=15) # Check every 15 minutes | |
# Start cleanup task | |
self._cleanup_task = None | |
self._start_cleanup_task() | |
def _start_cleanup_task(self): | |
"""Start the background cleanup task""" | |
async def cleanup_loop(): | |
while True: | |
try: | |
await asyncio.sleep(self.cleanup_interval.total_seconds()) | |
await self._cleanup_inactive_rooms() | |
except Exception: | |
logger.exception("Error in video cleanup task") | |
try: | |
loop = asyncio.get_event_loop() | |
self._cleanup_task = loop.create_task(cleanup_loop()) | |
logger.info("Started video room cleanup task") | |
except RuntimeError: | |
# No event loop running yet, cleanup will start when first room is created | |
logger.info("No event loop running, video cleanup task will start later") | |
async def _cleanup_inactive_rooms(self): | |
"""Remove rooms that have been inactive for more than the timeout period""" | |
current_time = datetime.now(tz=UTC) | |
rooms_to_remove = [] | |
for workspace_id, rooms in self.workspaces.items(): | |
for room_id, room in rooms.items(): | |
# Check if room has any active connections | |
has_active_connections = False | |
room_last_activity = room.last_activity | |
# Check all connections for this room to find most recent activity | |
for metadata in self.connection_metadata.values(): | |
if ( | |
metadata.get("workspace_id") == workspace_id | |
and metadata.get("room_id") == room_id | |
): | |
has_active_connections = True | |
if ( | |
metadata.get("last_activity") | |
and metadata["last_activity"] > room_last_activity | |
): | |
room_last_activity = metadata["last_activity"] | |
# If no active connections, use room's last activity | |
if not has_active_connections: | |
time_since_activity = current_time - room_last_activity | |
if time_since_activity > self.inactivity_timeout: | |
rooms_to_remove.append((workspace_id, room_id)) | |
logger.info( | |
f"Marking video room {room_id} in workspace {workspace_id} for cleanup " | |
f"(inactive for {time_since_activity})" | |
) | |
# Remove inactive rooms | |
for workspace_id, room_id in rooms_to_remove: | |
if self.delete_room(workspace_id, room_id): | |
logger.info( | |
f"Auto-removed inactive video room {room_id} from workspace {workspace_id}" | |
) | |
if rooms_to_remove: | |
logger.info(f"Cleaned up {len(rooms_to_remove)} inactive video rooms") | |
def _update_room_activity(self, workspace_id: str, room_id: str): | |
"""Update the last activity timestamp for a room""" | |
room = self._get_room(workspace_id, room_id) | |
if room: | |
room.last_activity = datetime.now(tz=UTC) | |
# ============= ROOM MANAGEMENT (same pattern as robotics) ============= | |
def create_room( | |
self, | |
workspace_id: str | None = None, | |
room_id: str | None = None, | |
config: VideoConfig = None, | |
recovery_config: RecoveryConfig = None, | |
) -> tuple[str, str]: | |
"""Create video room and return (workspace_id, room_id)""" | |
workspace_id = workspace_id or str(uuid.uuid4()) | |
room_id = room_id or str(uuid.uuid4()) | |
# Initialize workspace if it doesn't exist | |
if workspace_id not in self.workspaces: | |
self.workspaces[workspace_id] = {} | |
room = VideoRoom(room_id, workspace_id, config, recovery_config) | |
self.workspaces[workspace_id][room_id] = room | |
# Start cleanup task if not already running | |
if self._cleanup_task is None: | |
self._start_cleanup_task() | |
logger.info(f"Created video room {room_id} in workspace {workspace_id}") | |
return workspace_id, room_id | |
def list_rooms(self, workspace_id: str) -> list[dict]: | |
"""List all video rooms in a specific workspace""" | |
if workspace_id not in self.workspaces: | |
return [] | |
return [ | |
{ | |
"id": room.id, | |
"workspace_id": room.workspace_id, | |
"participants": { | |
"producer": room.producer, | |
"consumers": room.consumers, | |
"total": len(room.consumers) + (1 if room.producer else 0), | |
}, | |
"frame_count": room.frame_count, | |
"config": { | |
"resolution": room.config.resolution, | |
"framerate": room.config.framerate, | |
# "encoding": room.config.encoding.value | |
# if room.config.encoding | |
# else "vp8", | |
"bitrate": room.config.bitrate, | |
"quality": room.config.quality, | |
}, | |
"has_producer": room.producer is not None, | |
"active_consumers": len(room.consumers), | |
} | |
for room in self.workspaces[workspace_id].values() | |
] | |
def delete_room(self, workspace_id: str, room_id: str) -> bool: | |
"""Delete video room from workspace""" | |
if ( | |
workspace_id not in self.workspaces | |
or room_id not in self.workspaces[workspace_id] | |
): | |
return False | |
# Cleanup connections | |
for conn_id in list(self.webrtc_connections.keys()): | |
conn = self.webrtc_connections[conn_id] | |
if ( | |
conn.room_id == room_id | |
and getattr(conn, "workspace_id", None) == workspace_id | |
): | |
self._add_background_task(conn.close()) | |
del self.webrtc_connections[conn_id] | |
del self.workspaces[workspace_id][room_id] | |
logger.info(f"Deleted video room {room_id} from workspace {workspace_id}") | |
return True | |
def get_room_state(self, workspace_id: str, room_id: str) -> dict: | |
"""Get room state""" | |
room = self._get_room(workspace_id, room_id) | |
if not room: | |
return {"error": "Room not found"} | |
return { | |
"room_id": room_id, | |
"workspace_id": workspace_id, | |
"participants": { | |
"producer": room.producer, | |
"consumers": room.consumers, | |
"total": len(room.consumers) + (1 if room.producer else 0), | |
}, | |
"frame_count": room.frame_count, | |
"last_frame_time": room.last_frame_time, | |
"current_config": { | |
"resolution": room.config.resolution, | |
"framerate": room.config.framerate, | |
"encoding": room.config.encoding.value | |
if room.config.encoding | |
else "vp8", | |
"bitrate": room.config.bitrate, | |
"quality": room.config.quality, | |
}, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
def get_room_info(self, workspace_id: str, room_id: str) -> dict: | |
"""Get basic room info""" | |
room = self._get_room(workspace_id, room_id) | |
if not room: | |
return {"error": "Room not found"} | |
return { | |
"id": room.id, | |
"workspace_id": room.workspace_id, | |
"participants": { | |
"producer": room.producer, | |
"consumers": room.consumers, | |
"total": len(room.consumers) + (1 if room.producer else 0), | |
}, | |
"frame_count": room.frame_count, | |
"config": { | |
"resolution": room.config.resolution, | |
"framerate": room.config.framerate, | |
"encoding": room.config.encoding.value | |
if room.config.encoding | |
else "vp8", | |
"bitrate": room.config.bitrate, | |
"quality": room.config.quality, | |
}, | |
"has_producer": room.producer is not None, | |
"active_consumers": len(room.consumers), | |
} | |
def _get_room(self, workspace_id: str, room_id: str) -> VideoRoom | None: | |
"""Get room by workspace and room ID""" | |
if workspace_id not in self.workspaces: | |
return None | |
return self.workspaces[workspace_id].get(room_id) | |
# ============= PARTICIPANT MANAGEMENT ============= | |
def join_room( | |
self, | |
workspace_id: str, | |
room_id: str, | |
participant_id: str, | |
role: ParticipantRole, | |
) -> bool: | |
"""Join room as producer or consumer""" | |
room = self._get_room(workspace_id, room_id) | |
if not room: | |
return False | |
if role == ParticipantRole.PRODUCER: | |
if room.producer is None: | |
room.producer = participant_id | |
self._update_room_activity(workspace_id, room_id) | |
logger.info( | |
f"Producer {participant_id} joined video room {room_id} in workspace {workspace_id}" | |
) | |
# Broadcast producer join to existing consumers | |
self._add_background_task( | |
self._broadcast_participant_joined( | |
workspace_id, room_id, participant_id, role | |
) | |
) | |
return True | |
logger.warning( | |
f"Producer {participant_id} failed to join room {room_id} - room already has producer" | |
) | |
return False | |
if role == ParticipantRole.CONSUMER: | |
if participant_id not in room.consumers: | |
room.consumers.append(participant_id) | |
self._update_room_activity(workspace_id, room_id) | |
logger.info( | |
f"Consumer {participant_id} joined video room {room_id} in workspace {workspace_id}" | |
) | |
# Broadcast consumer join to producer and other consumers | |
self._add_background_task( | |
self._broadcast_participant_joined( | |
workspace_id, room_id, participant_id, role | |
) | |
) | |
return True | |
return False | |
return False | |
def leave_room(self, workspace_id: str, room_id: str, participant_id: str): | |
"""Leave room""" | |
room = self._get_room(workspace_id, room_id) | |
if not room: | |
return | |
role = None | |
if room.producer == participant_id: | |
room.producer = None | |
role = ParticipantRole.PRODUCER | |
logger.info( | |
f"Producer {participant_id} left video room {room_id} in workspace {workspace_id}" | |
) | |
elif participant_id in room.consumers: | |
room.consumers.remove(participant_id) | |
role = ParticipantRole.CONSUMER | |
logger.info( | |
f"Consumer {participant_id} left video room {room_id} in workspace {workspace_id}" | |
) | |
# Broadcast participant left event | |
if role: | |
self._add_background_task( | |
self._broadcast_participant_left( | |
workspace_id, room_id, participant_id, role | |
) | |
) | |
# ============= WEBRTC HANDLING ============= | |
async def handle_webrtc_signal( | |
self, | |
workspace_id: str, | |
room_id: str, | |
client_id: str, | |
message: RawWebRTCSignalingMessage, | |
participant_role: str | None = None, | |
): | |
"""Handle WebRTC signaling for peer-to-peer connections""" | |
if ( | |
workspace_id not in self.workspaces | |
or room_id not in self.workspaces[workspace_id] | |
): | |
msg = f"Room {room_id} not found in workspace {workspace_id}" | |
raise ValueError(msg) | |
if message["type"] == RawWebRTCMessageType.OFFER: | |
# Check if this is a targeted offer from producer to consumer | |
target_consumer = message.get("target_consumer") | |
if target_consumer and participant_role == "producer": | |
# Producer sending offer to specific consumer - forward it | |
logger.info( | |
f"🔄 Forwarding offer from producer {client_id} to consumer {target_consumer}" | |
) | |
output_message: WebRTCOfferMessageDict = { | |
"type": MessageType.WEBRTC_OFFER, | |
"offer": {"type": "offer", "sdp": message["sdp"]}, | |
"from_producer": client_id, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await self._send_to_participant(target_consumer, output_message) | |
return {"success": True, "message": "Offer forwarded to consumer"} | |
# For peer-to-peer, we don't handle server-side WebRTC connections | |
logger.info( | |
f"Ignoring server WebRTC offer from {client_id} - using peer-to-peer" | |
) | |
return { | |
"success": True, | |
"message": "Peer-to-peer mode - no server WebRTC processing", | |
} | |
if message["type"] == RawWebRTCMessageType.ANSWER: | |
# Handle answer from consumer back to producer | |
from_consumer = client_id | |
target_producer = message.get("target_producer") | |
if target_producer: | |
logger.info( | |
f"🔄 Forwarding answer from consumer {from_consumer} to producer {target_producer}" | |
) | |
output_message: WebRTCAnswerMessageDict = { | |
"type": MessageType.WEBRTC_ANSWER, | |
"answer": {"type": "answer", "sdp": message["sdp"]}, | |
"from_consumer": from_consumer, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await self._send_to_participant(target_producer, output_message) | |
return {"success": True, "message": "Answer forwarded to producer"} | |
elif message["type"] == RawWebRTCMessageType.ICE: | |
# Forward ICE candidates between peers | |
target_consumer = message.get("target_consumer") | |
target_producer = message.get("target_producer") | |
if target_consumer and participant_role == "producer": | |
output_message: WebRTCIceMessageDict = { | |
"type": MessageType.WEBRTC_ICE, | |
"candidate": message["candidate"], | |
"from_producer": client_id, | |
"from_consumer": None, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await self._send_to_participant(target_consumer, output_message) | |
return { | |
"success": True, | |
"message": "ICE candidate forwarded to consumer", | |
} | |
if target_producer and participant_role == "consumer": | |
output_message: WebRTCIceMessageDict = { | |
"type": MessageType.WEBRTC_ICE, | |
"candidate": message["candidate"], | |
"from_producer": None, | |
"from_consumer": client_id, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await self._send_to_participant(target_producer, output_message) | |
return { | |
"success": True, | |
"message": "ICE candidate forwarded to producer", | |
} | |
return None | |
async def broadcast_to_consumers( | |
self, workspace_id: str, room_id: str, frame_data: bytes | |
): | |
"""Broadcast frame to all consumers""" | |
room = self._get_room(workspace_id, room_id) | |
if not room: | |
return 0 | |
consumer_count = 0 | |
for consumer_id in room.consumers: | |
consumer_conn = self.webrtc_connections.get(consumer_id) | |
if consumer_conn and consumer_conn.is_consumer: | |
try: | |
consumer_conn.send_video_frame(frame_data) | |
consumer_count += 1 | |
except Exception: | |
logger.exception(f"Error sending frame to {consumer_id}") | |
if consumer_count > 0: | |
# Update room activity when frames are being broadcast | |
self._update_room_activity(workspace_id, room_id) | |
logger.debug(f"Broadcasted frame to {consumer_count} consumers") | |
return consumer_count | |
# ============= WEBSOCKET HANDLING ============= | |
async def handle_websocket( | |
self, websocket: WebSocket, workspace_id: str, room_id: str | |
): | |
"""Handle WebSocket connection for room management""" | |
await websocket.accept() | |
participant_id: str | None = None | |
role: ParticipantRole | None = None | |
try: | |
# Get join message | |
data = await websocket.receive_text() | |
join_msg = json.loads(data) | |
participant_id = join_msg["participant_id"] | |
role = ParticipantRole(join_msg["role"]) | |
# Join room | |
if not self.join_room(workspace_id, room_id, participant_id, role): | |
error_message: ErrorMessageDict = { | |
"type": MessageType.ERROR, | |
"message": "Cannot join room", | |
"code": None, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await websocket.send_text(json.dumps(error_message)) | |
await websocket.close() | |
return | |
self.websocket_connections[participant_id] = websocket | |
self.connection_metadata[participant_id] = { | |
"workspace_id": workspace_id, | |
"room_id": room_id, | |
"participant_id": participant_id, | |
"role": role, | |
"connected_at": datetime.now(tz=UTC), | |
"last_activity": datetime.now(tz=UTC), | |
"message_count": 0, | |
} | |
# Send join confirmation | |
joined_message: JoinedMessageDict = { | |
"type": MessageType.JOINED, | |
"room_id": room_id, | |
"role": role, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await websocket.send_text(json.dumps(joined_message)) | |
# Handle messages | |
async for message in websocket.iter_text(): | |
try: | |
msg = json.loads(message) | |
await self._handle_websocket_message( | |
workspace_id, room_id, participant_id, role, msg | |
) | |
except json.JSONDecodeError: | |
logger.exception(f"Invalid JSON from {participant_id}") | |
except Exception: | |
logger.exception("WebSocket message error") | |
except WebSocketDisconnect: | |
logger.info(f"WebSocket disconnected: {participant_id}") | |
except Exception: | |
logger.exception("WebSocket error") | |
finally: | |
# Cleanup | |
if participant_id: | |
metadata = self.connection_metadata.get(participant_id) | |
if metadata: | |
self.leave_room( | |
metadata["workspace_id"], metadata["room_id"], participant_id | |
) | |
if participant_id in self.websocket_connections: | |
del self.websocket_connections[participant_id] | |
if participant_id in self.connection_metadata: | |
del self.connection_metadata[participant_id] | |
async def _handle_websocket_message( | |
self, | |
workspace_id: str, | |
room_id: str, | |
participant_id: str, | |
role: ParticipantRole, | |
message: WebSocketMessageDict, | |
): | |
"""Handle incoming WebSocket message""" | |
# Update activity tracking | |
if participant_id in self.connection_metadata: | |
self.connection_metadata[participant_id]["last_activity"] = datetime.now( | |
tz=UTC | |
) | |
self.connection_metadata[participant_id]["message_count"] += 1 | |
# Update room activity | |
self._update_room_activity(workspace_id, room_id) | |
# Handle heartbeat | |
if message["type"] == MessageType.HEARTBEAT: | |
heartbeat_ack: HeartbeatAckMessageDict = { | |
"type": MessageType.HEARTBEAT_ACK, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await self._send_to_participant(participant_id, heartbeat_ack) | |
return | |
# Handle stream started notification | |
if message["type"] == MessageType.STREAM_STARTED: | |
logger.info( | |
f"Stream started by {participant_id} in room {room_id} (workspace {workspace_id})" | |
) | |
config = message.get("config", {}) | |
# Broadcast to other participants | |
broadcast_message: StreamStartedMessageDict = { | |
"type": MessageType.STREAM_STARTED, | |
"config": config, | |
"participant_id": participant_id, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await self._broadcast_to_room( | |
workspace_id, room_id, broadcast_message, exclude=participant_id | |
) | |
return | |
# Handle stream stopped notification | |
if message["type"] == MessageType.STREAM_STOPPED: | |
logger.info( | |
f"Stream stopped by {participant_id} in room {room_id} (workspace {workspace_id})" | |
) | |
reason = message.get("reason") | |
# Broadcast to other participants | |
broadcast_message: StreamStoppedMessageDict = { | |
"type": MessageType.STREAM_STOPPED, | |
"participant_id": participant_id, | |
"reason": reason, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await self._broadcast_to_room( | |
workspace_id, room_id, broadcast_message, exclude=participant_id | |
) | |
return | |
# Handle video config update | |
if message["type"] == MessageType.VIDEO_CONFIG_UPDATE: | |
logger.info( | |
f"Video config updated by {participant_id} in room {room_id} (workspace {workspace_id})" | |
) | |
config = message.get("config", {}) | |
# Update room config if producer | |
if role == ParticipantRole.PRODUCER: | |
room = self._get_room(workspace_id, room_id) | |
if room: | |
# Update room's video config | |
if "resolution" in config: | |
room.config.resolution = config["resolution"] | |
if "framerate" in config: | |
room.config.framerate = config["framerate"] | |
if "quality" in config: | |
room.config.quality = config["quality"] | |
if "encoding" in config: | |
room.config.encoding = config["encoding"] | |
if "bitrate" in config: | |
room.config.bitrate = config["bitrate"] | |
# Broadcast to other participants | |
broadcast_message: VideoConfigUpdateMessageDict = { | |
"type": MessageType.VIDEO_CONFIG_UPDATE, | |
"config": config, | |
"source": participant_id, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await self._broadcast_to_room( | |
workspace_id, room_id, broadcast_message, exclude=participant_id | |
) | |
return | |
# Handle status update | |
if message["type"] == MessageType.STATUS_UPDATE: | |
logger.info( | |
f"Status update from {participant_id} in room {room_id} (workspace {workspace_id})" | |
) | |
status = message.get("status", "unknown") | |
data = message.get("data") | |
# Broadcast to other participants | |
broadcast_message: StatusUpdateMessageDict = { | |
"type": MessageType.STATUS_UPDATE, | |
"status": status, | |
"data": data, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await self._broadcast_to_room( | |
workspace_id, room_id, broadcast_message, exclude=participant_id | |
) | |
return | |
# Handle stream stats | |
if message["type"] == MessageType.STREAM_STATS: | |
logger.debug( | |
f"Stream stats from {participant_id} in room {room_id} (workspace {workspace_id})" | |
) | |
stats = message.get("stats", {}) | |
# Broadcast to other participants (typically from producer to consumers) | |
broadcast_message: StreamStatsMessageDict = { | |
"type": MessageType.STREAM_STATS, | |
"stats": stats, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await self._broadcast_to_room( | |
workspace_id, room_id, broadcast_message, exclude=participant_id | |
) | |
return | |
# Handle emergency stop | |
if message["type"] == MessageType.EMERGENCY_STOP: | |
reason = message.get("reason", "Emergency stop triggered") | |
logger.warning( | |
f"Emergency stop by {participant_id} in room {room_id} (workspace {workspace_id}): {reason}" | |
) | |
# Broadcast to all participants | |
broadcast_message: EmergencyStopMessageDict = { | |
"type": MessageType.EMERGENCY_STOP, | |
"reason": reason, | |
"source": participant_id, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await self._broadcast_to_room(workspace_id, room_id, broadcast_message) | |
return | |
# Handle recovery triggered (from consumer typically) | |
if message["type"] == MessageType.RECOVERY_TRIGGERED: | |
policy = message.get("policy") | |
reason = message.get("reason", "Recovery triggered") | |
logger.info( | |
f"Recovery triggered by {participant_id} in room {room_id} (workspace {workspace_id}): {policy} - {reason}" | |
) | |
# Broadcast to other participants | |
broadcast_message: RecoveryTriggeredMessageDict = { | |
"type": MessageType.RECOVERY_TRIGGERED, | |
"policy": policy, | |
"reason": reason, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
await self._broadcast_to_room( | |
workspace_id, room_id, broadcast_message, exclude=participant_id | |
) | |
return | |
# Log unhandled message types | |
logger.info(f"Unhandled message type {message['type']} from {participant_id}") | |
async def _broadcast_to_room( | |
self, | |
workspace_id: str, | |
room_id: str, | |
message: WebSocketMessageDict, | |
exclude: str | None = None, | |
): | |
"""Broadcast message to all participants in a room""" | |
room = self._get_room(workspace_id, room_id) | |
if not room: | |
return | |
participants = [] | |
if room.producer: | |
participants.append(room.producer) | |
participants.extend(room.consumers) | |
for participant_id in participants: | |
if exclude and participant_id == exclude: | |
continue | |
await self._send_to_participant(participant_id, message) | |
async def _send_to_participant( | |
self, participant_id: str, message: WebSocketMessageDict | |
): | |
"""Send message to specific participant""" | |
if participant_id in self.websocket_connections: | |
try: | |
await self.websocket_connections[participant_id].send_text( | |
json.dumps(message) | |
) | |
except Exception: | |
logger.exception(f"Error sending message to {participant_id}") | |
if participant_id in self.websocket_connections: | |
del self.websocket_connections[participant_id] | |
async def _broadcast_participant_joined( | |
self, | |
workspace_id: str, | |
room_id: str, | |
participant_id: str, | |
role: ParticipantRole, | |
): | |
"""Broadcast participant joined event to other participants in the room""" | |
room = self._get_room(workspace_id, room_id) | |
if not room: | |
return | |
participants: list[str] = [] | |
if room.producer: | |
participants.append(room.producer) | |
participants.extend(room.consumers) | |
participant_joined_message: ParticipantJoinedMessageDict = { | |
"type": MessageType.PARTICIPANT_JOINED, | |
"room_id": room_id, | |
"participant_id": participant_id, | |
"role": role, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
for other_participant_id in participants: | |
if other_participant_id == participant_id: | |
continue | |
await self._send_to_participant( | |
other_participant_id, participant_joined_message | |
) | |
async def _broadcast_participant_left( | |
self, | |
workspace_id: str, | |
room_id: str, | |
participant_id: str, | |
role: ParticipantRole, | |
): | |
"""Broadcast participant left event to other participants in the room""" | |
room = self._get_room(workspace_id, room_id) | |
if not room: | |
return | |
participants: list[str] = [] | |
if room.producer: | |
participants.append(room.producer) | |
participants.extend(room.consumers) | |
participant_left_message: dict = { | |
"type": MessageType.PARTICIPANT_LEFT, | |
"room_id": room_id, | |
"participant_id": participant_id, | |
"role": role, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
for other_participant_id in participants: | |
if other_participant_id == participant_id: | |
continue | |
await self._send_to_participant( | |
other_participant_id, participant_left_message | |
) | |
def _add_background_task(self, coro: Coroutine): | |
"""Add a background task with automatic cleanup""" | |
task = asyncio.create_task(coro) | |
self.background_tasks.add(task) | |
task.add_done_callback(self.background_tasks.discard) | |
return task | |
# ============= CLEANUP MANAGEMENT ============= | |
async def manual_cleanup(self) -> dict: | |
"""Manually trigger room cleanup and return results""" | |
logger.info("Manual video room cleanup triggered") | |
rooms_before = sum(len(rooms) for rooms in self.workspaces.values()) | |
await self._cleanup_inactive_rooms() | |
rooms_after = sum(len(rooms) for rooms in self.workspaces.values()) | |
return { | |
"cleanup_triggered": True, | |
"rooms_before": rooms_before, | |
"rooms_after": rooms_after, | |
"rooms_removed": rooms_before - rooms_after, | |
"timestamp": datetime.now(tz=UTC).isoformat(), | |
} | |
def get_cleanup_status(self) -> dict: | |
"""Get cleanup system status and configuration""" | |
current_time = datetime.now(tz=UTC) | |
# Calculate room ages and activity | |
room_info = [] | |
for workspace_id, rooms in self.workspaces.items(): | |
for room_id, room in rooms.items(): | |
# Find latest activity for this room | |
latest_activity = room.last_activity | |
for metadata in self.connection_metadata.values(): | |
if ( | |
metadata.get("workspace_id") == workspace_id | |
and metadata.get("room_id") == room_id | |
): | |
if ( | |
metadata.get("last_activity") | |
and metadata["last_activity"] > latest_activity | |
): | |
latest_activity = metadata["last_activity"] | |
age = current_time - room.created_at | |
inactivity = current_time - latest_activity | |
room_info.append({ | |
"workspace_id": workspace_id, | |
"room_id": room_id, | |
"age_minutes": age.total_seconds() / 60, | |
"inactivity_minutes": inactivity.total_seconds() / 60, | |
"has_connections": any( | |
metadata.get("workspace_id") == workspace_id | |
and metadata.get("room_id") == room_id | |
for metadata in self.connection_metadata.values() | |
), | |
"will_be_cleaned": inactivity > self.inactivity_timeout, | |
}) | |
return { | |
"service": "video", | |
"cleanup_enabled": self._cleanup_task is not None, | |
"inactivity_timeout_minutes": self.inactivity_timeout.total_seconds() / 60, | |
"cleanup_interval_minutes": self.cleanup_interval.total_seconds() / 60, | |
"total_rooms": len(room_info), | |
"rooms": room_info, | |
"timestamp": current_time.isoformat(), | |
} | |