tech-envision commited on
Commit
d0252db
·
1 Parent(s): 4477cf4

Add asynchronous message streaming

Browse files
Files changed (4) hide show
  1. README.md +2 -1
  2. bot/discord_bot.py +2 -3
  3. run.py +2 -2
  4. src/chat.py +125 -4
README.md CHANGED
@@ -31,7 +31,8 @@ Uploaded files are stored under the `uploads` directory and mounted inside the V
31
  ```python
32
  async with ChatSession() as chat:
33
  path_in_vm = chat.upload_document("path/to/file.pdf")
34
- reply = await chat.chat(f"Summarize {path_in_vm}")
 
35
  ```
36
 
37
  When using the Discord bot, attach one or more text files to a message to
 
31
  ```python
32
  async with ChatSession() as chat:
33
  path_in_vm = chat.upload_document("path/to/file.pdf")
34
+ async for part in chat.chat_stream(f"Summarize {path_in_vm}"):
35
+ print(part)
36
  ```
37
 
38
  When using the Discord bot, attach one or more text files to a message to
bot/discord_bot.py CHANGED
@@ -71,12 +71,11 @@ async def on_message(message: discord.Message) -> None:
71
 
72
  if message.content.strip():
73
  try:
74
- reply = await chat.chat(message.content)
 
75
  except Exception as exc: # pragma: no cover - runtime errors
76
  _LOG.error("Failed to process message: %s", exc)
77
  await message.reply(f"Error: {exc}", mention_author=False)
78
- else:
79
- await message.reply(reply, mention_author=False)
80
 
81
 
82
  def main() -> None:
 
71
 
72
  if message.content.strip():
73
  try:
74
+ async for part in chat.chat_stream(message.content):
75
+ await message.reply(part, mention_author=False)
76
  except Exception as exc: # pragma: no cover - runtime errors
77
  _LOG.error("Failed to process message: %s", exc)
78
  await message.reply(f"Error: {exc}", mention_author=False)
 
 
79
 
80
 
81
  def main() -> None:
run.py CHANGED
@@ -10,8 +10,8 @@ async def _main() -> None:
10
  doc_path = chat.upload_document("test.txt")
11
  # print(f"Document uploaded to VM at: {doc_path}")
12
  # answer = await chat.chat(f"Remove all contents of test.txt and add the text 'Hello, World!' to it.")
13
- answer = await chat.chat("What is in /data directory?")
14
- print("\n>>>", answer)
15
 
16
 
17
  if __name__ == "__main__":
 
10
  doc_path = chat.upload_document("test.txt")
11
  # print(f"Document uploaded to VM at: {doc_path}")
12
  # answer = await chat.chat(f"Remove all contents of test.txt and add the text 'Hello, World!' to it.")
13
+ async for resp in chat.chat_stream("What is in /data directory?"):
14
+ print("\n>>>", resp)
15
 
16
 
17
  if __name__ == "__main__":
src/chat.py CHANGED
@@ -1,6 +1,6 @@
1
  from __future__ import annotations
2
 
3
- from typing import List
4
  import json
5
  import asyncio
6
  import shutil
@@ -49,6 +49,9 @@ class ChatSession:
49
  )
50
  self._vm = None
51
  self._messages: List[Msg] = self._load_history()
 
 
 
52
 
53
  async def __aenter__(self) -> "ChatSession":
54
  self._vm = VMRegistry.acquire(self._user.username)
@@ -131,13 +134,13 @@ class ChatSession:
131
  options={"num_ctx": NUM_CTX},
132
  )
133
 
134
- async def _handle_tool_calls(
135
  self,
136
  messages: List[Msg],
137
  response: ChatResponse,
138
  conversation: Conversation,
139
  depth: int = 0,
140
- ) -> ChatResponse:
141
  while depth < MAX_TOOL_CALL_DEPTH and response.message.tool_calls:
142
  for call in response.message.tool_calls:
143
  if call.function.name != "execute_terminal":
