tech-envision commited on
Commit
b8c6bee
·
1 Parent(s): 97c8016

Add message queue for agent chat

Browse files
Files changed (1) hide show
  1. src/chat.py +37 -12
src/chat.py CHANGED
@@ -80,6 +80,10 @@ class ChatSession:
80
  self._messages: List[Msg] = self._load_history()
81
  self._data = _get_session_data(self._conversation.id)
82
  self._lock = self._data.lock
 
 
 
 
83
 
84
  # Shared state properties -------------------------------------------------
85
 
@@ -191,7 +195,9 @@ class ChatSession:
191
  content = message.content or ""
192
 
193
  if content.strip():
194
- DBMessage.create(conversation=conversation, role="assistant", content=content)
 
 
195
 
196
  async def ask(self, messages: List[Msg], *, think: bool = True) -> ChatResponse:
197
  """Send a chat request, automatically prepending the system prompt."""
@@ -333,12 +339,8 @@ class ChatSession:
333
  async with self._lock:
334
  self._state = "idle"
335
 
336
-
337
- async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
338
  async with self._lock:
339
- if self._state == "generating":
340
- _LOG.info("Ignoring message while generating")
341
- return
342
  if self._state == "awaiting_tool" and self._tool_task:
343
  async for part in self._chat_during_tool(prompt):
344
  yield part
@@ -359,6 +361,33 @@ class ChatSession:
359
  if text:
360
  yield text
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  async def continue_stream(self) -> AsyncIterator[str]:
363
  async with self._lock:
364
  if self._state != "idle":
@@ -399,9 +428,7 @@ class ChatSession:
399
  self._tool_task = None
400
  name = self._current_tool_name or "tool"
401
  self._current_tool_name = None
402
- self._messages.append(
403
- {"role": "tool", "name": name, "content": result}
404
- )
405
  DBMessage.create(
406
  conversation=self._conversation, role="tool", content=result
407
  )
@@ -433,9 +460,7 @@ class ChatSession:
433
  self._remove_tool_placeholder(self._messages)
434
  name = self._current_tool_name or "tool"
435
  self._current_tool_name = None
436
- self._messages.append(
437
- {"role": "tool", "name": name, "content": result}
438
- )
439
  DBMessage.create(
440
  conversation=self._conversation, role="tool", content=result
441
  )
 
80
  self._messages: List[Msg] = self._load_history()
81
  self._data = _get_session_data(self._conversation.id)
82
  self._lock = self._data.lock
83
+ self._prompt_queue: asyncio.Queue[
84
+ tuple[str, asyncio.Queue[str | None]]
85
+ ] = asyncio.Queue()
86
+ self._worker: asyncio.Task | None = None
87
 
88
  # Shared state properties -------------------------------------------------
89
 
 
195
  content = message.content or ""
196
 
197
  if content.strip():
198
+ DBMessage.create(
199
+ conversation=conversation, role="assistant", content=content
200
+ )
201
 
202
  async def ask(self, messages: List[Msg], *, think: bool = True) -> ChatResponse:
203
  """Send a chat request, automatically prepending the system prompt."""
 
339
  async with self._lock:
340
  self._state = "idle"
341
 
342
+ async def _generate_stream(self, prompt: str) -> AsyncIterator[str]:
 
343
  async with self._lock:
 
 
 
344
  if self._state == "awaiting_tool" and self._tool_task:
345
  async for part in self._chat_during_tool(prompt):
346
  yield part
 
361
  if text:
362
  yield text
363
 
364
+ async def _process_prompt_queue(self) -> None:
365
+ try:
366
+ while not self._prompt_queue.empty():
367
+ prompt, result_q = await self._prompt_queue.get()
368
+ try:
369
+ async for part in self._generate_stream(prompt):
370
+ await result_q.put(part)
371
+ except Exception as exc: # pragma: no cover - unforeseen errors
372
+ _LOG.exception("Error processing prompt: %s", exc)
373
+ await result_q.put(f"Error: {exc}")
374
+ finally:
375
+ await result_q.put(None)
376
+ finally:
377
+ self._worker = None
378
+
379
+ async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
380
+ result_q: asyncio.Queue[str | None] = asyncio.Queue()
381
+ await self._prompt_queue.put((prompt, result_q))
382
+ if not self._worker or self._worker.done():
383
+ self._worker = asyncio.create_task(self._process_prompt_queue())
384
+
385
+ while True:
386
+ part = await result_q.get()
387
+ if part is None:
388
+ break
389
+ yield part
390
+
391
  async def continue_stream(self) -> AsyncIterator[str]:
392
  async with self._lock:
393
  if self._state != "idle":
 
428
  self._tool_task = None
429
  name = self._current_tool_name or "tool"
430
  self._current_tool_name = None
431
+ self._messages.append({"role": "tool", "name": name, "content": result})
 
 
432
  DBMessage.create(
433
  conversation=self._conversation, role="tool", content=result
434
  )
 
460
  self._remove_tool_placeholder(self._messages)
461
  name = self._current_tool_name or "tool"
462
  self._current_tool_name = None
463
+ self._messages.append({"role": "tool", "name": name, "content": result})
 
 
464
  DBMessage.create(
465
  conversation=self._conversation, role="tool", content=result
466
  )