File size: 4,238 Bytes
f7c8c98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598a53d
9728a79
598a53d
f7c8c98
 
598a53d
e6ecb98
f7c8c98
 
598a53d
 
 
f7c8c98
 
 
 
 
 
 
 
 
 
e6ecb98
f7c8c98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6ecb98
 
 
9728a79
 
 
 
e6ecb98
f7c8c98
 
9728a79
f7c8c98
 
8a3f6bd
 
e6ecb98
8a3f6bd
 
 
 
 
 
 
e6ecb98
8a3f6bd
9728a79
 
8a3f6bd
 
 
 
 
f7c8c98
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import annotations

import asyncio
from typing import AsyncIterator, Optional

from .chat import ChatSession
from .config import OLLAMA_HOST, MODEL_NAME, SYSTEM_PROMPT, JUNIOR_PROMPT
from .tools import execute_terminal
from .db import Message as DBMessage

__all__ = [
    "TeamChatSession",
    "send_to_junior",
    "send_to_junior_async",
    "set_team",
]

_TEAM: Optional["TeamChatSession"] = None


def set_team(team: "TeamChatSession" | None) -> None:
    global _TEAM
    _TEAM = team


async def send_to_junior(message: str) -> str:
    """Forward ``message`` to the junior agent and await the response."""

    if _TEAM is None:
        return "No active team"

    return await _TEAM.queue_message_to_junior(message, enqueue=False)


# Backwards compatibility ---------------------------------------------------

send_to_junior_async = send_to_junior


class TeamChatSession:
    def __init__(
        self,
        user: str = "default",
        session: str = "default",
        host: str = OLLAMA_HOST,
        model: str = MODEL_NAME,
    ) -> None:
        self._to_junior: asyncio.Queue[tuple[str, asyncio.Future[str], bool]] = asyncio.Queue()
        self._to_senior: asyncio.Queue[str] = asyncio.Queue()
        self._junior_task: asyncio.Task | None = None
        self.senior = ChatSession(
            user=user,
            session=session,
            host=host,
            model=model,
            system_prompt=SYSTEM_PROMPT,
            tools=[execute_terminal, send_to_junior],
        )
        self.junior = ChatSession(
            user=user,
            session=f"{session}-junior",
            host=host,
            model=model,
            system_prompt=JUNIOR_PROMPT,
            tools=[execute_terminal],
        )

    async def __aenter__(self) -> "TeamChatSession":
        await self.senior.__aenter__()
        await self.junior.__aenter__()
        set_team(self)
        return self

    async def __aexit__(self, exc_type, exc, tb) -> None:
        set_team(None)
        await self.senior.__aexit__(exc_type, exc, tb)
        await self.junior.__aexit__(exc_type, exc, tb)

    def upload_document(self, file_path: str) -> str:
        return self.senior.upload_document(file_path)

    async def queue_message_to_junior(
        self, message: str, *, enqueue: bool = True
    ) -> str:
        """Send ``message`` to the junior agent and wait for the reply."""

        loop = asyncio.get_running_loop()
        fut: asyncio.Future[str] = loop.create_future()
        await self._to_junior.put((message, fut, enqueue))
        if not self._junior_task or self._junior_task.done():
            self._junior_task = asyncio.create_task(self._process_junior())
        return await fut

    async def _process_junior(self) -> None:
        try:
            while not self._to_junior.empty():
                msg, fut, enqueue = await self._to_junior.get()
                self.junior._messages.append({"role": "tool", "name": "senior", "content": msg})
                DBMessage.create(conversation=self.junior._conversation, role="tool", content=msg)
                parts: list[str] = []
                async for part in self.junior.continue_stream():
                    if part:
                        parts.append(part)
                result = "\n".join(parts)
                if enqueue and result.strip():
                    await self._to_senior.put(result)
                if not fut.done():
                    fut.set_result(result)

            if self.senior._state == "idle":
                await self._deliver_junior_messages()
        finally:
            self._junior_task = None

    async def _deliver_junior_messages(self) -> None:
        while not self._to_senior.empty():
            msg = await self._to_senior.get()
            self.senior._messages.append({"role": "tool", "name": "junior", "content": msg})
            DBMessage.create(conversation=self.senior._conversation, role="tool", content=msg)

    async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
        await self._deliver_junior_messages()
        async for part in self.senior.chat_stream(prompt):
            yield part
        await self._deliver_junior_messages()