Spaces:
Runtime error
Runtime error
File size: 4,238 Bytes
f7c8c98 598a53d 9728a79 598a53d f7c8c98 598a53d e6ecb98 f7c8c98 598a53d f7c8c98 e6ecb98 f7c8c98 e6ecb98 9728a79 e6ecb98 f7c8c98 9728a79 f7c8c98 8a3f6bd e6ecb98 8a3f6bd e6ecb98 8a3f6bd 9728a79 8a3f6bd f7c8c98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from __future__ import annotations
import asyncio
from typing import AsyncIterator, Optional
from .chat import ChatSession
from .config import OLLAMA_HOST, MODEL_NAME, SYSTEM_PROMPT, JUNIOR_PROMPT
from .tools import execute_terminal
from .db import Message as DBMessage
__all__ = [
"TeamChatSession",
"send_to_junior",
"send_to_junior_async",
"set_team",
]
_TEAM: Optional["TeamChatSession"] = None
def set_team(team: "TeamChatSession" | None) -> None:
global _TEAM
_TEAM = team
async def send_to_junior(message: str) -> str:
"""Forward ``message`` to the junior agent and await the response."""
if _TEAM is None:
return "No active team"
return await _TEAM.queue_message_to_junior(message, enqueue=False)
# Backwards compatibility ---------------------------------------------------
send_to_junior_async = send_to_junior
class TeamChatSession:
def __init__(
self,
user: str = "default",
session: str = "default",
host: str = OLLAMA_HOST,
model: str = MODEL_NAME,
) -> None:
self._to_junior: asyncio.Queue[tuple[str, asyncio.Future[str], bool]] = asyncio.Queue()
self._to_senior: asyncio.Queue[str] = asyncio.Queue()
self._junior_task: asyncio.Task | None = None
self.senior = ChatSession(
user=user,
session=session,
host=host,
model=model,
system_prompt=SYSTEM_PROMPT,
tools=[execute_terminal, send_to_junior],
)
self.junior = ChatSession(
user=user,
session=f"{session}-junior",
host=host,
model=model,
system_prompt=JUNIOR_PROMPT,
tools=[execute_terminal],
)
async def __aenter__(self) -> "TeamChatSession":
await self.senior.__aenter__()
await self.junior.__aenter__()
set_team(self)
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
set_team(None)
await self.senior.__aexit__(exc_type, exc, tb)
await self.junior.__aexit__(exc_type, exc, tb)
def upload_document(self, file_path: str) -> str:
return self.senior.upload_document(file_path)
async def queue_message_to_junior(
self, message: str, *, enqueue: bool = True
) -> str:
"""Send ``message`` to the junior agent and wait for the reply."""
loop = asyncio.get_running_loop()
fut: asyncio.Future[str] = loop.create_future()
await self._to_junior.put((message, fut, enqueue))
if not self._junior_task or self._junior_task.done():
self._junior_task = asyncio.create_task(self._process_junior())
return await fut
async def _process_junior(self) -> None:
try:
while not self._to_junior.empty():
msg, fut, enqueue = await self._to_junior.get()
self.junior._messages.append({"role": "tool", "name": "senior", "content": msg})
DBMessage.create(conversation=self.junior._conversation, role="tool", content=msg)
parts: list[str] = []
async for part in self.junior.continue_stream():
if part:
parts.append(part)
result = "\n".join(parts)
if enqueue and result.strip():
await self._to_senior.put(result)
if not fut.done():
fut.set_result(result)
if self.senior._state == "idle":
await self._deliver_junior_messages()
finally:
self._junior_task = None
async def _deliver_junior_messages(self) -> None:
while not self._to_senior.empty():
msg = await self._to_senior.get()
self.senior._messages.append({"role": "tool", "name": "junior", "content": msg})
DBMessage.create(conversation=self.senior._conversation, role="tool", content=msg)
async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
await self._deliver_junior_messages()
async for part in self.senior.chat_stream(prompt):
yield part
await self._deliver_junior_messages()
|