File size: 4,019 Bytes
57fa15b
 
 
 
7765906
 
 
 
 
57fa15b
7765906
 
 
 
 
 
 
57fa15b
7765906
 
57fa15b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7765906
 
 
57fa15b
 
7765906
 
 
 
 
57fa15b
7765906
 
57fa15b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Discord bot implementation."""

from __future__ import annotations

import logging
import os
import shutil
import tempfile
from pathlib import Path
from typing import Iterable

import discord
from discord.ext import commands
from dotenv import load_dotenv

from src.db import reset_history
from src.log import get_logger
from src.team import TeamChatSession


class DiscordTeamBot(commands.Bot):
    """Discord bot for interacting with :class:`TeamChatSession`."""

    def __init__(self) -> None:
        intents = discord.Intents.all()
        super().__init__(command_prefix="!", intents=intents)
        self._log = get_logger(__name__, level=logging.INFO)
        self._register_commands()

    # ------------------------------------------------------------------
    # Lifecycle events
    # ------------------------------------------------------------------
    async def on_ready(self) -> None:  # noqa: D401 - callback signature
        """Log a message once the bot has connected."""

        self._log.info("Logged in as %s", self.user)

    async def on_message(self, message: discord.Message) -> None:  # noqa: D401 - callback signature
        """Process incoming messages and stream chat replies."""

        if message.author.bot:
            return

        await self.process_commands(message)
        if message.content.startswith("!"):
            return

        async with TeamChatSession(
            user=str(message.author.id), session=str(message.channel.id)
        ) as chat:
            docs = await self._handle_attachments(chat, message.attachments)
            if docs:
                info = "\n".join(f"{name} -> {path}" for name, path in docs)
                await message.reply(f"Uploaded:\n{info}", mention_author=False)

            if message.content.strip():
                try:
                    async for part in chat.chat_stream(message.content):
                        await message.reply(part, mention_author=False)
                except Exception as exc:  # pragma: no cover - runtime errors
                    self._log.error("Failed to process message: %s", exc)
                    await message.reply(f"Error: {exc}", mention_author=False)

    # ------------------------------------------------------------------
    # Commands
    # ------------------------------------------------------------------
    def _register_commands(self) -> None:
        @self.command(name="reset")
        async def reset(ctx: commands.Context) -> None:
            deleted = reset_history(str(ctx.author.id), str(ctx.channel.id))
            await ctx.reply(
                f"Chat history cleared ({deleted} messages deleted).",
            )

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------
    async def _handle_attachments(
        self, chat: TeamChatSession, attachments: Iterable[discord.Attachment]
    ) -> list[tuple[str, str]]:
        """Download any attachments and return their VM paths."""

        if not attachments:
            return []

        uploaded: list[tuple[str, str]] = []
        tmpdir = Path(tempfile.mkdtemp(prefix="discord_upload_"))
        try:
            for attachment in attachments:
                dest = tmpdir / attachment.filename
                await attachment.save(dest)
                vm_path = chat.upload_document(str(dest))
                uploaded.append((attachment.filename, vm_path))
        finally:
            shutil.rmtree(tmpdir, ignore_errors=True)

        return uploaded


def run_bot(token: str) -> None:
    """Create and run the Discord bot."""

    DiscordTeamBot().run(token)


def main() -> None:
    """Load environment and start the bot."""

    load_dotenv()
    token = os.getenv("DISCORD_TOKEN")
    if not token:
        raise RuntimeError("DISCORD_TOKEN environment variable not set")

    run_bot(token)


if __name__ == "__main__":  # pragma: no cover - manual execution
    main()