Spaces:
Running
Running
""" | |
Consumer client for receiving video streams in LeRobot Arena | |
""" | |
import asyncio | |
import logging | |
from typing import Any | |
from aiortc import RTCIceCandidate, RTCSessionDescription | |
from .core import VideoClientCore | |
from .types import ( | |
ClientOptions, | |
FrameData, | |
FrameUpdateCallback, | |
ParticipantRole, | |
RecoveryTriggeredCallback, | |
StatusUpdateCallback, | |
StreamStartedCallback, | |
StreamStatsCallback, | |
StreamStoppedCallback, | |
VideoConfig, | |
VideoConfigUpdateCallback, | |
WebRTCStats, | |
) | |
logger = logging.getLogger(__name__) | |
class VideoConsumer(VideoClientCore): | |
"""Consumer client for receiving video streams in LeRobot Arena""" | |
def __init__( | |
self, | |
base_url: str = "http://localhost:8000", | |
options: ClientOptions | None = None, | |
): | |
super().__init__(base_url, options) | |
# Event callbacks | |
self.on_frame_update_callback: FrameUpdateCallback | None = None | |
self.on_video_config_update_callback: VideoConfigUpdateCallback | None = None | |
self.on_stream_started_callback: StreamStartedCallback | None = None | |
self.on_stream_stopped_callback: StreamStoppedCallback | None = None | |
self.on_recovery_triggered_callback: RecoveryTriggeredCallback | None = None | |
self.on_status_update_callback: StatusUpdateCallback | None = None | |
self.on_stream_stats_callback: StreamStatsCallback | None = None | |
# ICE candidate queuing for proper timing | |
self.ice_candidate_queue: list[dict[str, Any]] = [] | |
self.has_remote_description = False | |
# Frame monitoring for stream health | |
self._last_frame_time: float | None = None | |
self._frame_timeout_task: asyncio.Task | None = None | |
self._monitoring_frames = False | |
# ============= CONSUMER CONNECTION ============= | |
async def connect( | |
self, workspace_id: str, room_id: str, participant_id: str | None = None | |
) -> bool: | |
"""Connect to a room as consumer""" | |
# Create peer connection BEFORE connecting to avoid race condition | |
logger.info("Creating peer connection for consumer...") | |
self.create_peer_connection() | |
# Add transceiver to receive video | |
if self.peer_connection: | |
self.peer_connection.addTransceiver("video", direction="recvonly") | |
logger.info("Added video transceiver for consumer") | |
# Now connect to room - we're ready for WebRTC offers | |
connected = await self.connect_to_room( | |
workspace_id, room_id, ParticipantRole.CONSUMER, participant_id | |
) | |
if connected: | |
# Create peer connection immediately so we're ready for WebRTC offers | |
logger.info("π§ Consumer connected and ready for WebRTC offers") | |
await self.start_receiving() | |
return connected | |
# ============= CONSUMER METHODS ============= | |
async def start_receiving(self) -> None: | |
"""Start receiving video stream""" | |
if not self.connected: | |
raise ValueError("Must be connected to start receiving") | |
# Reset WebRTC state | |
self.has_remote_description = False | |
self.ice_candidate_queue = [] | |
# Create peer connection for receiving (if not already created) | |
if not self.peer_connection: | |
self.create_peer_connection() | |
# Set up to receive remote stream | |
if self.peer_connection: | |
# Add transceiver to receive video | |
self.peer_connection.addTransceiver("video", direction="recvonly") | |
logger.info("Added video transceiver for consumer") | |
else: | |
logger.info("Peer connection already exists for consumer") | |
async def stop_receiving(self) -> None: | |
"""Stop receiving video stream""" | |
# Stop frame monitoring | |
self._monitoring_frames = False | |
if self._frame_timeout_task and not self._frame_timeout_task.done(): | |
self._frame_timeout_task.cancel() | |
if self.peer_connection: | |
await self.peer_connection.close() | |
self.peer_connection = None | |
self.remote_stream = None | |
# ============= WEBRTC NEGOTIATION ============= | |
async def handle_webrtc_offer( | |
self, offer_data: dict[str, Any], from_producer: str | |
) -> None: | |
"""Handle WebRTC offer from producer""" | |
try: | |
logger.info(f"π₯ Received WebRTC offer from producer {from_producer}") | |
# Check if we need to restart the connection (new offer from same producer) | |
if self.peer_connection and self.has_remote_description: | |
logger.info("π Restarting connection for new stream...") | |
await self._restart_connection_for_new_stream() | |
if not self.peer_connection: | |
logger.info("π§ Creating new peer connection for offer...") | |
self.create_peer_connection() | |
# Add transceiver to receive video | |
if self.peer_connection: | |
self.peer_connection.addTransceiver("video", direction="recvonly") | |
logger.info("Added video transceiver for new connection") | |
# Reset state for new offer | |
self.has_remote_description = False | |
self.ice_candidate_queue = [] | |
# Set remote description (the offer) | |
offer = RTCSessionDescription( | |
sdp=offer_data["sdp"], type=offer_data["type"] | |
) | |
await self.set_remote_description(offer) | |
self.has_remote_description = True | |
# Process any queued ICE candidates now that we have remote description | |
await self._process_queued_ice_candidates() | |
# Create answer | |
answer = await self.create_answer(offer) | |
logger.info(f"π€ Sending WebRTC answer to producer {from_producer}") | |
# Send answer back through server to producer | |
if self.workspace_id and self.room_id and self.participant_id: | |
await self.send_webrtc_signal( | |
self.workspace_id, | |
self.room_id, | |
self.participant_id, | |
{ | |
"type": "answer", | |
"sdp": answer.sdp, | |
"target_producer": from_producer, | |
}, | |
) | |
logger.info("β WebRTC negotiation completed from consumer side") | |
except Exception as e: | |
logger.error(f"Failed to handle WebRTC offer: {e}") | |
if self.on_error_callback: | |
self.on_error_callback(f"Failed to handle WebRTC offer: {e}") | |
async def _restart_connection_for_new_stream(self) -> None: | |
"""Restart connection for new stream (called when getting new offer)""" | |
try: | |
logger.info("π Restarting peer connection for new stream...") | |
# Stop frame monitoring | |
self._monitoring_frames = False | |
if self._frame_timeout_task and not self._frame_timeout_task.done(): | |
self._frame_timeout_task.cancel() | |
# Close existing peer connection | |
if self.peer_connection: | |
await self.peer_connection.close() | |
self.peer_connection = None | |
# Reset all WebRTC state | |
self.remote_stream = None | |
self.has_remote_description = False | |
self.ice_candidate_queue = [] | |
self._last_frame_time = None | |
logger.info("β Connection restart completed") | |
except Exception as e: | |
logger.error(f"β Error restarting connection: {e}") | |
# Continue anyway - we'll try to create a new connection | |
async def handle_webrtc_ice( | |
self, ice_data: dict[str, Any], from_producer: str | |
) -> None: | |
"""Handle WebRTC ICE candidate from producer""" | |
if not self.peer_connection: | |
logger.warning("No peer connection available to handle ICE") | |
return | |
try: | |
logger.info(f"π₯ Received WebRTC ICE from producer {from_producer}") | |
# Parse ICE candidate string and create RTCIceCandidate | |
candidate_str = ice_data["candidate"] | |
parts = candidate_str.split() | |
if len(parts) >= 8: | |
candidate = RTCIceCandidate( | |
component=int(parts[1]), | |
foundation=parts[0].split(":")[1], # Remove "candidate:" prefix | |
ip=parts[4], | |
port=int(parts[5]), | |
priority=int(parts[3]), | |
protocol=parts[2], | |
type=parts[7], # typ value | |
sdpMid=ice_data.get("sdpMid"), | |
sdpMLineIndex=ice_data.get("sdpMLineIndex"), | |
) | |
else: | |
logger.warning(f"Invalid ICE candidate format: {candidate_str}") | |
return | |
if not self.has_remote_description: | |
# Queue ICE candidate until we have remote description | |
logger.info( | |
f"π Queuing ICE candidate from {from_producer} (no remote description yet)" | |
) | |
self.ice_candidate_queue.append({ | |
"candidate": candidate, | |
"from_producer": from_producer, | |
}) | |
return | |
# Add ICE candidate to peer connection | |
await self.add_ice_candidate(candidate) | |
logger.info(f"β WebRTC ICE handled from producer {from_producer}") | |
except Exception as e: | |
logger.error(f"Failed to handle WebRTC ICE from {from_producer}: {e}") | |
if self.on_error_callback: | |
self.on_error_callback(f"Failed to handle WebRTC ICE: {e}") | |
async def _process_queued_ice_candidates(self) -> None: | |
"""Process all queued ICE candidates after remote description is set""" | |
if not self.ice_candidate_queue: | |
return | |
logger.info( | |
f"π Processing {len(self.ice_candidate_queue)} queued ICE candidates" | |
) | |
for item in self.ice_candidate_queue: | |
try: | |
candidate = item["candidate"] | |
from_producer = item["from_producer"] | |
if self.peer_connection: | |
await self.peer_connection.addIceCandidate(candidate) | |
logger.info( | |
f"β Processed queued ICE candidate from {from_producer}" | |
) | |
except Exception as e: | |
logger.error( | |
f"Failed to process queued ICE candidate from {item.get('from_producer', 'unknown')}: {e}" | |
) | |
# Clear the queue | |
self.ice_candidate_queue = [] | |
# ============= EVENT CALLBACKS ============= | |
def on_frame_update(self, callback: FrameUpdateCallback) -> None: | |
"""Set callback for frame updates""" | |
self.on_frame_update_callback = callback | |
def on_video_config_update(self, callback: VideoConfigUpdateCallback) -> None: | |
"""Set callback for video config updates""" | |
self.on_video_config_update_callback = callback | |
def on_stream_started(self, callback: StreamStartedCallback) -> None: | |
"""Set callback for stream started events""" | |
self.on_stream_started_callback = callback | |
def on_stream_stopped(self, callback: StreamStoppedCallback) -> None: | |
"""Set callback for stream stopped events""" | |
self.on_stream_stopped_callback = callback | |
def on_recovery_triggered(self, callback: RecoveryTriggeredCallback) -> None: | |
"""Set callback for recovery triggered events""" | |
self.on_recovery_triggered_callback = callback | |
def on_status_update(self, callback: StatusUpdateCallback) -> None: | |
"""Set callback for status updates""" | |
self.on_status_update_callback = callback | |
def on_stream_stats(self, callback: StreamStatsCallback) -> None: | |
"""Set callback for stream statistics""" | |
self.on_stream_stats_callback = callback | |
# ============= MESSAGE HANDLING ============= | |
async def _handle_role_specific_message(self, data: dict[str, Any]) -> None: | |
"""Handle consumer-specific messages""" | |
msg_type = data.get("type") | |
if msg_type == "frame_update": | |
await self._handle_frame_update(data) | |
elif msg_type == "video_config_update": | |
await self._handle_video_config_update(data) | |
elif msg_type == "stream_started": | |
await self._handle_stream_started(data) | |
elif msg_type == "stream_stopped": | |
await self._handle_stream_stopped(data) | |
elif msg_type == "recovery_triggered": | |
await self._handle_recovery_triggered(data) | |
elif msg_type == "status_update": | |
await self._handle_status_update(data) | |
elif msg_type == "stream_stats": | |
await self._handle_stream_stats(data) | |
elif msg_type == "participant_joined": | |
logger.info( | |
f"π₯ Participant joined: {data.get('participant_id')} as {data.get('role')}" | |
) | |
# If it's a producer joining, we should be ready for offers | |
if data.get("role") == "producer": | |
producer_id = data.get("participant_id", "") | |
logger.info( | |
f"π¬ Producer {producer_id} joined - ready for WebRTC offers" | |
) | |
elif msg_type == "participant_left": | |
logger.info( | |
f"π€ Participant left: {data.get('participant_id')} ({data.get('role')})" | |
) | |
# If it's a producer leaving, we should be ready for recovery | |
if data.get("role") == "producer": | |
producer_id = data.get("participant_id", "") | |
logger.info( | |
f"π Producer {producer_id} left - waiting for reconnection..." | |
) | |
# Reset state for potential reconnection | |
self.has_remote_description = False | |
self.ice_candidate_queue = [] | |
elif msg_type == "webrtc_offer": | |
await self.handle_webrtc_offer( | |
data.get("offer", {}), data.get("from_producer", "") | |
) | |
elif msg_type == "webrtc_answer": | |
logger.info("Received WebRTC answer (consumer should not receive this)") | |
elif msg_type == "webrtc_ice": | |
await self.handle_webrtc_ice( | |
data.get("candidate", {}), data.get("from_producer", "") | |
) | |
elif msg_type == "emergency_stop": | |
logger.warning(f"Emergency stop: {data.get('reason', 'Unknown reason')}") | |
if self.on_error_callback: | |
self.on_error_callback( | |
f"Emergency stop: {data.get('reason', 'Unknown reason')}" | |
) | |
else: | |
logger.warning(f"Unknown message type for consumer: {msg_type}") | |
async def _handle_frame_update(self, data: dict[str, Any]) -> None: | |
"""Handle frame update message""" | |
if self.on_frame_update_callback: | |
frame_data = FrameData( | |
data=data.get("data", b""), metadata=data.get("metadata", {}) | |
) | |
self.on_frame_update_callback(frame_data) | |
async def _handle_video_config_update(self, data: dict[str, Any]) -> None: | |
"""Handle video config update message""" | |
if self.on_video_config_update_callback: | |
config = self._dict_to_video_config(data.get("config", {})) | |
self.on_video_config_update_callback(config) | |
async def _handle_stream_started(self, data: dict[str, Any]) -> None: | |
"""Handle stream started message""" | |
if self.on_stream_started_callback: | |
config = self._dict_to_video_config(data.get("config", {})) | |
participant_id = data.get("participant_id", "") | |
self.on_stream_started_callback(config, participant_id) | |
logger.info( | |
f"π Stream started by producer {data.get('participant_id')}, ready to receive video" | |
) | |
async def _handle_stream_stopped(self, data: dict[str, Any]) -> None: | |
"""Handle stream stopped message""" | |
producer_id = data.get("participant_id", "") | |
reason = data.get("reason") | |
logger.info(f"βΉοΈ Stream stopped by producer {producer_id}") | |
if reason: | |
logger.info(f" Reason: {reason}") | |
# Reset WebRTC state for potential restart | |
self.has_remote_description = False | |
self.ice_candidate_queue = [] | |
# Keep peer connection alive for potential restart | |
logger.info("π Ready for stream restart...") | |
if self.on_stream_stopped_callback: | |
self.on_stream_stopped_callback(producer_id, reason) | |
async def _handle_recovery_triggered(self, data: dict[str, Any]) -> None: | |
"""Handle recovery triggered message""" | |
if self.on_recovery_triggered_callback: | |
from .types import RecoveryPolicy | |
policy = RecoveryPolicy(data.get("policy", "freeze_last_frame")) | |
reason = data.get("reason", "") | |
self.on_recovery_triggered_callback(policy, reason) | |
async def _handle_status_update(self, data: dict[str, Any]) -> None: | |
"""Handle status update message""" | |
if self.on_status_update_callback: | |
status = data.get("status", "") | |
status_data = data.get("data") | |
self.on_status_update_callback(status, status_data) | |
async def _handle_stream_stats(self, data: dict[str, Any]) -> None: | |
"""Handle stream stats message""" | |
if self.on_stream_stats_callback: | |
from .types import StreamStats | |
stats_data = data.get("stats", {}) | |
stats = StreamStats( | |
stream_id=stats_data.get("stream_id", ""), | |
duration_seconds=stats_data.get("duration_seconds", 0.0), | |
frame_count=stats_data.get("frame_count", 0), | |
total_bytes=stats_data.get("total_bytes", 0), | |
average_fps=stats_data.get("average_fps", 0.0), | |
average_bitrate=stats_data.get("average_bitrate", 0.0), | |
) | |
self.on_stream_stats_callback(stats) | |
# ============= TRACK HANDLING ============= | |
def _handle_track_received(self, track: Any) -> None: | |
"""Handle received video track""" | |
logger.info(f"πΊ Received video track: {track.kind}") | |
self.remote_stream = track | |
# Start reading frames from the track | |
if track.kind == "video": | |
asyncio.create_task(self._read_video_frames(track)) | |
# Start frame monitoring | |
asyncio.create_task(self._start_frame_monitoring()) | |
async def _read_video_frames(self, track: Any) -> None: | |
"""Read frames from video track and trigger callbacks""" | |
frame_count = 0 | |
self._monitoring_frames = True | |
consecutive_errors = 0 | |
max_consecutive_errors = 5 | |
try: | |
logger.info(f"πΉ Starting frame reading from track: {track.kind}") | |
while self._monitoring_frames: | |
try: | |
# Use timeout to detect stream issues | |
frame = await asyncio.wait_for(track.recv(), timeout=5.0) | |
frame_count += 1 | |
consecutive_errors = 0 # Reset error count on success | |
# Update frame monitoring | |
import time | |
self._last_frame_time = time.time() | |
# Convert frame to numpy array properly - use RGB format to match server | |
img = frame.to_ndarray(format="rgb24") | |
# Convert RGB to BGR for OpenCV compatibility if needed | |
# For callbacks, we can provide RGB data and let user decide format | |
frame_data = FrameData( | |
data=img.tobytes(), | |
metadata={ | |
"width": frame.width, | |
"height": frame.height, | |
"format": "rgb24", # Server sends RGB format | |
"pts": frame.pts, | |
"time_base": str(frame.time_base), | |
"frame_count": frame_count, | |
}, | |
) | |
# Trigger frame update callback | |
if self.on_frame_update_callback: | |
self.on_frame_update_callback(frame_data) | |
if frame_count % 30 == 0: # Log every 30 frames | |
logger.info(f"πΉ Processed {frame_count} video frames") | |
except TimeoutError: | |
logger.warning( | |
"β° Timeout waiting for video frame - stream may have stopped" | |
) | |
consecutive_errors += 1 | |
if consecutive_errors >= max_consecutive_errors: | |
logger.error( | |
"β Too many consecutive frame timeouts - stopping frame reading" | |
) | |
break | |
await asyncio.sleep(1) # Wait before retrying | |
continue | |
except Exception as frame_error: | |
# Individual frame processing error - log but continue | |
consecutive_errors += 1 | |
logger.warning( | |
f"β οΈ Error processing frame {frame_count}: {frame_error}" | |
) | |
if consecutive_errors >= max_consecutive_errors: | |
logger.error( | |
f"β Too many consecutive frame errors ({consecutive_errors}) - stopping frame reading" | |
) | |
break | |
await asyncio.sleep(0.1) # Brief pause before retrying | |
continue | |
except Exception as e: | |
logger.error(f"β Fatal error in frame reading loop: {e}") | |
finally: | |
logger.info( | |
f"π Frame reading stopped. Total frames processed: {frame_count}" | |
) | |
self._monitoring_frames = False | |
# If we stopped due to errors and we're still connected, try to restart | |
if consecutive_errors >= max_consecutive_errors and self.connected: | |
logger.info( | |
"π Frame reading stopped due to errors - triggering connection recovery" | |
) | |
asyncio.create_task(self._handle_connection_failure()) | |
async def _start_frame_monitoring(self) -> None: | |
"""Start monitoring for frame timeouts""" | |
if self._frame_timeout_task and not self._frame_timeout_task.done(): | |
self._frame_timeout_task.cancel() | |
self._frame_timeout_task = asyncio.create_task(self._monitor_frame_timeout()) | |
async def _monitor_frame_timeout(self) -> None: | |
"""Monitor for frame timeouts and trigger recovery if needed""" | |
timeout_seconds = 10.0 # No frames for 10 seconds = problem | |
while self.connected and self._monitoring_frames: | |
await asyncio.sleep(5) # Check every 5 seconds | |
if self._last_frame_time is not None: | |
import time | |
time_since_last_frame = time.time() - self._last_frame_time | |
if time_since_last_frame > timeout_seconds: | |
logger.warning( | |
f"β οΈ No frames received for {time_since_last_frame:.1f}s - stream may be stopped" | |
) | |
# Reset frame monitoring | |
self._last_frame_time = None | |
# ============= UTILITY METHODS ============= | |
async def create_and_connect( | |
workspace_id: str, | |
room_id: str, | |
base_url: str = "http://localhost:8000", | |
participant_id: str | None = None, | |
) -> "VideoConsumer": | |
"""Create a consumer and automatically connect to a room""" | |
consumer = VideoConsumer(base_url) | |
connected = await consumer.connect(workspace_id, room_id, participant_id) | |
if not connected: | |
raise ValueError("Failed to connect as video consumer") | |
return consumer | |
def attach_to_video_element(self, video_element: Any) -> None: | |
"""Attach remote stream to a video element (for web frameworks)""" | |
if self.remote_stream: | |
# This would be used in web contexts | |
# For now, just log that we have a stream | |
logger.info("Video stream available for attachment") | |
async def get_video_stats(self) -> WebRTCStats | None: | |
"""Get current video statistics""" | |
return await self.get_stats() | |
def _dict_to_video_config(self, data: dict[str, Any]) -> VideoConfig: | |
"""Convert dictionary to VideoConfig""" | |
from .types import Resolution, VideoEncoding | |
config = VideoConfig() | |
if "encoding" in data: | |
config.encoding = VideoEncoding(data["encoding"]) | |
if "resolution" in data: | |
res_data = data["resolution"] | |
config.resolution = Resolution( | |
width=res_data.get("width", 640), height=res_data.get("height", 480) | |
) | |
if "framerate" in data: | |
config.framerate = data["framerate"] | |
if "bitrate" in data: | |
config.bitrate = data["bitrate"] | |
if "quality" in data: | |
config.quality = data["quality"] | |
return config | |