Spaces:
Runtime error
Runtime error
"""Discord bot implementation.""" | |
from __future__ import annotations | |
import logging | |
import os | |
import shutil | |
import tempfile | |
from pathlib import Path | |
from typing import Iterable | |
import discord | |
from discord.ext import commands | |
from dotenv import load_dotenv | |
from src.db import reset_history | |
from src.log import get_logger | |
from src.team import TeamChatSession | |
class DiscordTeamBot(commands.Bot): | |
"""Discord bot for interacting with :class:`TeamChatSession`.""" | |
def __init__(self) -> None: | |
intents = discord.Intents.all() | |
super().__init__(command_prefix="!", intents=intents) | |
self._log = get_logger(__name__, level=logging.INFO) | |
self._register_commands() | |
# ------------------------------------------------------------------ | |
# Lifecycle events | |
# ------------------------------------------------------------------ | |
async def on_ready(self) -> None: # noqa: D401 - callback signature | |
"""Log a message once the bot has connected.""" | |
self._log.info("Logged in as %s", self.user) | |
async def on_message(self, message: discord.Message) -> None: # noqa: D401 - callback signature | |
"""Process incoming messages and stream chat replies.""" | |
if message.author.bot: | |
return | |
await self.process_commands(message) | |
if message.content.startswith("!"): | |
return | |
async with TeamChatSession( | |
user=str(message.author.id), session=str(message.channel.id) | |
) as chat: | |
docs = await self._handle_attachments(chat, message.attachments) | |
if docs: | |
info = "\n".join(f"{name} -> {path}" for name, path in docs) | |
await message.reply(f"Uploaded:\n{info}", mention_author=False) | |
if message.content.strip(): | |
try: | |
async for part in chat.chat_stream(message.content): | |
await message.reply(part, mention_author=False) | |
except Exception as exc: # pragma: no cover - runtime errors | |
self._log.error("Failed to process message: %s", exc) | |
await message.reply(f"Error: {exc}", mention_author=False) | |
# ------------------------------------------------------------------ | |
# Commands | |
# ------------------------------------------------------------------ | |
def _register_commands(self) -> None: | |
async def reset(ctx: commands.Context) -> None: | |
deleted = reset_history(str(ctx.author.id), str(ctx.channel.id)) | |
await ctx.reply( | |
f"Chat history cleared ({deleted} messages deleted).", | |
) | |
# ------------------------------------------------------------------ | |
# Helpers | |
# ------------------------------------------------------------------ | |
async def _handle_attachments( | |
self, chat: TeamChatSession, attachments: Iterable[discord.Attachment] | |
) -> list[tuple[str, str]]: | |
"""Download any attachments and return their VM paths.""" | |
if not attachments: | |
return [] | |
uploaded: list[tuple[str, str]] = [] | |
tmpdir = Path(tempfile.mkdtemp(prefix="discord_upload_")) | |
try: | |
for attachment in attachments: | |
dest = tmpdir / attachment.filename | |
await attachment.save(dest) | |
vm_path = chat.upload_document(str(dest)) | |
uploaded.append((attachment.filename, vm_path)) | |
finally: | |
shutil.rmtree(tmpdir, ignore_errors=True) | |
return uploaded | |
def run_bot(token: str) -> None: | |
"""Create and run the Discord bot.""" | |
DiscordTeamBot().run(token) | |
def main() -> None: | |
"""Load environment and start the bot.""" | |
load_dotenv() | |
token = os.getenv("DISCORD_TOKEN") | |
if not token: | |
raise RuntimeError("DISCORD_TOKEN environment variable not set") | |
run_bot(token) | |
if __name__ == "__main__": # pragma: no cover - manual execution | |
main() | |