@@ -162,6 +165,10 @@ class ChatSession:
162
  )
163
  follow_task = asyncio.create_task(self.ask(messages, think=True))
164
 
 
 
 
 
165
  done, _ = await asyncio.wait(
166
  {exec_task, follow_task},
167
  return_when=asyncio.FIRST_COMPLETED,
@@ -186,13 +193,19 @@ class ChatSession:
186
  role="tool",
187
  content=result,
188
  )
 
 
 
189
  nxt = await self.ask(messages, think=True)
190
  self._store_assistant_message(conversation, nxt.message)
 
191
  response = nxt
 
192
  else:
193
  followup = await follow_task
194
  self._store_assistant_message(conversation, followup.message)
195
  messages.append(followup.message.model_dump())
 
196
  result = await exec_task
197
  messages.append(
198
  {
@@ -206,13 +219,32 @@ class ChatSession:
206
  role="tool",
207
  content=result,
208
  )
 
 
 
209
  nxt = await self.ask(messages, think=True)
210
  self._store_assistant_message(conversation, nxt.message)
 
211
  response = nxt
 
212
 
213
  depth += 1
214
 
215
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  async def chat(self, prompt: str) -> str:
218
  DBMessage.create(conversation=self._conversation, role="user", content=prompt)
@@ -228,3 +260,92 @@ class ChatSession:
228
  self._messages, response, self._conversation
229
  )
230
  return final_resp.message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ from typing import List, AsyncIterator
4
  import json
5
  import asyncio
6
  import shutil
 
49
  )
50
  self._vm = None
51
  self._messages: List[Msg] = self._load_history()
52
+ self._lock = asyncio.Lock()
53
+ self._state = "idle"
54
+ self._tool_task: asyncio.Task | None = None
55
 
56
  async def __aenter__(self) -> "ChatSession":
57
  self._vm = VMRegistry.acquire(self._user.username)
 
134
  options={"num_ctx": NUM_CTX},
135
  )
136
 
137
+ async def _handle_tool_calls_stream(
138
  self,
139
  messages: List[Msg],
140
  response: ChatResponse,
141
  conversation: Conversation,
142
  depth: int = 0,
143
+ ) -> AsyncIterator[ChatResponse]:
144
  while depth < MAX_TOOL_CALL_DEPTH and response.message.tool_calls:
145
  for call in response.message.tool_calls:
146
  if call.function.name != "execute_terminal":
 
165
  )
166
  follow_task = asyncio.create_task(self.ask(messages, think=True))
167
 
168
+ async with self._lock:
169
+ self._state = "awaiting_tool"
170
+ self._tool_task = exec_task
171
+
172
  done, _ = await asyncio.wait(
173
  {exec_task, follow_task},
174
  return_when=asyncio.FIRST_COMPLETED,
 
193
  role="tool",
194
  content=result,
195
  )
196
+ async with self._lock:
197
+ self._state = "generating"
198
+ self._tool_task = None
199
  nxt = await self.ask(messages, think=True)
200
  self._store_assistant_message(conversation, nxt.message)
201
+ messages.append(nxt.message.model_dump())
202
  response = nxt
203
+ yield nxt
204
  else:
205
  followup = await follow_task
206
  self._store_assistant_message(conversation, followup.message)
207
  messages.append(followup.message.model_dump())
208
+ yield followup
209
  result = await exec_task
210
  messages.append(
211
  {
 
219
  role="tool",
220
  content=result,
221
  )
222
+ async with self._lock:
223
+ self._state = "generating"
224
+ self._tool_task = None
225
  nxt = await self.ask(messages, think=True)
226
  self._store_assistant_message(conversation, nxt.message)
227
+ messages.append(nxt.message.model_dump())
228
  response = nxt
229
+ yield nxt
230
 
231
  depth += 1
232
 
233
+ async with self._lock:
234
+ self._state = "idle"
235
+
236
+ async def _handle_tool_calls(
237
+ self,
238
+ messages: List[Msg],
239
+ response: ChatResponse,
240
+ conversation: Conversation,
241
+ depth: int = 0,
242
+ ) -> ChatResponse:
243
+ final = response
244
+ gen = self._handle_tool_calls_stream(messages, response, conversation, depth)
245
+ async for final in gen:
246
+ pass
247
+ return final
248
 
249
  async def chat(self, prompt: str) -> str:
250
  DBMessage.create(conversation=self._conversation, role="user", content=prompt)
 
260
  self._messages, response, self._conversation
261
  )
