File size: 10,120 Bytes
02eac4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import asyncio

import pytest
from lerobot_arena_client import RoboticsConsumer


class TestRoboticsConsumer:
    """Test RoboticsConsumer functionality."""

    @pytest.mark.asyncio
    async def test_consumer_connection(self, consumer, test_room):
        """Test basic consumer connection."""
        assert not consumer.is_connected()

        success = await consumer.connect(test_room)
        assert success is True
        assert consumer.is_connected()
        assert consumer.room_id == test_room
        assert consumer.role == "consumer"

        await consumer.disconnect()
        assert not consumer.is_connected()

    @pytest.mark.asyncio
    async def test_consumer_connection_info(self, connected_consumer):
        """Test getting connection information."""
        consumer, room_id = connected_consumer

        info = consumer.get_connection_info()
        assert info["connected"] is True
        assert info["room_id"] == room_id
        assert info["role"] == "consumer"
        assert info["participant_id"] is not None
        assert info["base_url"] == "http://localhost:8000"

    @pytest.mark.asyncio
    async def test_get_state_sync(self, connected_consumer):
        """Test getting current state synchronously."""
        consumer, room_id = connected_consumer

        state = await consumer.get_state_sync()
        assert isinstance(state, dict)
        # Initial state should be empty
        assert len(state) == 0

    @pytest.mark.asyncio
    async def test_consumer_callbacks_setup(self, consumer, test_room):
        """Test setting up consumer callbacks."""
        state_sync_called = False
        joint_update_called = False
        error_called = False
        connected_called = False
        disconnected_called = False

        def on_state_sync(state):
            nonlocal state_sync_called
            state_sync_called = True

        def on_joint_update(joints):
            nonlocal joint_update_called
            joint_update_called = True

        def on_error(error):
            nonlocal error_called
            error_called = True

        def on_connected():
            nonlocal connected_called
            connected_called = True

        def on_disconnected():
            nonlocal disconnected_called
            disconnected_called = True

        # Set all callbacks
        consumer.on_state_sync(on_state_sync)
        consumer.on_joint_update(on_joint_update)
        consumer.on_error(on_error)
        consumer.on_connected(on_connected)
        consumer.on_disconnected(on_disconnected)

        # Connect and test connection callbacks
        await consumer.connect(test_room)
        await asyncio.sleep(0.1)
        assert connected_called is True

        await consumer.disconnect()
        await asyncio.sleep(0.1)
        assert disconnected_called is True

    @pytest.mark.asyncio
    async def test_multiple_consumers(self, test_room):
        """Test multiple consumers connecting to same room."""
        consumer1 = RoboticsConsumer("http://localhost:8000")
        consumer2 = RoboticsConsumer("http://localhost:8000")

        try:
            # Both consumers should be able to connect
            success1 = await consumer1.connect(test_room)
            success2 = await consumer2.connect(test_room)

            assert success1 is True
            assert success2 is True
            assert consumer1.is_connected()
            assert consumer2.is_connected()

        finally:
            if consumer1.is_connected():
                await consumer1.disconnect()
            if consumer2.is_connected():
                await consumer2.disconnect()

    @pytest.mark.asyncio
    async def test_consumer_receive_state_sync(self, producer_consumer_pair):
        """Test consumer receiving state sync from producer."""
        producer, consumer, room_id = producer_consumer_pair

        received_states = []
        received_updates = []

        def on_state_sync(state):
            received_states.append(state)

        def on_joint_update(joints):
            received_updates.append(joints)

        consumer.on_state_sync(on_state_sync)
        consumer.on_joint_update(on_joint_update)

        # Give some time for connection to stabilize
        await asyncio.sleep(0.1)

        # Producer sends state sync (which gets converted to joint updates)
        await producer.send_state_sync({"shoulder": 45.0, "elbow": -20.0})

        # Wait for message to be received
        await asyncio.sleep(0.2)

        # Consumer should have received the joint updates from the state sync
        # The initial state sync during connection might be empty, so we check for joint updates
        assert len(received_updates) >= 1

    @pytest.mark.asyncio
    async def test_consumer_receive_joint_updates(self, producer_consumer_pair):
        """Test consumer receiving joint updates from producer."""
        producer, consumer, room_id = producer_consumer_pair

        received_updates = []

        def on_joint_update(joints):
            received_updates.append(joints)

        consumer.on_joint_update(on_joint_update)

        # Give some time for connection to stabilize
        await asyncio.sleep(0.1)

        # Producer sends joint updates
        test_joints = [
            {"name": "shoulder", "value": 45.0},
            {"name": "elbow", "value": -20.0},
        ]
        await producer.send_joint_update(test_joints)

        # Wait for message to be received
        await asyncio.sleep(0.2)

        # Consumer should have received the joint update
        assert len(received_updates) >= 1
        if received_updates:
            received_joints = received_updates[-1]
            assert isinstance(received_joints, list)
            assert len(received_joints) == 2

    @pytest.mark.asyncio
    async def test_consumer_multiple_updates(self, producer_consumer_pair):
        """Test consumer receiving multiple updates."""
        producer, consumer, room_id = producer_consumer_pair

        received_updates = []

        def on_joint_update(joints):
            received_updates.append(joints)

        consumer.on_joint_update(on_joint_update)

        # Give some time for connection to stabilize
        await asyncio.sleep(0.1)

        # Send multiple updates
        for i in range(5):
            await producer.send_state_sync({
                "joint1": float(i * 10),
                "joint2": float(i * -5),
            })
            await asyncio.sleep(0.05)

        # Wait for all messages to be received
        await asyncio.sleep(0.3)

        # Should have received multiple updates
        assert len(received_updates) >= 3

    @pytest.mark.asyncio
    async def test_consumer_emergency_stop(self, producer_consumer_pair):
        """Test consumer receiving emergency stop."""
        producer, consumer, room_id = producer_consumer_pair

        received_errors = []

        def on_error(error):
            received_errors.append(error)

        consumer.on_error(on_error)

        # Give some time for connection to stabilize
        await asyncio.sleep(0.1)

        # Producer sends emergency stop
        await producer.send_emergency_stop("Test emergency stop")

        # Wait for message to be received
        await asyncio.sleep(0.2)

        # Consumer should have received emergency stop as error
        assert len(received_errors) >= 1
        if received_errors:
            assert "emergency stop" in received_errors[-1].lower()

    @pytest.mark.asyncio
    async def test_custom_participant_id(self, consumer, test_room):
        """Test connecting with custom participant ID."""
        custom_id = "custom-consumer-456"

        await consumer.connect(test_room, participant_id=custom_id)

        info = consumer.get_connection_info()
        assert info["participant_id"] == custom_id

    @pytest.mark.asyncio
    async def test_context_manager(self, test_room):
        """Test using consumer as context manager."""
        async with RoboticsConsumer("http://localhost:8000") as consumer:
            await consumer.connect(test_room)
            assert consumer.is_connected()

            state = await consumer.get_state_sync()
            assert isinstance(state, dict)

        # Should be disconnected after context exit
        assert not consumer.is_connected()

    @pytest.mark.asyncio
    async def test_get_state_without_connection(self, consumer):
        """Test getting state without being connected."""
        assert not consumer.is_connected()

        with pytest.raises(ValueError, match="Must be connected to a room"):
            await consumer.get_state_sync()

    @pytest.mark.asyncio
    async def test_consumer_reconnection(self, consumer, test_room):
        """Test consumer reconnecting to same room."""
        # First connection
        await consumer.connect(test_room)
        assert consumer.is_connected()

        await consumer.disconnect()
        assert not consumer.is_connected()

        # Reconnect to same room
        await consumer.connect(test_room)
        assert consumer.is_connected()
        assert consumer.room_id == test_room

    @pytest.mark.asyncio
    async def test_consumer_state_after_producer_updates(self, producer_consumer_pair):
        """Test that consumer can get updated state after producer sends updates."""
        producer, consumer, room_id = producer_consumer_pair

        # Give some time for connection to stabilize
        await asyncio.sleep(0.1)

        # Producer sends some state updates
        await producer.send_state_sync({
            "shoulder": 45.0,
            "elbow": -20.0,
            "wrist": 10.0,
        })

        # Wait for state to propagate
        await asyncio.sleep(0.2)

        # Consumer should be able to get updated state
        state = await consumer.get_state_sync()
        assert isinstance(state, dict)
        # State should contain the joints we sent
        expected_joints = {"shoulder", "elbow", "wrist"}
        if state:  # Only check if state is not empty
            assert set(state.keys()) == expected_joints