File size: 3,341 Bytes
0e02b97
 
 
ec335c4
0e02b97
 
 
 
bf45c7d
0e02b97
 
 
 
 
 
 
 
bf45c7d
bedb8e2
0e02b97
 
bf45c7d
0e02b97
 
 
 
bedb8e2
 
 
 
ec335c4
 
 
 
 
 
 
 
 
 
 
 
 
0e02b97
 
 
 
 
 
 
 
 
 
 
 
bedb8e2
0e02b97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bedb8e2
 
 
 
 
0e02b97
ec335c4
bedb8e2
 
 
0e02b97
 
 
 
bf45c7d
bedb8e2
0e02b97
 
 
ec335c4
0e02b97
 
 
bedb8e2
0e02b97
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import annotations

from typing import List
import json

from ollama import AsyncClient, ChatResponse

from .config import MAX_TOOL_CALL_DEPTH, MODEL_NAME, OLLAMA_HOST
from .db import Conversation, Message, User, _db, init_db
from .log import get_logger
from .schema import Msg
from .tools import add_two_numbers

_LOG = get_logger(__name__)


class ChatSession:
    def __init__(self, user: str = "default", host: str = OLLAMA_HOST, model: str = MODEL_NAME) -> None:
        init_db()
        self._client = AsyncClient(host=host)
        self._model = model
        self._user, _ = User.get_or_create(username=user)

    async def __aenter__(self) -> "ChatSession":
        return self

    async def __aexit__(self, exc_type, exc, tb) -> None:
        if not _db.is_closed():
            _db.close()

    @staticmethod
    def _store_assistant_message(
        conversation: Conversation, message: ChatResponse.Message
    ) -> None:
        """Persist assistant messages, storing tool calls when present."""

        if message.tool_calls:
            content = json.dumps([c.model_dump() for c in message.tool_calls])
        else:
            content = message.content or ""

        Message.create(conversation=conversation, role="assistant", content=content)

    async def ask(self, messages: List[Msg], *, think: bool = True) -> ChatResponse:
        return await self._client.chat(
            self._model,
            messages=messages,
            think=think,
            tools=[add_two_numbers],
        )

    async def _handle_tool_calls(
        self,
        messages: List[Msg],
        response: ChatResponse,
        conversation: Conversation,
        depth: int = 0,
    ) -> ChatResponse:
        if depth >= MAX_TOOL_CALL_DEPTH or not response.message.tool_calls:
            return response

        for call in response.message.tool_calls:
            if call.function.name == "add_two_numbers":
                result = add_two_numbers(**call.function.arguments)
                messages.append(
                    {
                        "role": "tool",
                        "name": call.function.name,
                        "content": str(result),
                    }
                )
                Message.create(
                    conversation=conversation,
                    role="tool",
                    content=str(result),
                )
                nxt = await self.ask(messages, think=True)
                self._store_assistant_message(conversation, nxt.message)
                return await self._handle_tool_calls(
                    messages, nxt, conversation, depth + 1
                )

        return response

    async def chat(self, prompt: str) -> str:
        conversation = Conversation.create(user=self._user)
        Message.create(conversation=conversation, role="user", content=prompt)
        messages: List[Msg] = [{"role": "user", "content": prompt}]
        response = await self.ask(messages)
        messages.append(response.message.model_dump())
        self._store_assistant_message(conversation, response.message)

        _LOG.info("Thinking:\n%s", response.message.thinking or "<no thinking trace>")

        final_resp = await self._handle_tool_calls(messages, response, conversation)
        return final_resp.message.content