tech-envision commited on
Commit
f7c8c98
·
1 Parent(s): c8dee25

Add multi-agent team with communication

Browse files
Files changed (7) hide show
  1. api_app/__init__.py +3 -3
  2. bot/discord_bot.py +3 -3
  3. run.py +2 -2
  4. src/__init__.py +5 -0
  5. src/chat.py +43 -6
  6. src/config.py +27 -20
  7. src/team.py +107 -0
api_app/__init__.py CHANGED
@@ -9,7 +9,7 @@ import os
9
  import tempfile
10
  from pathlib import Path
11
 
12
- from src.chat import ChatSession
13
  from src.log import get_logger
14
  from src.db import list_sessions, list_sessions_info
15
 
@@ -29,7 +29,7 @@ def create_app() -> FastAPI:
29
  @app.post("/chat/stream")
30
  async def chat_stream(req: ChatRequest):
31
  async def stream() -> asyncio.AsyncIterator[str]:
32
- async with ChatSession(user=req.user, session=req.session) as chat:
33
  try:
34
  async for part in chat.chat_stream(req.prompt):
35
  yield part
@@ -45,7 +45,7 @@ def create_app() -> FastAPI:
45
  session: str = Form("default"),
46
  file: UploadFile = File(...),
47
  ):
48
- async with ChatSession(user=user, session=session) as chat:
49
  tmpdir = tempfile.mkdtemp(prefix="upload_")
50
  tmp_path = Path(tmpdir) / file.filename
51
  try:
 
9
  import tempfile
10
  from pathlib import Path
11
 
12
+ from src.team import TeamChatSession
13
  from src.log import get_logger
14
  from src.db import list_sessions, list_sessions_info
15
 
 
29
  @app.post("/chat/stream")
30
  async def chat_stream(req: ChatRequest):
31
  async def stream() -> asyncio.AsyncIterator[str]:
32
+ async with TeamChatSession(user=req.user, session=req.session) as chat:
33
  try:
34
  async for part in chat.chat_stream(req.prompt):
35
  yield part
 
45
  session: str = Form("default"),
46
  file: UploadFile = File(...),
47
  ):
48
+ async with TeamChatSession(user=user, session=session) as chat:
49
  tmpdir = tempfile.mkdtemp(prefix="upload_")
50
  tmp_path = Path(tmpdir) / file.filename
51
  try:
bot/discord_bot.py CHANGED
@@ -8,7 +8,7 @@ import discord
8
  from discord.ext import commands
9
  from dotenv import load_dotenv
10
 
11
- from src.chat import ChatSession
12
  from src.db import reset_history
13
  from src.log import get_logger
14
 
@@ -34,7 +34,7 @@ async def reset(ctx: commands.Context) -> None:
34
  await ctx.reply(f"Chat history cleared ({deleted} messages deleted).")
35
 
36
 
37
- async def _handle_attachments(chat: ChatSession, message: discord.Message) -> list[tuple[str, str]]:
38
  if not message.attachments:
39
  return []
40
 
@@ -61,7 +61,7 @@ async def on_message(message: discord.Message) -> None:
61
  if message.content.startswith("!"):
62
  return
63
 
64
- async with ChatSession(user=str(message.author.id), session=str(message.channel.id)) as chat:
65
  docs = await _handle_attachments(chat, message)
66
  if docs:
67
  info = "\n".join(f"{name} -> {path}" for name, path in docs)
 
8
  from discord.ext import commands
9
  from dotenv import load_dotenv
10
 
11
+ from src.team import TeamChatSession
12
  from src.db import reset_history
13
  from src.log import get_logger
14
 
 
34
  await ctx.reply(f"Chat history cleared ({deleted} messages deleted).")
35
 
36
 
37
+ async def _handle_attachments(chat: TeamChatSession, message: discord.Message) -> list[tuple[str, str]]:
38
  if not message.attachments:
39
  return []
40
 
 
61
  if message.content.startswith("!"):
62
  return
63
 
64
+ async with TeamChatSession(user=str(message.author.id), session=str(message.channel.id)) as chat:
65
  docs = await _handle_attachments(chat, message)
66
  if docs:
67
  info = "\n".join(f"{name} -> {path}" for name, path in docs)
run.py CHANGED
@@ -2,12 +2,12 @@ from __future__ import annotations
2
 
3
  import asyncio
4
 
