Spaces:
Runtime error
Runtime error
tech-envision
commited on
Commit
·
49e2d72
1
Parent(s):
8114c3f
Return tool calls when present
Browse files- src/chat.py +35 -20
src/chat.py
CHANGED
@@ -146,12 +146,27 @@ class ChatSession:
|
|
146 |
messages.append({"role": "tool", "content": msg.content})
|
147 |
return messages
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
@staticmethod
|
150 |
def _store_assistant_message(conversation: Conversation, message: Message) -> None:
|
151 |
"""Persist assistant messages, storing tool calls when present."""
|
152 |
|
153 |
if message.tool_calls:
|
154 |
-
content =
|
155 |
else:
|
156 |
content = message.content or ""
|
157 |
|
@@ -304,7 +319,7 @@ class ChatSession:
|
|
304 |
final_resp = await self._handle_tool_calls(
|
305 |
self._messages, response, self._conversation
|
306 |
)
|
307 |
-
return final_resp.message
|
308 |
|
309 |
async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
|
310 |
async with self._lock:
|
@@ -329,10 +344,9 @@ class ChatSession:
|
|
329 |
async for resp in self._handle_tool_calls_stream(
|
330 |
self._messages, response, self._conversation
|
331 |
):
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
yield resp.message.content
|
336 |
|
337 |
async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
|
338 |
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
|
@@ -365,23 +379,24 @@ class ChatSession:
|
|
365 |
nxt = await self.ask(self._messages, think=True)
|
366 |
self._store_assistant_message(self._conversation, nxt.message)
|
367 |
self._messages.append(nxt.message.model_dump())
|
368 |
-
|
369 |
-
|
|
|
370 |
async for part in self._handle_tool_calls_stream(
|
371 |
self._messages, nxt, self._conversation
|
372 |
):
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
yield part.message.content
|
377 |
else:
|
378 |
resp = await user_task
|
379 |
self._store_assistant_message(self._conversation, resp.message)
|
380 |
self._messages.append(resp.message.model_dump())
|
381 |
async with self._lock:
|
382 |
self._state = "awaiting_tool"
|
383 |
-
|
384 |
-
|
|
|
385 |
result = await exec_task
|
386 |
self._tool_task = None
|
387 |
self._messages.append(
|
@@ -395,12 +410,12 @@ class ChatSession:
|
|
395 |
nxt = await self.ask(self._messages, think=True)
|
396 |
self._store_assistant_message(self._conversation, nxt.message)
|
397 |
self._messages.append(nxt.message.model_dump())
|
398 |
-
|
399 |
-
|
|
|
400 |
async for part in self._handle_tool_calls_stream(
|
401 |
self._messages, nxt, self._conversation
|
402 |
):
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
yield part.message.content
|
|
|
146 |
messages.append({"role": "tool", "content": msg.content})
|
147 |
return messages
|
148 |
|
149 |
+
# ------------------------------------------------------------------
|
150 |
+
@staticmethod
|
151 |
+
def _serialize_tool_calls(calls: List[Message.ToolCall]) -> str:
|
152 |
+
"""Convert tool calls to a JSON string for storage or output."""
|
153 |
+
|
154 |
+
return json.dumps([c.model_dump() for c in calls])
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def _format_output(message: Message) -> str:
|
158 |
+
"""Return tool calls as JSON or message content if present."""
|
159 |
+
|
160 |
+
if message.tool_calls:
|
161 |
+
return ChatSession._serialize_tool_calls(message.tool_calls)
|
162 |
+
return message.content or ""
|
163 |
+
|
164 |
@staticmethod
|
165 |
def _store_assistant_message(conversation: Conversation, message: Message) -> None:
|
166 |
"""Persist assistant messages, storing tool calls when present."""
|
167 |
|
168 |
if message.tool_calls:
|
169 |
+
content = ChatSession._serialize_tool_calls(message.tool_calls)
|
170 |
else:
|
171 |
content = message.content or ""
|
172 |
|
|
|
319 |
final_resp = await self._handle_tool_calls(
|
320 |
self._messages, response, self._conversation
|
321 |
)
|
322 |
+
return self._format_output(final_resp.message)
|
323 |
|
324 |
async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
|
325 |
async with self._lock:
|
|
|
344 |
async for resp in self._handle_tool_calls_stream(
|
345 |
self._messages, response, self._conversation
|
346 |
):
|
347 |
+
text = self._format_output(resp.message)
|
348 |
+
if text:
|
349 |
+
yield text
|
|
|
350 |
|
351 |
async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
|
352 |
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
|
|
|
379 |
nxt = await self.ask(self._messages, think=True)
|
380 |
self._store_assistant_message(self._conversation, nxt.message)
|
381 |
self._messages.append(nxt.message.model_dump())
|
382 |
+
text = self._format_output(nxt.message)
|
383 |
+
if text:
|
384 |
+
yield text
|
385 |
async for part in self._handle_tool_calls_stream(
|
386 |
self._messages, nxt, self._conversation
|
387 |
):
|
388 |
+
text = self._format_output(part.message)
|
389 |
+
if text:
|
390 |
+
yield text
|
|
|
391 |
else:
|
392 |
resp = await user_task
|
393 |
self._store_assistant_message(self._conversation, resp.message)
|
394 |
self._messages.append(resp.message.model_dump())
|
395 |
async with self._lock:
|
396 |
self._state = "awaiting_tool"
|
397 |
+
text = self._format_output(resp.message)
|
398 |
+
if text:
|
399 |
+
yield text
|
400 |
result = await exec_task
|
401 |
self._tool_task = None
|
402 |
self._messages.append(
|
|
|
410 |
nxt = await self.ask(self._messages, think=True)
|
411 |
self._store_assistant_message(self._conversation, nxt.message)
|
412 |
self._messages.append(nxt.message.model_dump())
|
413 |
+
text = self._format_output(nxt.message)
|
414 |
+
if text:
|
415 |
+
yield text
|
416 |
async for part in self._handle_tool_calls_stream(
|
417 |
self._messages, nxt, self._conversation
|
418 |
):
|
419 |
+
text = self._format_output(part.message)
|
420 |
+
if text:
|
421 |
+
yield text
|
|