262
  return final_resp.message.content
263
+
264
+ async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
265
+ async with self._lock:
266
+ if self._state == "generating":
267
+ _LOG.info("Ignoring message while generating")
268
+ return
269
+ if self._state == "awaiting_tool" and self._tool_task:
270
+ async for part in self._chat_during_tool(prompt):
271
+ yield part
272
+ return
273
+ self._state = "generating"
274
+
275
+ DBMessage.create(conversation=self._conversation, role="user", content=prompt)
276
+ self._messages.append({"role": "user", "content": prompt})
277
+
278
+ response = await self.ask(self._messages)
279
+ self._messages.append(response.message.model_dump())
280
+ self._store_assistant_message(self._conversation, response.message)
281
+
282
+ _LOG.info("Thinking:\n%s", response.message.thinking or "<no thinking trace>")
283
+
284
+ async for resp in self._handle_tool_calls_stream(
285
+ self._messages, response, self._conversation
286
+ ):
287
+ yield resp.message.content
288
+
289
+ async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
290
+ DBMessage.create(conversation=self._conversation, role="user", content=prompt)
291
+ self._messages.append({"role": "user", "content": prompt})
292
+
293
+ user_task = asyncio.create_task(self.ask(self._messages))
294
+ exec_task = self._tool_task
295
+
296
+ done, _ = await asyncio.wait(
297
+ {exec_task, user_task},
298
+ return_when=asyncio.FIRST_COMPLETED,
299
+ )
300
+
301
+ if exec_task in done:
302
+ user_task.cancel()
303
+ try:
304
+ await user_task
305
+ except asyncio.CancelledError:
306
+ pass
307
+ result = await exec_task
308
+ self._tool_task = None
309
+ self._messages.append(
310
+ {"role": "tool", "name": "execute_terminal", "content": result}
311
+ )
312
+ DBMessage.create(
313
+ conversation=self._conversation, role="tool", content=result
314
+ )
315
+ async with self._lock:
316
+ self._state = "generating"
317
+ nxt = await self.ask(self._messages, think=True)
318
+ self._store_assistant_message(self._conversation, nxt.message)
319
+ self._messages.append(nxt.message.model_dump())
320
+ yield nxt.message.content
321
+ async for part in self._handle_tool_calls_stream(
322
+ self._messages, nxt, self._conversation
323
+ ):
324
+ yield part.message.content
325
+ else:
326
+ resp = await user_task
327
+ self._store_assistant_message(self._conversation, resp.message)
328
+ self._messages.append(resp.message.model_dump())
329
+ async with self._lock:
330
+ self._state = "awaiting_tool"
331
+ yield resp.message.content
332
+ result = await exec_task
333
+ self._tool_task = None
334
+ self._messages.append(
335
+ {"role": "tool", "name": "execute_terminal", "content": result}
336
+ )
337
+ DBMessage.create(
338
+ conversation=self._conversation, role="tool", content=result
339
+ )
340
+ async with self._lock:
341
+ self._state = "generating"
342
+ nxt = await self.ask(self._messages, think=True)
343
+ self._store_assistant_message(self._conversation, nxt.message)
344
+ self._messages.append(nxt.message.model_dump())
345
+ yield nxt.message.content
346
+ async for part in self._handle_tool_calls_stream(
347
+ self._messages, nxt, self._conversation
348
+ ):
349
+ yield part.message.content
350
+
351
+