File size: 5,399 Bytes
3e5cdc3
 
4f46c78
e75d7f2
6ebba8f
3e5cdc3
 
 
 
 
4f46c78
 
 
c1304d4
 
 
 
 
 
3e5cdc3
f7c8c98
e75d7f2
 
3e5cdc3
 
 
 
 
 
 
 
 
 
 
4f46c78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e5cdc3
 
 
c1304d4
 
 
 
6ebba8f
 
c1304d4
6ebba8f
 
 
 
 
3e5cdc3
 
 
f7c8c98
3e5cdc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7c8c98
3e5cdc3
 
 
 
 
 
 
 
 
 
 
 
 
 
bf3a897
 
 
 
5cbac45
 
 
 
 
 
 
3e5cdc3
 
 
 
4f46c78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e5cdc3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from __future__ import annotations

from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import asyncio
import os
import tempfile
from pathlib import Path
from typing import List
import shutil

from src.config import UPLOAD_DIR, CORS_ORIGINS, RATE_LIMIT
from src.security import (
    APIKeyAuthMiddleware,
    RateLimiterMiddleware,
    SecurityHeadersMiddleware,
)

from src.team import TeamChatSession
from src.log import get_logger
from src.db import list_sessions, list_sessions_info


_LOG = get_logger(__name__)


class ChatRequest(BaseModel):
    user: str = "default"
    session: str = "default"
    prompt: str


class FileWriteRequest(BaseModel):
    path: str
    content: str


def _vm_host_path(user: str, vm_path: str) -> Path:
    """Return the host path for a given ``vm_path`` inside ``/data``."""

    try:
        rel = Path(vm_path).relative_to("/data")
    except ValueError as exc:  # pragma: no cover - invalid path
        raise HTTPException(status_code=400, detail="Path must start with /data") from exc

    base = (Path(UPLOAD_DIR) / user).resolve()
    target = (base / rel).resolve()
    if not target.is_relative_to(base):
        raise HTTPException(status_code=400, detail="Invalid path")
    return target


def create_app() -> FastAPI:
    app = FastAPI(title="LLM Backend API")

    app.add_middleware(APIKeyAuthMiddleware)
    app.add_middleware(RateLimiterMiddleware, rate_limit=RATE_LIMIT)
    app.add_middleware(SecurityHeadersMiddleware)

    app.add_middleware(
        CORSMiddleware,
        allow_origins=CORS_ORIGINS,
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

    @app.post("/chat/stream")
    async def chat_stream(req: ChatRequest):
        async def stream() -> asyncio.AsyncIterator[str]:
            async with TeamChatSession(user=req.user, session=req.session) as chat:
                try:
                    async for part in chat.chat_stream(req.prompt):
                        yield part
                except Exception as exc:  # pragma: no cover - runtime failures
                    _LOG.error("Streaming chat failed: %s", exc)
                    yield f"Error: {exc}"

        return StreamingResponse(stream(), media_type="text/plain")

    @app.post("/upload")
    async def upload_document(
        user: str = Form(...),
        session: str = Form("default"),
        file: UploadFile = File(...),
    ):
        async with TeamChatSession(user=user, session=session) as chat:
            tmpdir = tempfile.mkdtemp(prefix="upload_")
            tmp_path = Path(tmpdir) / file.filename
            try:
                contents = await file.read()
                tmp_path.write_bytes(contents)
                vm_path = chat.upload_document(str(tmp_path))
            finally:
                try:
                    os.remove(tmp_path)
                    os.rmdir(tmpdir)
                except OSError:
                    pass
        return {"path": vm_path}

    @app.get("/sessions/{user}")
    async def list_user_sessions(user: str):
        return {"sessions": list_sessions(user)}

    @app.get("/sessions/{user}/info")
    async def list_user_sessions_info(user: str):
        data = list_sessions_info(user)
        if not data:
            raise HTTPException(status_code=404, detail="User not found")
        return {"sessions": data}

    @app.get("/health")
    async def health():
        return {"status": "ok"}

    @app.get("/vm/{user}/list")
    async def list_vm_dir(user: str, path: str = "/data"):
        target = _vm_host_path(user, path)
        if not target.exists():
            raise HTTPException(status_code=404, detail="Directory not found")
        if not target.is_dir():
            raise HTTPException(status_code=400, detail="Not a directory")
        entries: List[dict[str, str | bool]] = []
        for entry in sorted(target.iterdir()):
            entries.append({"name": entry.name, "is_dir": entry.is_dir()})
        return {"entries": entries}

    @app.get("/vm/{user}/file")
    async def read_vm_file(user: str, path: str):
        target = _vm_host_path(user, path)
        if not target.exists():
            raise HTTPException(status_code=404, detail="File not found")
        if target.is_dir():
            raise HTTPException(status_code=400, detail="Path is a directory")
        try:
            content = target.read_text()
        except UnicodeDecodeError:
            raise HTTPException(status_code=400, detail="Binary file not supported")
        return {"content": content}

    @app.post("/vm/{user}/file")
    async def write_vm_file(user: str, req: FileWriteRequest):
        target = _vm_host_path(user, req.path)
        target.parent.mkdir(parents=True, exist_ok=True)
        target.write_text(req.content)
        return {"status": "ok"}

    @app.delete("/vm/{user}/file")
    async def delete_vm_file(user: str, path: str):
        target = _vm_host_path(user, path)
        if target.is_dir():
            shutil.rmtree(target)
        elif target.exists():
            target.unlink()
        else:
            raise HTTPException(status_code=404, detail="File not found")
        return {"status": "deleted"}

    return app


app = create_app()