5
- from src.chat import ChatSession
6
  from src.vm import VMRegistry
7
 
8
 
9
  async def _main() -> None:
10
- async with ChatSession(user="demo_user", session="demo_session") as chat:
11
  # doc_path = chat.upload_document("note.pdf")
12
  async for resp in chat.chat_stream("using python, execute a code to remind me in 30 seconds to take a break."):
13
  print("\n>>>", resp)
 
2
 
3
  import asyncio
4
 
5
+ from src.team import TeamChatSession
6
  from src.vm import VMRegistry
7
 
8
 
9
  async def _main() -> None:
10
+ async with TeamChatSession(user="demo_user", session="demo_session") as chat:
11
  # doc_path = chat.upload_document("note.pdf")
12
  async for resp in chat.chat_stream("using python, execute a code to remind me in 30 seconds to take a break."):
13
  print("\n>>>", resp)
src/__init__.py CHANGED
@@ -1,12 +1,17 @@
1
  from .chat import ChatSession
 
2
  from .tools import execute_terminal, execute_terminal_async, set_vm
3
  from .utils import limit_chars
4
  from .vm import LinuxVM
5
 
6
  __all__ = [
7
  "ChatSession",
 
8
  "execute_terminal",
9
  "execute_terminal_async",
 
 
 
10
  "set_vm",
11
  "LinuxVM",
12
  "limit_chars",
 
1
  from .chat import ChatSession
2
+ from .team import TeamChatSession, send_to_junior, send_to_junior_async, set_team
3
  from .tools import execute_terminal, execute_terminal_async, set_vm
4
  from .utils import limit_chars
5
  from .vm import LinuxVM
6
 
7
  __all__ = [
8
  "ChatSession",
9
+ "TeamChatSession",
10
  "execute_terminal",
11
  "execute_terminal_async",
12
+ "send_to_junior",
13
+ "send_to_junior_async",
14
+ "set_team",
15
  "set_vm",
16
  "LinuxVM",
17
  "limit_chars",
src/chat.py CHANGED
@@ -61,6 +61,9 @@ class ChatSession:
61
  session: str = "default",
62
  host: str = OLLAMA_HOST,
63
  model: str = MODEL_NAME,
 
 
 
64
  ) -> None:
65
  init_db()
66
  self._client = AsyncClient(host=host)
@@ -70,6 +73,10 @@ class ChatSession:
70
  user=self._user, session_name=session
71
  )
72
  self._vm = None
 
 
 
 
73
  self._messages: List[Msg] = self._load_history()
74
  self._data = _get_session_data(self._conversation.id)
75
  self._lock = self._data.lock
@@ -190,7 +197,7 @@ class ChatSession:
190
  """Send a chat request, automatically prepending the system prompt."""
191
 
192
  if not messages or messages[0].get("role") != "system":
193
- payload = [{"role": "system", "content": SYSTEM_PROMPT}, *messages]
194
  else:
195
  payload = messages
196
 
@@ -198,10 +205,16 @@ class ChatSession:
198
  self._model,
199
  messages=payload,
200
  think=think,
201
- tools=[execute_terminal],
202
  options={"num_ctx": NUM_CTX},
203
  )
204
 
 
 
 
 
 
 
