llmOS-Agent / src /chat.py
tech-envision
stream only final responses
280fbf3
raw
history blame
13.5 kB
from __future__ import annotations
from typing import List, AsyncIterator
import json
import asyncio
import shutil
from pathlib import Path
from ollama import AsyncClient, ChatResponse, Message
from .config import (
MAX_TOOL_CALL_DEPTH,
MODEL_NAME,
NUM_CTX,
OLLAMA_HOST,
SYSTEM_PROMPT,
UPLOAD_DIR,
)
from .db import (
Conversation,
Message as DBMessage,
User,
_db,
init_db,
add_document,
)
from .log import get_logger
from .schema import Msg
from .tools import execute_terminal, execute_terminal_async, set_vm
from .vm import VMRegistry
_LOG = get_logger(__name__)
class ChatSession:
def __init__(
self,
user: str = "default",
session: str = "default",
host: str = OLLAMA_HOST,
model: str = MODEL_NAME,
) -> None:
init_db()
self._client = AsyncClient(host=host)
self._model = model
self._user, _ = User.get_or_create(username=user)
self._conversation, _ = Conversation.get_or_create(
user=self._user, session_name=session
)
self._vm = None
self._messages: List[Msg] = self._load_history()
self._lock = asyncio.Lock()
self._state = "idle"
self._tool_task: asyncio.Task | None = None
async def __aenter__(self) -> "ChatSession":
self._vm = VMRegistry.acquire(self._user.username)
set_vm(self._vm)
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
set_vm(None)
if self._vm:
VMRegistry.release(self._user.username)
if not _db.is_closed():
_db.close()
def upload_document(self, file_path: str) -> str:
"""Save a document for later access inside the VM.
The file is copied into ``UPLOAD_DIR`` and recorded in the database. The
returned path is the location inside the VM (prefixed with ``/data``).
"""
src = Path(file_path)
if not src.exists():
raise FileNotFoundError(file_path)
dest = Path(UPLOAD_DIR) / self._user.username
dest.mkdir(parents=True, exist_ok=True)
target = dest / src.name
shutil.copy(src, target)
add_document(self._user.username, str(target), src.name)
return f"/data/{src.name}"
def _load_history(self) -> List[Msg]:
messages: List[Msg] = []
for msg in self._conversation.messages.order_by(DBMessage.created_at):
if msg.role == "system":
# Skip persisted system prompts from older versions
continue
if msg.role == "assistant":
try:
calls = json.loads(msg.content)
except json.JSONDecodeError:
messages.append({"role": "assistant", "content": msg.content})
else:
messages.append(
{
"role": "assistant",
"tool_calls": [Message.ToolCall(**c) for c in calls],
}
)
elif msg.role == "user":
messages.append({"role": "user", "content": msg.content})
else:
messages.append({"role": "tool", "content": msg.content})
return messages
@staticmethod
def _store_assistant_message(conversation: Conversation, message: Message) -> None:
"""Persist assistant messages, storing tool calls when present."""
if message.tool_calls:
content = json.dumps([c.model_dump() for c in message.tool_calls])
else:
content = message.content or ""
DBMessage.create(conversation=conversation, role="assistant", content=content)
async def ask(self, messages: List[Msg], *, think: bool = True) -> ChatResponse:
"""Send a chat request, automatically prepending the system prompt."""
if not messages or messages[0].get("role") != "system":
payload = [{"role": "system", "content": SYSTEM_PROMPT}, *messages]
else:
payload = messages
return await self._client.chat(
self._model,
messages=payload,
think=think,
tools=[execute_terminal],
options={"num_ctx": NUM_CTX},
)
async def _handle_tool_calls_stream(
self,
messages: List[Msg],
response: ChatResponse,
conversation: Conversation,
depth: int = 0,
) -> AsyncIterator[ChatResponse]:
while depth < MAX_TOOL_CALL_DEPTH and response.message.tool_calls:
for call in response.message.tool_calls:
if call.function.name != "execute_terminal":
_LOG.warning("Unsupported tool call: %s", call.function.name)
result = f"Unsupported tool: {call.function.name}"
messages.append(
{
"role": "tool",
"name": call.function.name,
"content": result,
}
)
DBMessage.create(
conversation=conversation,
role="tool",
content=result,
)
continue
exec_task = asyncio.create_task(
execute_terminal_async(**call.function.arguments)
)
follow_task = asyncio.create_task(self.ask(messages, think=True))
async with self._lock:
self._state = "awaiting_tool"
self._tool_task = exec_task
done, _ = await asyncio.wait(
{exec_task, follow_task},
return_when=asyncio.FIRST_COMPLETED,
)
if exec_task in done:
follow_task.cancel()
try:
await follow_task
except asyncio.CancelledError:
pass
result = await exec_task
messages.append(
{
"role": "tool",
"name": call.function.name,
"content": result,
}
)
DBMessage.create(
conversation=conversation,
role="tool",
content=result,
)
async with self._lock:
self._state = "generating"
self._tool_task = None
nxt = await self.ask(messages, think=True)
self._store_assistant_message(conversation, nxt.message)
messages.append(nxt.message.model_dump())
response = nxt
yield nxt
else:
followup = await follow_task
self._store_assistant_message(conversation, followup.message)
messages.append(followup.message.model_dump())
yield followup
result = await exec_task
messages.append(
{
"role": "tool",
"name": call.function.name,
"content": result,
}
)
DBMessage.create(
conversation=conversation,
role="tool",
content=result,
)
async with self._lock:
self._state = "generating"
self._tool_task = None
nxt = await self.ask(messages, think=True)
self._store_assistant_message(conversation, nxt.message)
messages.append(nxt.message.model_dump())
response = nxt
yield nxt
depth += 1
async with self._lock:
self._state = "idle"
async def _handle_tool_calls(
self,
messages: List[Msg],
response: ChatResponse,
conversation: Conversation,
depth: int = 0,
) -> ChatResponse:
final = response
gen = self._handle_tool_calls_stream(messages, response, conversation, depth)
async for final in gen:
pass
return final
async def chat(self, prompt: str) -> str:
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
self._messages.append({"role": "user", "content": prompt})
response = await self.ask(self._messages)
self._messages.append(response.message.model_dump())
self._store_assistant_message(self._conversation, response.message)
_LOG.info("Thinking:\n%s", response.message.thinking or "<no thinking trace>")
final_resp = await self._handle_tool_calls(
self._messages, response, self._conversation
)
return final_resp.message.content
async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
async with self._lock:
if self._state == "generating":
_LOG.info("Ignoring message while generating")
return
if self._state == "awaiting_tool" and self._tool_task:
async for part in self._chat_during_tool(prompt):
yield part
return
self._state = "generating"
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
self._messages.append({"role": "user", "content": prompt})
response = await self.ask(self._messages)
self._messages.append(response.message.model_dump())
self._store_assistant_message(self._conversation, response.message)
_LOG.info("Thinking:\n%s", response.message.thinking or "<no thinking trace>")
async for resp in self._handle_tool_calls_stream(
self._messages, response, self._conversation
):
if resp.message.tool_calls:
continue
if resp.message.content:
yield resp.message.content
async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
self._messages.append({"role": "user", "content": prompt})
user_task = asyncio.create_task(self.ask(self._messages))
exec_task = self._tool_task
done, _ = await asyncio.wait(
{exec_task, user_task},
return_when=asyncio.FIRST_COMPLETED,
)
if exec_task in done:
user_task.cancel()
try:
await user_task
except asyncio.CancelledError:
pass
result = await exec_task
self._tool_task = None
self._messages.append(
{"role": "tool", "name": "execute_terminal", "content": result}
)
DBMessage.create(
conversation=self._conversation, role="tool", content=result
)
async with self._lock:
self._state = "generating"
nxt = await self.ask(self._messages, think=True)
self._store_assistant_message(self._conversation, nxt.message)
self._messages.append(nxt.message.model_dump())
if not nxt.message.tool_calls and nxt.message.content:
yield nxt.message.content
async for part in self._handle_tool_calls_stream(
self._messages, nxt, self._conversation
):
if part.message.tool_calls:
continue
if part.message.content:
yield part.message.content
else:
resp = await user_task
self._store_assistant_message(self._conversation, resp.message)
self._messages.append(resp.message.model_dump())
async with self._lock:
self._state = "awaiting_tool"
if not resp.message.tool_calls and resp.message.content:
yield resp.message.content
result = await exec_task
self._tool_task = None
self._messages.append(
{"role": "tool", "name": "execute_terminal", "content": result}
)
DBMessage.create(
conversation=self._conversation, role="tool", content=result
)
async with self._lock:
self._state = "generating"
nxt = await self.ask(self._messages, think=True)
self._store_assistant_message(self._conversation, nxt.message)
self._messages.append(nxt.message.model_dump())
if not nxt.message.tool_calls and nxt.message.content:
yield nxt.message.content
async for part in self._handle_tool_calls_stream(
self._messages, nxt, self._conversation
):
if part.message.tool_calls:
continue
if part.message.content:
yield part.message.content