from __future__ import annotations from typing import List import json import shutil from pathlib import Path from ollama import AsyncClient, ChatResponse, Message from .config import ( MAX_TOOL_CALL_DEPTH, MODEL_NAME, EMBEDDING_MODEL_NAME, NUM_CTX, OLLAMA_HOST, SYSTEM_PROMPT, UPLOAD_DIR, ) from .db import ( Conversation, Message as DBMessage, User, _db, init_db, add_document, ) from .log import get_logger from .schema import Msg from .tools import execute_terminal, set_vm from .vm import VMRegistry _LOG = get_logger(__name__) class ChatSession: def __init__( self, user: str = "default", session: str = "default", host: str = OLLAMA_HOST, model: str = MODEL_NAME, embedding_model: str = EMBEDDING_MODEL_NAME, ) -> None: init_db() self._client = AsyncClient(host=host) self._model = model self._user, _ = User.get_or_create(username=user) self._conversation, _ = Conversation.get_or_create( user=self._user, session_name=session ) self._vm = None self._messages: List[Msg] = self._load_history() async def __aenter__(self) -> "ChatSession": self._vm = VMRegistry.acquire(self._user.username) set_vm(self._vm) return self async def __aexit__(self, exc_type, exc, tb) -> None: set_vm(None) if self._vm: VMRegistry.release(self._user.username) if not _db.is_closed(): _db.close() def upload_document(self, file_path: str) -> str: """Save a document for later access inside the VM. The file is copied into ``UPLOAD_DIR`` and recorded in the database. The returned path is the location inside the VM (prefixed with ``/data``). """ src = Path(file_path) if not src.exists(): raise FileNotFoundError(file_path) dest = Path(UPLOAD_DIR) / self._user.username dest.mkdir(parents=True, exist_ok=True) target = dest / src.name shutil.copy(src, target) add_document(self._user.username, str(target), src.name) return f"/data/{src.name}" def _load_history(self) -> List[Msg]: messages: List[Msg] = [] for msg in self._conversation.messages.order_by(DBMessage.created_at): if msg.role == "system": # Skip persisted system prompts from older versions continue if msg.role == "assistant": try: calls = json.loads(msg.content) except json.JSONDecodeError: messages.append({"role": "assistant", "content": msg.content}) else: messages.append( { "role": "assistant", "tool_calls": [Message.ToolCall(**c) for c in calls], } ) elif msg.role == "user": messages.append({"role": "user", "content": msg.content}) else: messages.append({"role": "tool", "content": msg.content}) return messages @staticmethod def _store_assistant_message(conversation: Conversation, message: Message) -> None: """Persist assistant messages, storing tool calls when present.""" if message.tool_calls: content = json.dumps([c.model_dump() for c in message.tool_calls]) else: content = message.content or "" DBMessage.create(conversation=conversation, role="assistant", content=content) async def ask(self, messages: List[Msg], *, think: bool = True) -> ChatResponse: """Send a chat request, automatically prepending the system prompt.""" if not messages or messages[0].get("role") != "system": payload = [{"role": "system", "content": SYSTEM_PROMPT}, *messages] else: payload = messages return await self._client.chat( self._model, messages=payload, think=think, tools=[execute_terminal], options={"num_ctx": NUM_CTX}, ) async def _handle_tool_calls( self, messages: List[Msg], response: ChatResponse, conversation: Conversation, depth: int = 0, ) -> ChatResponse: while depth < MAX_TOOL_CALL_DEPTH and response.message.tool_calls: for call in response.message.tool_calls: if call.function.name == "execute_terminal": result = execute_terminal(**call.function.arguments) else: _LOG.warning("Unsupported tool call: %s", call.function.name) result = f"Unsupported tool: {call.function.name}" messages.append( { "role": "tool", "name": call.function.name, "content": str(result), } ) DBMessage.create( conversation=conversation, role="tool", content=str(result), ) nxt = await self.ask(messages, think=True) self._store_assistant_message(conversation, nxt.message) response = nxt depth += 1 return response async def chat(self, prompt: str) -> str: DBMessage.create(conversation=self._conversation, role="user", content=prompt) self._messages.append({"role": "user", "content": prompt}) response = await self.ask(self._messages) self._messages.append(response.message.model_dump()) self._store_assistant_message(self._conversation, response.message) _LOG.info("Thinking:\n%s", response.message.thinking or "") final_resp = await self._handle_tool_calls( self._messages, response, self._conversation ) return final_resp.message.content