Spaces:
Runtime error
Runtime error
File size: 5,399 Bytes
3e5cdc3 4f46c78 e75d7f2 6ebba8f 3e5cdc3 4f46c78 c1304d4 3e5cdc3 f7c8c98 e75d7f2 3e5cdc3 4f46c78 3e5cdc3 c1304d4 6ebba8f c1304d4 6ebba8f 3e5cdc3 f7c8c98 3e5cdc3 f7c8c98 3e5cdc3 bf3a897 5cbac45 3e5cdc3 4f46c78 3e5cdc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from __future__ import annotations
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import asyncio
import os
import tempfile
from pathlib import Path
from typing import List
import shutil
from src.config import UPLOAD_DIR, CORS_ORIGINS, RATE_LIMIT
from src.security import (
APIKeyAuthMiddleware,
RateLimiterMiddleware,
SecurityHeadersMiddleware,
)
from src.team import TeamChatSession
from src.log import get_logger
from src.db import list_sessions, list_sessions_info
_LOG = get_logger(__name__)
class ChatRequest(BaseModel):
user: str = "default"
session: str = "default"
prompt: str
class FileWriteRequest(BaseModel):
path: str
content: str
def _vm_host_path(user: str, vm_path: str) -> Path:
"""Return the host path for a given ``vm_path`` inside ``/data``."""
try:
rel = Path(vm_path).relative_to("/data")
except ValueError as exc: # pragma: no cover - invalid path
raise HTTPException(status_code=400, detail="Path must start with /data") from exc
base = (Path(UPLOAD_DIR) / user).resolve()
target = (base / rel).resolve()
if not target.is_relative_to(base):
raise HTTPException(status_code=400, detail="Invalid path")
return target
def create_app() -> FastAPI:
app = FastAPI(title="LLM Backend API")
app.add_middleware(APIKeyAuthMiddleware)
app.add_middleware(RateLimiterMiddleware, rate_limit=RATE_LIMIT)
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/chat/stream")
async def chat_stream(req: ChatRequest):
async def stream() -> asyncio.AsyncIterator[str]:
async with TeamChatSession(user=req.user, session=req.session) as chat:
try:
async for part in chat.chat_stream(req.prompt):
yield part
except Exception as exc: # pragma: no cover - runtime failures
_LOG.error("Streaming chat failed: %s", exc)
yield f"Error: {exc}"
return StreamingResponse(stream(), media_type="text/plain")
@app.post("/upload")
async def upload_document(
user: str = Form(...),
session: str = Form("default"),
file: UploadFile = File(...),
):
async with TeamChatSession(user=user, session=session) as chat:
tmpdir = tempfile.mkdtemp(prefix="upload_")
tmp_path = Path(tmpdir) / file.filename
try:
contents = await file.read()
tmp_path.write_bytes(contents)
vm_path = chat.upload_document(str(tmp_path))
finally:
try:
os.remove(tmp_path)
os.rmdir(tmpdir)
except OSError:
pass
return {"path": vm_path}
@app.get("/sessions/{user}")
async def list_user_sessions(user: str):
return {"sessions": list_sessions(user)}
@app.get("/sessions/{user}/info")
async def list_user_sessions_info(user: str):
data = list_sessions_info(user)
if not data:
raise HTTPException(status_code=404, detail="User not found")
return {"sessions": data}
@app.get("/health")
async def health():
return {"status": "ok"}
@app.get("/vm/{user}/list")
async def list_vm_dir(user: str, path: str = "/data"):
target = _vm_host_path(user, path)
if not target.exists():
raise HTTPException(status_code=404, detail="Directory not found")
if not target.is_dir():
raise HTTPException(status_code=400, detail="Not a directory")
entries: List[dict[str, str | bool]] = []
for entry in sorted(target.iterdir()):
entries.append({"name": entry.name, "is_dir": entry.is_dir()})
return {"entries": entries}
@app.get("/vm/{user}/file")
async def read_vm_file(user: str, path: str):
target = _vm_host_path(user, path)
if not target.exists():
raise HTTPException(status_code=404, detail="File not found")
if target.is_dir():
raise HTTPException(status_code=400, detail="Path is a directory")
try:
content = target.read_text()
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="Binary file not supported")
return {"content": content}
@app.post("/vm/{user}/file")
async def write_vm_file(user: str, req: FileWriteRequest):
target = _vm_host_path(user, req.path)
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text(req.content)
return {"status": "ok"}
@app.delete("/vm/{user}/file")
async def delete_vm_file(user: str, path: str):
target = _vm_host_path(user, path)
if target.is_dir():
shutil.rmtree(target)
elif target.exists():
target.unlink()
else:
raise HTTPException(status_code=404, detail="File not found")
return {"status": "deleted"}
return app
app = create_app()
|