Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Tests for the LeRobot Arena Video Client | |
Basic tests to validate the video client implementation | |
""" | |
import asyncio | |
import logging | |
import numpy as np | |
import pytest | |
from lerobot_arena_client.video import ( | |
CustomVideoTrack, | |
ParticipantRole, | |
Resolution, | |
VideoConfig, | |
VideoConsumer, | |
VideoEncoding, | |
VideoProducer, | |
create_consumer_client, | |
create_producer_client, | |
) | |
class TestVideoTypes: | |
"""Test video type definitions""" | |
def test_resolution_creation(self): | |
"""Test Resolution dataclass""" | |
res = Resolution(width=1920, height=1080) | |
assert res.width == 1920 | |
assert res.height == 1080 | |
def test_video_config_creation(self): | |
"""Test VideoConfig dataclass""" | |
config = VideoConfig( | |
encoding=VideoEncoding.VP8, | |
resolution=Resolution(640, 480), | |
framerate=30, | |
bitrate=1000000, | |
) | |
assert config.encoding == VideoEncoding.VP8 | |
assert config.resolution.width == 640 | |
assert config.framerate == 30 | |
def test_participant_role_enum(self): | |
"""Test ParticipantRole enum""" | |
assert ParticipantRole.PRODUCER.value == "producer" | |
assert ParticipantRole.CONSUMER.value == "consumer" | |
class TestVideoCore: | |
"""Test core video client functionality""" | |
def test_video_producer_creation(self): | |
"""Test VideoProducer initialization""" | |
producer = VideoProducer("http://localhost:8000") | |
assert producer.base_url == "http://localhost:8000" | |
assert producer.api_base == "http://localhost:8000/video" | |
assert not producer.connected | |
assert producer.room_id is None | |
def test_video_consumer_creation(self): | |
"""Test VideoConsumer initialization""" | |
consumer = VideoConsumer("http://localhost:8000") | |
assert consumer.base_url == "http://localhost:8000" | |
assert consumer.api_base == "http://localhost:8000/video" | |
assert not consumer.connected | |
assert consumer.room_id is None | |
async def test_producer_room_creation(self): | |
"""Test room creation (requires server)""" | |
try: | |
producer = VideoProducer("http://localhost:8000") | |
room_id = await producer.create_room() | |
assert isinstance(room_id, str) | |
assert len(room_id) > 0 | |
print(f"β Created room: {room_id}") | |
except Exception as e: | |
pytest.skip(f"Server not available: {e}") | |
async def test_consumer_list_rooms(self): | |
"""Test listing rooms (requires server)""" | |
try: | |
consumer = VideoConsumer("http://localhost:8000") | |
rooms = await consumer.list_rooms() | |
assert isinstance(rooms, list) | |
print(f"β Listed {len(rooms)} rooms") | |
except Exception as e: | |
pytest.skip(f"Server not available: {e}") | |
class TestVideoTracks: | |
"""Test video track implementations""" | |
async def test_custom_video_track(self): | |
"""Test CustomVideoTrack with mock frame source""" | |
frame_count = 0 | |
async def mock_frame_source() -> np.ndarray | None: | |
nonlocal frame_count | |
if frame_count >= 3: | |
return None | |
# Create a simple test frame | |
frame = np.zeros((240, 320, 3), dtype=np.uint8) | |
frame[:, :, frame_count % 3] = 255 # Red, Green, Blue frames | |
frame_count += 1 | |
return frame | |
track = CustomVideoTrack(mock_frame_source, frame_rate=10) | |
# Get a few frames | |
for i in range(3): | |
frame = await track.recv() | |
assert frame is not None | |
print(f"β Generated frame {i + 1}") | |
print("β CustomVideoTrack test passed") | |
class TestVideoClientIntegration: | |
"""Integration tests for video client""" | |
async def test_producer_consumer_setup(self): | |
"""Test producer and consumer setup without server connection""" | |
# Test producer setup | |
producer = VideoProducer("http://localhost:8000") | |
assert producer.get_video_track() is None | |
# Test consumer setup | |
consumer = VideoConsumer("http://localhost:8000") | |
assert consumer.get_remote_stream() is None | |
print("β Producer/Consumer setup test passed") | |
async def test_custom_stream_setup(self): | |
"""Test custom stream setup""" | |
async def test_frame_source() -> np.ndarray | None: | |
return np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) | |
producer = VideoProducer("http://localhost:8000") | |
# This will fail because we're not connected, but it tests the setup | |
try: | |
await producer.start_custom_stream(test_frame_source) | |
except ValueError as e: | |
assert "Must be connected" in str(e) | |
print("β Custom stream setup validation passed") | |
async def test_factory_functions(self): | |
"""Test factory function creation (without connection)""" | |
# Test that factory functions create the right types | |
# (We can't actually connect without a server) | |
try: | |
producer = await create_producer_client("http://localhost:8000") | |
except Exception: | |
# Expected to fail without server | |
pass | |
try: | |
consumer = await create_consumer_client( | |
"test-room", "http://localhost:8000" | |
) | |
except Exception: | |
# Expected to fail without server | |
pass | |
print("β Factory functions test passed") | |
async def run_interactive_tests(): | |
"""Run interactive tests for manual verification""" | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
logger.info("π§ͺ Running Interactive Video Client Tests") | |
# Test 1: Basic client creation | |
logger.info("π Test 1: Creating video clients...") | |
producer = VideoProducer("http://localhost:8000") | |
consumer = VideoConsumer("http://localhost:8000") | |
logger.info("β Clients created successfully") | |
# Test 2: Custom video track | |
logger.info("π Test 2: Testing custom video track...") | |
frame_count = 0 | |
async def animated_frame_source() -> np.ndarray | None: | |
nonlocal frame_count | |
if frame_count >= 5: | |
return None | |
# Create animated frame | |
frame = np.zeros((240, 320, 3), dtype=np.uint8) | |
t = frame_count * 0.5 | |
# Simple animation | |
for y in range(240): | |
for x in range(320): | |
r = int(128 + 127 * np.sin(t + x * 0.05)) | |
g = int(128 + 127 * np.sin(t + y * 0.05)) | |
b = int(128 + 127 * np.sin(t)) | |
frame[y, x] = [r, g, b] | |
frame_count += 1 | |
return frame | |
track = CustomVideoTrack(animated_frame_source, frame_rate=5) | |
for i in range(5): | |
frame = await track.recv() | |
logger.info( | |
f"πΊ Generated animated frame {i + 1}: {frame.width}x{frame.height}" | |
) | |
logger.info("β Custom video track test completed") | |
# Test 3: Server communication (if available) | |
logger.info("π Test 3: Testing server communication...") | |
try: | |
rooms = await consumer.list_rooms() | |
logger.info(f"β Server communication successful - found {len(rooms)} rooms") | |
# Try creating a room | |
room_id = await producer.create_room() | |
logger.info(f"β Room created successfully: {room_id}") | |
except Exception as e: | |
logger.warning( | |
f"β οΈ Server communication failed (expected if server not running): {e}" | |
) | |
logger.info("π All interactive tests completed!") | |
if __name__ == "__main__": | |
# Run interactive tests | |
asyncio.run(run_interactive_tests()) | |