File size: 7,201 Bytes
02eac4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import asyncio

import pytest
from lerobot_arena_client import RoboticsProducer


class TestRoboticsProducer:
    """Test RoboticsProducer functionality."""

    @pytest.mark.asyncio
    async def test_producer_connection(self, producer, test_room):
        """Test basic producer connection."""
        assert not producer.is_connected()

        success = await producer.connect(test_room)
        assert success is True
        assert producer.is_connected()
        assert producer.room_id == test_room
        assert producer.role == "producer"

        await producer.disconnect()
        assert not producer.is_connected()

    @pytest.mark.asyncio
    async def test_producer_connection_info(self, connected_producer):
        """Test getting connection information."""
        producer, room_id = connected_producer

        info = producer.get_connection_info()
        assert info["connected"] is True
        assert info["room_id"] == room_id
        assert info["role"] == "producer"
        assert info["participant_id"] is not None
        assert info["base_url"] == "http://localhost:8000"

    @pytest.mark.asyncio
    async def test_send_joint_update(self, connected_producer):
        """Test sending joint updates."""
        producer, room_id = connected_producer

        joints = [
            {"name": "shoulder", "value": 45.0},
            {"name": "elbow", "value": -20.0},
            {"name": "wrist", "value": 10.0},
        ]

        # Should not raise an exception
        await producer.send_joint_update(joints)

    @pytest.mark.asyncio
    async def test_send_state_sync(self, connected_producer):
        """Test sending state synchronization."""
        producer, room_id = connected_producer

        state = {"shoulder": 45.0, "elbow": -20.0, "wrist": 10.0}

        # Should not raise an exception
        await producer.send_state_sync(state)

    @pytest.mark.asyncio
    async def test_send_emergency_stop(self, connected_producer):
        """Test sending emergency stop."""
        producer, room_id = connected_producer

        # Should not raise an exception
        await producer.send_emergency_stop("Test emergency stop")
        await producer.send_emergency_stop()  # Default reason

    @pytest.mark.asyncio
    async def test_send_heartbeat(self, connected_producer):
        """Test sending heartbeat."""
        producer, room_id = connected_producer

        # Should not raise an exception
        await producer.send_heartbeat()

    @pytest.mark.asyncio
    async def test_producer_callbacks(self, producer, test_room):
        """Test producer event callbacks."""
        connected_called = False
        disconnected_called = False
        error_called = False
        error_message = None

        def on_connected():
            nonlocal connected_called
            connected_called = True

        def on_disconnected():
            nonlocal disconnected_called
            disconnected_called = True

        def on_error(error):
            nonlocal error_called, error_message
            error_called = True
            error_message = error

        # Set callbacks
        producer.on_connected(on_connected)
        producer.on_disconnected(on_disconnected)
        producer.on_error(on_error)

        # Connect and disconnect
        await producer.connect(test_room)
        await asyncio.sleep(0.1)  # Give callbacks time to execute
        assert connected_called is True

        await producer.disconnect()
        await asyncio.sleep(0.1)  # Give callbacks time to execute
        assert disconnected_called is True

    @pytest.mark.asyncio
    async def test_send_without_connection(self, producer):
        """Test that sending commands without connection raises errors."""
        assert not producer.is_connected()

        with pytest.raises(ValueError, match="Must be connected"):
            await producer.send_joint_update([{"name": "test", "value": 0}])

        with pytest.raises(ValueError, match="Must be connected"):
            await producer.send_state_sync({"test": 0})

        with pytest.raises(ValueError, match="Must be connected"):
            await producer.send_emergency_stop()

    @pytest.mark.asyncio
    async def test_multiple_connections(self, producer, test_room):
        """Test connecting to multiple rooms sequentially."""
        # Connect to first room
        await producer.connect(test_room)
        assert producer.room_id == test_room

        # Create second room
        room_id_2 = await producer.create_room()

        try:
            # Connect to second room (should disconnect from first)
            await producer.connect(room_id_2)
            assert producer.room_id == room_id_2
            assert producer.is_connected()

        finally:
            await producer.delete_room(room_id_2)

    @pytest.mark.asyncio
    async def test_context_manager(self, test_room):
        """Test using producer as context manager."""
        async with RoboticsProducer("http://localhost:8000") as producer:
            await producer.connect(test_room)
            assert producer.is_connected()

            await producer.send_state_sync({"test": 123.0})

        # Should be disconnected after context exit
        assert not producer.is_connected()

    @pytest.mark.asyncio
    async def test_duplicate_producer_connection(self, producer, test_room):
        """Test what happens when multiple producers try to connect to same room."""
        producer2 = RoboticsProducer("http://localhost:8000")

        try:
            # First producer connects successfully
            success1 = await producer.connect(test_room)
            assert success1 is True

            # Second producer should fail to connect as producer
            success2 = await producer2.connect(test_room)
            assert success2 is False  # Should fail since room already has producer

        finally:
            if producer2.is_connected():
                await producer2.disconnect()

    @pytest.mark.asyncio
    async def test_custom_participant_id(self, producer, test_room):
        """Test connecting with custom participant ID."""
        custom_id = "custom-producer-123"

        await producer.connect(test_room, participant_id=custom_id)

        info = producer.get_connection_info()
        assert info["participant_id"] == custom_id

    @pytest.mark.asyncio
    async def test_large_joint_update(self, connected_producer):
        """Test sending large joint update."""
        producer, room_id = connected_producer

        # Create a large joint update
        joints = []
        for i in range(100):
            joints.append({"name": f"joint_{i}", "value": float(i)})

        # Should handle large updates without issue
        await producer.send_joint_update(joints)

    @pytest.mark.asyncio
    async def test_rapid_updates(self, connected_producer):
        """Test sending rapid joint updates."""
        producer, room_id = connected_producer

        # Send multiple rapid updates
        for i in range(10):
            await producer.send_state_sync({"joint1": float(i), "joint2": float(i * 2)})
            await asyncio.sleep(0.01)  # Small delay