llmOS-Agent / src /chat.py
tech-envision
document session registry
114361d
raw
history blame
14.6 kB
from __future__ import annotations
from typing import List, AsyncIterator
from dataclasses import dataclass, field
import json
import asyncio
import shutil
from pathlib import Path
from ollama import AsyncClient, ChatResponse, Message
from .config import (
MAX_TOOL_CALL_DEPTH,
MODEL_NAME,
NUM_CTX,
OLLAMA_HOST,
SYSTEM_PROMPT,
UPLOAD_DIR,
)
from .db import (
Conversation,
Message as DBMessage,
User,
_db,
init_db,
add_document,
)
from .log import get_logger
from .schema import Msg
from .tools import execute_terminal, execute_terminal_async, set_vm
from .vm import VMRegistry
@dataclass
class _SessionData:
"""Shared state for each conversation session."""
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
state: str = "idle"
tool_task: asyncio.Task | None = None
_SESSION_DATA: dict[int, _SessionData] = {}
def _get_session_data(conv_id: int) -> _SessionData:
data = _SESSION_DATA.get(conv_id)
if data is None:
data = _SessionData()
_SESSION_DATA[conv_id] = data
return data
_LOG = get_logger(__name__)
class ChatSession:
def __init__(
self,
user: str = "default",
session: str = "default",
host: str = OLLAMA_HOST,
model: str = MODEL_NAME,
) -> None:
init_db()
self._client = AsyncClient(host=host)
self._model = model
self._user, _ = User.get_or_create(username=user)
self._conversation, _ = Conversation.get_or_create(
user=self._user, session_name=session
)
self._vm = None
self._messages: List[Msg] = self._load_history()
self._data = _get_session_data(self._conversation.id)
self._lock = self._data.lock
# Shared state properties -------------------------------------------------
@property
def _state(self) -> str:
return self._data.state
@_state.setter
def _state(self, value: str) -> None:
self._data.state = value
@property
def _tool_task(self) -> asyncio.Task | None:
return self._data.tool_task
@_tool_task.setter
def _tool_task(self, task: asyncio.Task | None) -> None:
self._data.tool_task = task
async def __aenter__(self) -> "ChatSession":
self._vm = VMRegistry.acquire(self._user.username)
set_vm(self._vm)
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
set_vm(None)
if self._vm:
VMRegistry.release(self._user.username)
if not _db.is_closed():
_db.close()
def upload_document(self, file_path: str) -> str:
"""Save a document for later access inside the VM.
The file is copied into ``UPLOAD_DIR`` and recorded in the database. The
returned path is the location inside the VM (prefixed with ``/data``).
"""
src = Path(file_path)
if not src.exists():
raise FileNotFoundError(file_path)
dest = Path(UPLOAD_DIR) / self._user.username
dest.mkdir(parents=True, exist_ok=True)
target = dest / src.name
shutil.copy(src, target)
add_document(self._user.username, str(target), src.name)
return f"/data/{src.name}"
def _load_history(self) -> List[Msg]:
messages: List[Msg] = []
for msg in self._conversation.messages.order_by(DBMessage.created_at):
if msg.role == "system":
# Skip persisted system prompts from older versions
continue
if msg.role == "assistant":
try:
calls = json.loads(msg.content)
except json.JSONDecodeError:
messages.append({"role": "assistant", "content": msg.content})
else:
messages.append(
{
"role": "assistant",
"tool_calls": [Message.ToolCall(**c) for c in calls],
}
)
elif msg.role == "user":
messages.append({"role": "user", "content": msg.content})
else:
messages.append({"role": "tool", "content": msg.content})
return messages
@staticmethod
def _store_assistant_message(conversation: Conversation, message: Message) -> None:
"""Persist assistant messages, storing tool calls when present."""
if message.tool_calls:
content = json.dumps([c.model_dump() for c in message.tool_calls])
else:
content = message.content or ""
DBMessage.create(conversation=conversation, role="assistant", content=content)
async def ask(self, messages: List[Msg], *, think: bool = True) -> ChatResponse:
"""Send a chat request, automatically prepending the system prompt."""
if not messages or messages[0].get("role") != "system":
payload = [{"role": "system", "content": SYSTEM_PROMPT}, *messages]
else:
payload = messages
return await self._client.chat(
self._model,
messages=payload,
think=think,
tools=[execute_terminal],
options={"num_ctx": NUM_CTX},
)
async def _handle_tool_calls_stream(
self,
messages: List[Msg],
response: ChatResponse,
conversation: Conversation,
depth: int = 0,
) -> AsyncIterator[ChatResponse]:
if not response.message.tool_calls:
if response.message.content:
yield response
async with self._lock:
self._state = "idle"
return
while depth < MAX_TOOL_CALL_DEPTH and response.message.tool_calls:
for call in response.message.tool_calls:
if call.function.name != "execute_terminal":
_LOG.warning("Unsupported tool call: %s", call.function.name)
result = f"Unsupported tool: {call.function.name}"
messages.append(
{
"role": "tool",
"name": call.function.name,
"content": result,
}
)
DBMessage.create(
conversation=conversation,
role="tool",
content=result,
)
continue
exec_task = asyncio.create_task(
execute_terminal_async(**call.function.arguments)
)
follow_task = asyncio.create_task(self.ask(messages, think=True))
async with self._lock:
self._state = "awaiting_tool"
self._tool_task = exec_task
done, _ = await asyncio.wait(
{exec_task, follow_task},
return_when=asyncio.FIRST_COMPLETED,
)
if exec_task in done:
follow_task.cancel()
try:
await follow_task
except asyncio.CancelledError:
pass
result = await exec_task
messages.append(
{
"role": "tool",
"name": call.function.name,
"content": result,
}
)
DBMessage.create(
conversation=conversation,
role="tool",
content=result,
)
async with self._lock:
self._state = "generating"
self._tool_task = None
nxt = await self.ask(messages, think=True)
self._store_assistant_message(conversation, nxt.message)
messages.append(nxt.message.model_dump())
response = nxt
yield nxt
else:
followup = await follow_task
self._store_assistant_message(conversation, followup.message)
messages.append(followup.message.model_dump())
yield followup
result = await exec_task
messages.append(
{
"role": "tool",
"name": call.function.name,
"content": result,
}
)
DBMessage.create(
conversation=conversation,
role="tool",
content=result,
)
async with self._lock:
self._state = "generating"
self._tool_task = None
nxt = await self.ask(messages, think=True)
self._store_assistant_message(conversation, nxt.message)
messages.append(nxt.message.model_dump())
response = nxt
yield nxt
depth += 1
async with self._lock:
self._state = "idle"
async def _handle_tool_calls(
self,
messages: List[Msg],
response: ChatResponse,
conversation: Conversation,
depth: int = 0,
) -> ChatResponse:
final = response
gen = self._handle_tool_calls_stream(messages, response, conversation, depth)
async for final in gen:
pass
return final
async def chat(self, prompt: str) -> str:
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
self._messages.append({"role": "user", "content": prompt})
response = await self.ask(self._messages)
self._messages.append(response.message.model_dump())
self._store_assistant_message(self._conversation, response.message)
_LOG.info("Thinking:\n%s", response.message.thinking or "<no thinking trace>")
final_resp = await self._handle_tool_calls(
self._messages, response, self._conversation
)
return final_resp.message.content
async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
async with self._lock:
if self._state == "generating":
_LOG.info("Ignoring message while generating")
return
if self._state == "awaiting_tool" and self._tool_task:
async for part in self._chat_during_tool(prompt):
yield part
return
self._state = "generating"
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
self._messages.append({"role": "user", "content": prompt})
response = await self.ask(self._messages)
self._messages.append(response.message.model_dump())
self._store_assistant_message(self._conversation, response.message)
_LOG.info("Thinking:\n%s", response.message.thinking or "<no thinking trace>")
async for resp in self._handle_tool_calls_stream(
self._messages, response, self._conversation
):
if resp.message.tool_calls:
continue
if resp.message.content:
yield resp.message.content
async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
self._messages.append({"role": "user", "content": prompt})
user_task = asyncio.create_task(self.ask(self._messages))
exec_task = self._tool_task
done, _ = await asyncio.wait(
{exec_task, user_task},
return_when=asyncio.FIRST_COMPLETED,
)
if exec_task in done:
user_task.cancel()
try:
await user_task
except asyncio.CancelledError:
pass
result = await exec_task
self._tool_task = None
self._messages.append(
{"role": "tool", "name": "execute_terminal", "content": result}
)
DBMessage.create(
conversation=self._conversation, role="tool", content=result
)
async with self._lock:
self._state = "generating"
nxt = await self.ask(self._messages, think=True)
self._store_assistant_message(self._conversation, nxt.message)
self._messages.append(nxt.message.model_dump())
if not nxt.message.tool_calls and nxt.message.content:
yield nxt.message.content
async for part in self._handle_tool_calls_stream(
self._messages, nxt, self._conversation
):
if part.message.tool_calls:
continue
if part.message.content:
yield part.message.content
else:
resp = await user_task
self._store_assistant_message(self._conversation, resp.message)
self._messages.append(resp.message.model_dump())
async with self._lock:
self._state = "awaiting_tool"
if not resp.message.tool_calls and resp.message.content:
yield resp.message.content
result = await exec_task
self._tool_task = None
self._messages.append(
{"role": "tool", "name": "execute_terminal", "content": result}
)
DBMessage.create(
conversation=self._conversation, role="tool", content=result
)
async with self._lock:
self._state = "generating"
nxt = await self.ask(self._messages, think=True)
self._store_assistant_message(self._conversation, nxt.message)
self._messages.append(nxt.message.model_dump())
if not nxt.message.tool_calls and nxt.message.content:
yield nxt.message.content
async for part in self._handle_tool_calls_stream(
self._messages, nxt, self._conversation
):
if part.message.tool_calls:
continue
if part.message.content:
yield part.message.content