Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import asyncio | |
import os | |
import tempfile | |
from pathlib import Path | |
from typing import List | |
import shutil | |
from src.config import UPLOAD_DIR, CORS_ORIGINS, RATE_LIMIT | |
from src.security import ( | |
APIKeyAuthMiddleware, | |
RateLimiterMiddleware, | |
SecurityHeadersMiddleware, | |
) | |
from src.team import TeamChatSession | |
from src.log import get_logger | |
from src.db import list_sessions, list_sessions_info | |
_LOG = get_logger(__name__) | |
class ChatRequest(BaseModel): | |
user: str = "default" | |
session: str = "default" | |
prompt: str | |
class FileWriteRequest(BaseModel): | |
path: str | |
content: str | |
def _vm_host_path(user: str, vm_path: str) -> Path: | |
"""Return the host path for a given ``vm_path`` inside ``/data``.""" | |
try: | |
rel = Path(vm_path).relative_to("/data") | |
except ValueError as exc: # pragma: no cover - invalid path | |
raise HTTPException(status_code=400, detail="Path must start with /data") from exc | |
base = (Path(UPLOAD_DIR) / user).resolve() | |
target = (base / rel).resolve() | |
if not target.is_relative_to(base): | |
raise HTTPException(status_code=400, detail="Invalid path") | |
return target | |
def create_app() -> FastAPI: | |
app = FastAPI(title="LLM Backend API") | |
app.add_middleware(APIKeyAuthMiddleware) | |
app.add_middleware(RateLimiterMiddleware, rate_limit=RATE_LIMIT) | |
app.add_middleware(SecurityHeadersMiddleware) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=CORS_ORIGINS, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def chat_stream(req: ChatRequest): | |
async def stream() -> asyncio.AsyncIterator[str]: | |
async with TeamChatSession(user=req.user, session=req.session) as chat: | |
try: | |
async for part in chat.chat_stream(req.prompt): | |
yield part | |
except Exception as exc: # pragma: no cover - runtime failures | |
_LOG.error("Streaming chat failed: %s", exc) | |
yield f"Error: {exc}" | |
return StreamingResponse(stream(), media_type="text/plain") | |
async def upload_document( | |
user: str = Form(...), | |
session: str = Form("default"), | |
file: UploadFile = File(...), | |
): | |
async with TeamChatSession(user=user, session=session) as chat: | |
tmpdir = tempfile.mkdtemp(prefix="upload_") | |
tmp_path = Path(tmpdir) / file.filename | |
try: | |
contents = await file.read() | |
tmp_path.write_bytes(contents) | |
vm_path = chat.upload_document(str(tmp_path)) | |
finally: | |
try: | |
os.remove(tmp_path) | |
os.rmdir(tmpdir) | |
except OSError: | |
pass | |
return {"path": vm_path} | |
async def list_user_sessions(user: str): | |
return {"sessions": list_sessions(user)} | |
async def list_user_sessions_info(user: str): | |
data = list_sessions_info(user) | |
if not data: | |
raise HTTPException(status_code=404, detail="User not found") | |
return {"sessions": data} | |
async def health(): | |
return {"status": "ok"} | |
async def list_vm_dir(user: str, path: str = "/data"): | |
target = _vm_host_path(user, path) | |
if not target.exists(): | |
raise HTTPException(status_code=404, detail="Directory not found") | |
if not target.is_dir(): | |
raise HTTPException(status_code=400, detail="Not a directory") | |
entries: List[dict[str, str | bool]] = [] | |
for entry in sorted(target.iterdir()): | |
entries.append({"name": entry.name, "is_dir": entry.is_dir()}) | |
return {"entries": entries} | |
async def read_vm_file(user: str, path: str): | |
target = _vm_host_path(user, path) | |
if not target.exists(): | |
raise HTTPException(status_code=404, detail="File not found") | |
if target.is_dir(): | |
raise HTTPException(status_code=400, detail="Path is a directory") | |
try: | |
content = target.read_text() | |
except UnicodeDecodeError: | |
raise HTTPException(status_code=400, detail="Binary file not supported") | |
return {"content": content} | |
async def write_vm_file(user: str, req: FileWriteRequest): | |
target = _vm_host_path(user, req.path) | |
target.parent.mkdir(parents=True, exist_ok=True) | |
target.write_text(req.content) | |
return {"status": "ok"} | |
async def delete_vm_file(user: str, path: str): | |
target = _vm_host_path(user, path) | |
if target.is_dir(): | |
shutil.rmtree(target) | |
elif target.exists(): | |
target.unlink() | |
else: | |
raise HTTPException(status_code=404, detail="File not found") | |
return {"status": "deleted"} | |
return app | |
app = create_app() | |