llmOS-Agent / src /chat.py
starsnatched
Remove unused EMBEDDING_MODEL_NAME from ChatSession initialization
e98e0c1
raw
history blame
8.07 kB
from __future__ import annotations
from typing import List
import json
import asyncio
import shutil
from pathlib import Path
from ollama import AsyncClient, ChatResponse, Message
from .config import (
MAX_TOOL_CALL_DEPTH,
MODEL_NAME,
NUM_CTX,
OLLAMA_HOST,
SYSTEM_PROMPT,
UPLOAD_DIR,
)
from .db import (
Conversation,
Message as DBMessage,
User,
_db,
init_db,
add_document,
)
from .log import get_logger
from .schema import Msg
from .tools import execute_terminal, execute_terminal_async, set_vm
from .vm import VMRegistry
_LOG = get_logger(__name__)
class ChatSession:
def __init__(
self,
user: str = "default",
session: str = "default",
host: str = OLLAMA_HOST,
model: str = MODEL_NAME,
) -> None:
init_db()
self._client = AsyncClient(host=host)
self._model = model
self._user, _ = User.get_or_create(username=user)
self._conversation, _ = Conversation.get_or_create(
user=self._user, session_name=session
)
self._vm = None
self._messages: List[Msg] = self._load_history()
async def __aenter__(self) -> "ChatSession":
self._vm = VMRegistry.acquire(self._user.username)
set_vm(self._vm)
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
set_vm(None)
if self._vm:
VMRegistry.release(self._user.username)
if not _db.is_closed():
_db.close()
def upload_document(self, file_path: str) -> str:
"""Save a document for later access inside the VM.
The file is copied into ``UPLOAD_DIR`` and recorded in the database. The
returned path is the location inside the VM (prefixed with ``/data``).
"""
src = Path(file_path)
if not src.exists():
raise FileNotFoundError(file_path)
dest = Path(UPLOAD_DIR) / self._user.username
dest.mkdir(parents=True, exist_ok=True)
target = dest / src.name
shutil.copy(src, target)
add_document(self._user.username, str(target), src.name)
return f"/data/{src.name}"
def _load_history(self) -> List[Msg]:
messages: List[Msg] = []
for msg in self._conversation.messages.order_by(DBMessage.created_at):
if msg.role == "system":
# Skip persisted system prompts from older versions
continue
if msg.role == "assistant":
try:
calls = json.loads(msg.content)
except json.JSONDecodeError:
messages.append({"role": "assistant", "content": msg.content})
else:
messages.append(
{
"role": "assistant",
"tool_calls": [Message.ToolCall(**c) for c in calls],
}
)
elif msg.role == "user":
messages.append({"role": "user", "content": msg.content})
else:
messages.append({"role": "tool", "content": msg.content})
return messages
@staticmethod
def _store_assistant_message(conversation: Conversation, message: Message) -> None:
"""Persist assistant messages, storing tool calls when present."""
if message.tool_calls:
content = json.dumps([c.model_dump() for c in message.tool_calls])
else:
content = message.content or ""
DBMessage.create(conversation=conversation, role="assistant", content=content)
async def ask(self, messages: List[Msg], *, think: bool = True) -> ChatResponse:
"""Send a chat request, automatically prepending the system prompt."""
if not messages or messages[0].get("role") != "system":
payload = [{"role": "system", "content": SYSTEM_PROMPT}, *messages]
else:
payload = messages
return await self._client.chat(
self._model,
messages=payload,
think=think,
tools=[execute_terminal],
options={"num_ctx": NUM_CTX},
)
async def _handle_tool_calls(
self,
messages: List[Msg],
response: ChatResponse,
conversation: Conversation,
depth: int = 0,
) -> ChatResponse:
while depth < MAX_TOOL_CALL_DEPTH and response.message.tool_calls:
for call in response.message.tool_calls:
if call.function.name != "execute_terminal":
_LOG.warning("Unsupported tool call: %s", call.function.name)
result = f"Unsupported tool: {call.function.name}"
messages.append(
{
"role": "tool",
"name": call.function.name,
"content": result,
}
)
DBMessage.create(
conversation=conversation,
role="tool",
content=result,
)
continue
exec_task = asyncio.create_task(
execute_terminal_async(**call.function.arguments)
)
follow_task = asyncio.create_task(self.ask(messages, think=True))
done, _ = await asyncio.wait(
{exec_task, follow_task},
return_when=asyncio.FIRST_COMPLETED,
)
if exec_task in done:
follow_task.cancel()
try:
await follow_task
except asyncio.CancelledError:
pass
result = await exec_task
messages.append(
{
"role": "tool",
"name": call.function.name,
"content": result,
}
)
DBMessage.create(
conversation=conversation,
role="tool",
content=result,
)
nxt = await self.ask(messages, think=True)
self._store_assistant_message(conversation, nxt.message)
response = nxt
else:
followup = await follow_task
self._store_assistant_message(conversation, followup.message)
messages.append(followup.message.model_dump())
result = await exec_task
messages.append(
{
"role": "tool",
"name": call.function.name,
"content": result,
}
)
DBMessage.create(
conversation=conversation,
role="tool",
content=result,
)
nxt = await self.ask(messages, think=True)
self._store_assistant_message(conversation, nxt.message)
response = nxt
depth += 1
return response
async def chat(self, prompt: str) -> str:
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
self._messages.append({"role": "user", "content": prompt})
response = await self.ask(self._messages)
self._messages.append(response.message.model_dump())
self._store_assistant_message(self._conversation, response.message)
_LOG.info("Thinking:\n%s", response.message.thinking or "<no thinking trace>")
final_resp = await self._handle_tool_calls(
self._messages, response, self._conversation
)
return final_resp.message.content