205
  async def _handle_tool_calls_stream(
206
  self,
207
  messages: List[Msg],
@@ -217,7 +230,8 @@ class ChatSession:
217
  return
218
  while depth < MAX_TOOL_CALL_DEPTH and response.message.tool_calls:
219
  for call in response.message.tool_calls:
220
- if call.function.name != "execute_terminal":
 
221
  _LOG.warning("Unsupported tool call: %s", call.function.name)
222
  result = f"Unsupported tool: {call.function.name}"
223
  messages.append(
@@ -235,9 +249,11 @@ class ChatSession:
235
  continue
236
 
237
  exec_task = asyncio.create_task(
238
- execute_terminal_async(**call.function.arguments)
239
  )
240
 
 
 
241
  placeholder = {
242
  "role": "tool",
243
  "name": call.function.name,
@@ -343,6 +359,23 @@ class ChatSession:
343
  if text:
344
  yield text
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
347
  DBMessage.create(conversation=self._conversation, role="user", content=prompt)
348
  self._messages.append({"role": "user", "content": prompt})
@@ -364,8 +397,10 @@ class ChatSession:
364
  self._remove_tool_placeholder(self._messages)
365
  result = await exec_task
366
  self._tool_task = None
 
 
367
  self._messages.append(
368
- {"role": "tool", "name": "execute_terminal", "content": result}
369
  )
370
  DBMessage.create(
371
  conversation=self._conversation, role="tool", content=result
@@ -396,8 +431,10 @@ class ChatSession:
396
  result = await exec_task
397
  self._tool_task = None
398
  self._remove_tool_placeholder(self._messages)
 
 
399
  self._messages.append(
400
- {"role": "tool", "name": "execute_terminal", "content": result}
401
  )
402
  DBMessage.create(
403
  conversation=self._conversation, role="tool", content=result
 
61
  session: str = "default",
62
  host: str = OLLAMA_HOST,
63
  model: str = MODEL_NAME,
64
+ *,
65
+ system_prompt: str = SYSTEM_PROMPT,
66
+ tools: list[callable] | None = None,
67
  ) -> None:
68
  init_db()
69
  self._client = AsyncClient(host=host)
 
73
  user=self._user, session_name=session
74
  )
75
  self._vm = None
76
+ self._system_prompt = system_prompt
77
+ self._tools = tools or [execute_terminal]
78
+ self._tool_funcs = {func.__name__: func for func in self._tools}
79
+ self._current_tool_name: str | None = None
80
  self._messages: List[Msg] = self._load_history()
81
  self._data = _get_session_data(self._conversation.id)
82
  self._lock = self._data.lock
 
197
  """Send a chat request, automatically prepending the system prompt."""
198
 
199
  if not messages or messages[0].get("role") != "system":
200
+ payload = [{"role": "system", "content": self._system_prompt}, *messages]
201
  else:
202
  payload = messages
203
 
 
205
  self._model,
206
  messages=payload,
207
  think=think,
208
+ tools=self._tools,
209
  options={"num_ctx": NUM_CTX},
210
  )
211
 
212
+ async def _run_tool_async(self, func, **kwargs) -> str:
213
+ if asyncio.iscoroutinefunction(func):
214
+ return await func(**kwargs)
215
+ loop = asyncio.get_running_loop()
216
+ return await loop.run_in_executor(None, lambda: func(**kwargs))
217
+
218
  async def _handle_tool_calls_stream(
219
  self,
220
  messages: List[Msg],
 
230
  return
231
  while depth < MAX_TOOL_CALL_DEPTH and response.message.tool_calls:
232
  for call in response.message.tool_calls:
233
+ func = self._tool_funcs.get(call.function.name)
234
+ if not func:
235
  _LOG.warning("Unsupported tool call: %s", call.function.name)
236
  result = f"Unsupported tool: {call.function.name}"
237
  messages.append(
 
249
  continue
250
 
251
  exec_task = asyncio.create_task(
252
+ self._run_tool_async(func, **call.function.arguments)
253
  )
254
 
255
+ self._current_tool_name = call.function.name
256
+
257
  placeholder = {
258
  "role": "tool",
259
  "name": call.function.name,
 
359
  if text:
360
  yield text
361
 
362
+ async def continue_stream(self) -> AsyncIterator[str]:
363
+ async with self._lock:
364
+ if self._state != "idle":
365
+ return
366
+ self._state = "generating"
367
+
368
+ response = await self.ask(self._messages)
369
+ self._messages.append(response.message.model_dump())
370
+ self._store_assistant_message(self._conversation, response.message)
371
+
372
+ async for resp in self._handle_tool_calls_stream(
373
+ self._messages, response, self._conversation
374
+ ):
375
+ text = self._format_output(resp.message)
376
+ if text:
377
+ yield text
378
+
379
  async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
380
  DBMessage.create(conversation=self._conversation, role="user", content=prompt)
381
  self._messages.append({"role": "user", "content": prompt})
 
397
  self._remove_tool_placeholder(self._messages)
398
  result = await exec_task
399
  self._tool_task = None
400
+ name = self._current_tool_name or "tool"
401
+ self._current_tool_name = None
402
  self._messages.append(
403
+ {"role": "tool", "name": name, "content": result}
404
  )
405
  DBMessage.create(
406
  conversation=self._conversation, role="tool", content=result
 
431
  result = await exec_task
432
  self._tool_task = None
433
  self._remove_tool_placeholder(self._messages)
434
+ name = self._current_tool_name or "tool"
435
+ self._current_tool_name = None
436
  self._messages.append(
437
+ {"role": "tool", "name": name, "content": result}
438
  )
439
  DBMessage.create(
440
  conversation=self._conversation, role="tool", content=result
src/config.py CHANGED
@@ -18,27 +18,34 @@ VM_STATE_DIR: Final[str] = os.getenv(
18
  )
19
 
20
  SYSTEM_PROMPT: Final[str] = (
21
- "You are Starlette, a professional AI assistant with advanced tool orchestration. "
 
 
 
 
 
22
  "You were developed by Envision to assist users with a wide range of tasks. "
23
  "Always analyze the user's objective before responding. If tools are needed, "
24
- "outline a step-by-step plan and invoke each tool sequentially. Use "
25
- "execute_terminal with its built-in Python whenever possible to perform "
26
  "calculations, inspect files and search the web. Shell commands execute "
27
- "asynchronously, so provide a brief interim reply while waiting. Once a tool "
28
- "returns its result you will receive a tool message and must continue from "
29
- "there. If the result arrives before your interim reply is complete, cancel the "
30
- "reply and incorporate the tool output instead. Uploaded files live under /data "
31
- "and are accessible via the execute_terminal tool. When a user prompt ends with "
32
- "'/think', ignore that suffix. When you are unsure about any detail, use "
33
- "execute_terminal to search the internet or inspect files before answering. "
34
- "Continue using tools until you have gathered everything required to produce "
35
- "an accurate answer, then craft a clear and precise final response that fully "
36
- "addresses the request. Always assume the user has no knowledge of computers "
37
- "or programming, so take the initiative to run terminal commands yourself and "
38
- "minimize the steps the user must perform. When replying, avoid technical "
39
- "jargon entirely. Speak in plain language that anyone can understand, "
40
- "explaining concepts as simply as possible. Remember, you must always "
41
- "prioritize using execute_terminal tool for everything unless it is "
42
- "absolutely unnecessary or impossible to do so. Even if you have executed a command before, "
43
- "always re-run it to ensure you have the most up-to-date information upon user request."
 
 
44
  ).strip()
 
18
  )
19
 
20
  SYSTEM_PROMPT: Final[str] = (
21
+ "You are Starlette, the senior agent leading a two-agent team. "
22
+ "A junior agent named Starlette Jr. assists you but never speaks to the user. "
23
+ "Use the send_to_junior tool whenever you want the junior's help. "
24
+ "Messages from the junior arrive as tool outputs named 'junior'. "
25
+ "Handle them when you are not actively generating so replies are never interrupted. "
26
+ "Both agents operate asynchronously and communicate through queued messages. "
27
  "You were developed by Envision to assist users with a wide range of tasks. "
28
  "Always analyze the user's objective before responding. If tools are needed, "
29
+ "outline a step-by-step plan and invoke each tool sequentially. "
30
+ "Use execute_terminal with its built-in Python whenever possible to perform "
31
  "calculations, inspect files and search the web. Shell commands execute "
32
+ "asynchronously, so provide a brief interim reply while waiting. "
33
+ "Once a tool returns its result you will receive a tool message and must continue from there. "
34
+ "If the result arrives before your interim reply is complete, cancel the reply and incorporate the tool output instead. "
35
+ "Uploaded files live under /data and are accessible via the execute_terminal tool. "
36
+ "When a user prompt ends with '/think', ignore that suffix. "
37
+ "When you are unsure about any detail, use execute_terminal to search the internet or inspect files before answering. "
38
+ "Continue using tools until you have gathered everything required to produce an accurate answer, "
39
+ "then craft a clear and precise final response that fully addresses the request. "
40
+ "Always assume the user has no knowledge of computers or programming, so take the initiative to run terminal commands yourself and minimize the steps the user must perform. "
41
+ "When replying, avoid technical jargon entirely. Speak in plain language that anyone can understand, explaining concepts as simply as possible. "
42
+ "Remember, you must always prioritize using execute_terminal tool for everything unless it is absolutely unnecessary or impossible to do so. "
43
+ "Even if you have executed a command before, always re-run it to ensure you have the most up-to-date information upon user request."
44
+ ).strip()
45
+
46
+ JUNIOR_PROMPT: Final[str] = (
47
+ "You are Starlette Jr., a junior assistant working under the senior agent. "
48
+ "You never communicate with the user directly. All messages from the senior agent "
49
+ "arrive as tool outputs named 'senior'. Provide concise, helpful responses and "
50
+ "use execute_terminal whenever necessary. Your replies are sent back to the senior agent as tool messages."
51
  ).strip()
src/team.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from typing import AsyncIterator, Optional
5
+
6
+ from .chat import ChatSession
7
+ from .config import OLLAMA_HOST, MODEL_NAME, SYSTEM_PROMPT, JUNIOR_PROMPT
8
+ from .tools import execute_terminal
9
+ from .db import Message as DBMessage
10
+
11
+ __all__ = [
12
+ "TeamChatSession",
13
+ "send_to_junior",
14
+ "send_to_junior_async",
15
+ "set_team",
16
+ ]
17
+
18
+ _TEAM: Optional["TeamChatSession"] = None
19
+
20
+
21
+ def set_team(team: "TeamChatSession" | None) -> None:
22
+ global _TEAM
23
+ _TEAM = team
24
+
25
+
26
+ def send_to_junior(message: str) -> str:
27
+ if _TEAM is None:
28
+ return "No active team"
29
+ _TEAM.queue_message_to_junior(message)
30
+ return "Message sent to junior"
31
+
32
+
33
+ async def send_to_junior_async(message: str) -> str:
34
+ return send_to_junior(message)
35
+
36
+
37
+ class TeamChatSession:
38
+ def __init__(
39
+ self,
40
+ user: str = "default",
41
+ session: str = "default",
42
+ host: str = OLLAMA_HOST,
43
+ model: str = MODEL_NAME,
44
+ ) -> None:
45
+ self._to_junior: asyncio.Queue[str] = asyncio.Queue()
46
+ self._to_senior: asyncio.Queue[str] = asyncio.Queue()
47
+ self._junior_task: asyncio.Task | None = None
48
+ self.senior = ChatSession(
49
+ user=user,
50
+ session=session,
51
+ host=host,
52
+ model=model,
53
+ system_prompt=SYSTEM_PROMPT,
54
+ tools=[execute_terminal, send_to_junior],
55
+ )
56
+ self.junior = ChatSession(
57
+ user=user,
58
+ session=f"{session}-junior",
59
+ host=host,
60
+ model=model,
61
+ system_prompt=JUNIOR_PROMPT,
62
+ tools=[execute_terminal],
63
+ )
64
+
65
+ async def __aenter__(self) -> "TeamChatSession":
66
+ await self.senior.__aenter__()
67
+ await self.junior.__aenter__()
68
+ set_team(self)
69
+ return self
70
+
71
+ async def __aexit__(self, exc_type, exc, tb) -> None:
72
+ set_team(None)
73
+ await self.senior.__aexit__(exc_type, exc, tb)
74
+ await self.junior.__aexit__(exc_type, exc, tb)
75
+
76
+ def upload_document(self, file_path: str) -> str:
77
+ return self.senior.upload_document(file_path)
78
+
79
+ def queue_message_to_junior(self, message: str) -> None:
80
+ self._to_junior.put_nowait(message)
81
+ if not self._junior_task or self._junior_task.done():
82
+ self._junior_task = asyncio.create_task(self._process_junior())
83
+
84
+ async def _process_junior(self) -> None:
85
+ while not self._to_junior.empty():
86
+ msg = await self._to_junior.get()
87
+ self.junior._messages.append({"role": "tool", "name": "senior", "content": msg})
88
+ DBMessage.create(conversation=self.junior._conversation, role="tool", content=msg)
89
+ parts = []
90
+ async for part in self.junior.continue_stream():
91
+ if part:
92
+ parts.append(part)
93
+ result = "\n".join(parts)
94
+ if result.strip():
95
+ await self._to_senior.put(result)
96
+
97
+ async def _deliver_junior_messages(self) -> None:
98
+ while not self._to_senior.empty():
99
+ msg = await self._to_senior.get()
100
+ self.senior._messages.append({"role": "tool", "name": "junior", "content": msg})
101
+ DBMessage.create(conversation=self.senior._conversation, role="tool", content=msg)
102
+
103
+ async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
104
+ await self._deliver_junior_messages()
105
+ async for part in self.senior.chat_stream(prompt):
106
+ yield part
107
+ await self._deliver_junior_messages()