tech-envision commited on
Commit
49e2d72
·
1 Parent(s): 8114c3f

Return tool calls when present

Browse files
Files changed (1) hide show
  1. src/chat.py +35 -20
src/chat.py CHANGED
@@ -146,12 +146,27 @@ class ChatSession:
146
  messages.append({"role": "tool", "content": msg.content})
147
  return messages
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  @staticmethod
150
  def _store_assistant_message(conversation: Conversation, message: Message) -> None:
151
  """Persist assistant messages, storing tool calls when present."""
152
 
153
  if message.tool_calls:
154
- content = json.dumps([c.model_dump() for c in message.tool_calls])
155
  else:
156
  content = message.content or ""
157
 
@@ -304,7 +319,7 @@ class ChatSession:
304
  final_resp = await self._handle_tool_calls(
305
  self._messages, response, self._conversation
306
  )
307
- return final_resp.message.content
308
 
309
  async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
310
  async with self._lock:
@@ -329,10 +344,9 @@ class ChatSession:
329
  async for resp in self._handle_tool_calls_stream(
330
  self._messages, response, self._conversation
331
  ):
332
- if resp.message.tool_calls:
333
- continue
334
- if resp.message.content:
335
- yield resp.message.content
336
 
337
  async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
338
  DBMessage.create(conversation=self._conversation, role="user", content=prompt)
@@ -365,23 +379,24 @@ class ChatSession:
365
  nxt = await self.ask(self._messages, think=True)
366
  self._store_assistant_message(self._conversation, nxt.message)
367
  self._messages.append(nxt.message.model_dump())
368
- if not nxt.message.tool_calls and nxt.message.content:
369
- yield nxt.message.content
 
370
  async for part in self._handle_tool_calls_stream(
371
  self._messages, nxt, self._conversation
372
  ):
373
- if part.message.tool_calls:
374
- continue
375
- if part.message.content:
376
- yield part.message.content
377
  else:
378
  resp = await user_task
379
  self._store_assistant_message(self._conversation, resp.message)
380
  self._messages.append(resp.message.model_dump())
381
  async with self._lock:
382
  self._state = "awaiting_tool"
383
- if not resp.message.tool_calls and resp.message.content:
384
- yield resp.message.content
 
385
  result = await exec_task
386
  self._tool_task = None
387
  self._messages.append(
@@ -395,12 +410,12 @@ class ChatSession:
395
  nxt = await self.ask(self._messages, think=True)
396
  self._store_assistant_message(self._conversation, nxt.message)
397
  self._messages.append(nxt.message.model_dump())
398
- if not nxt.message.tool_calls and nxt.message.content:
399
- yield nxt.message.content
 
400
  async for part in self._handle_tool_calls_stream(
401
  self._messages, nxt, self._conversation
402
  ):
403
- if part.message.tool_calls:
404
- continue
405
- if part.message.content:
406
- yield part.message.content
 
146
  messages.append({"role": "tool", "content": msg.content})
147
  return messages
148
 
149
+ # ------------------------------------------------------------------
150
+ @staticmethod
151
+ def _serialize_tool_calls(calls: List[Message.ToolCall]) -> str:
152
+ """Convert tool calls to a JSON string for storage or output."""
153
+
154
+ return json.dumps([c.model_dump() for c in calls])
155
+
156
+ @staticmethod
157
+ def _format_output(message: Message) -> str:
158
+ """Return tool calls as JSON or message content if present."""
159
+
160
+ if message.tool_calls:
161
+ return ChatSession._serialize_tool_calls(message.tool_calls)
162
+ return message.content or ""
163
+
164
  @staticmethod
165
  def _store_assistant_message(conversation: Conversation, message: Message) -> None:
166
  """Persist assistant messages, storing tool calls when present."""
167
 
168
  if message.tool_calls:
169
+ content = ChatSession._serialize_tool_calls(message.tool_calls)
170
  else:
171
  content = message.content or ""
172
 
 
319
  final_resp = await self._handle_tool_calls(
320
  self._messages, response, self._conversation
321
  )
322
+ return self._format_output(final_resp.message)
323
 
324
  async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
325
  async with self._lock:
 
344
  async for resp in self._handle_tool_calls_stream(
345
  self._messages, response, self._conversation
346
  ):
347
+ text = self._format_output(resp.message)
348
+ if text:
349
+ yield text
 
350
 
351
  async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
352
  DBMessage.create(conversation=self._conversation, role="user", content=prompt)
 
379
  nxt = await self.ask(self._messages, think=True)
380
  self._store_assistant_message(self._conversation, nxt.message)
381
  self._messages.append(nxt.message.model_dump())
382
+ text = self._format_output(nxt.message)
383
+ if text:
384
+ yield text
385
  async for part in self._handle_tool_calls_stream(
386
  self._messages, nxt, self._conversation
387
  ):
388
+ text = self._format_output(part.message)
389
+ if text:
390
+ yield text
 
391
  else:
392
  resp = await user_task
393
  self._store_assistant_message(self._conversation, resp.message)
394
  self._messages.append(resp.message.model_dump())
395
  async with self._lock:
396
  self._state = "awaiting_tool"
397
+ text = self._format_output(resp.message)
398
+ if text:
399
+ yield text
400
  result = await exec_task
401
  self._tool_task = None
402
  self._messages.append(
 
410
  nxt = await self.ask(self._messages, think=True)
411
  self._store_assistant_message(self._conversation, nxt.message)
412
  self._messages.append(nxt.message.model_dump())
413
+ text = self._format_output(nxt.message)
414
+ if text:
415
+ yield text
416
  async for part in self._handle_tool_calls_stream(
417
  self._messages, nxt, self._conversation
418
  ):
419
+ text = self._format_output(part.message)
420
+ if text:
421
+ yield text