tech-envision commited on
Commit
c1304d4
·
1 Parent(s): dd8abbe

Add security middleware

Browse files
Files changed (4) hide show
  1. README.md +6 -0
  2. api_app/__init__.py +11 -2
  3. src/config.py +9 -0
  4. src/security.py +69 -0
README.md CHANGED
@@ -104,6 +104,12 @@ curl -N -X POST http://localhost:8000/chat/stream \
104
  -d '{"user":"demo","session":"default","prompt":"Hello"}'
105
  ```
106
 
 
 
 
 
 
 
107
  ## Command Line Interface
108
 
109
  Run the interactive CLI on any platform:
 
104
  -d '{"user":"demo","session":"default","prompt":"Hello"}'
105
  ```
106
 
107
+ ### Security
108
+
109
+ Set one or more API keys in the ``API_KEYS`` environment variable. Requests must
110
+ include the ``X-API-Key`` header when keys are configured. A simple rate limit is
111
+ also enforced per key or client IP, configurable via ``RATE_LIMIT``.
112
+
113
  ## Command Line Interface
114
 
115
  Run the interactive CLI on any platform:
api_app/__init__.py CHANGED
@@ -11,7 +11,12 @@ from pathlib import Path
11
  from typing import List
12
  import shutil
13
 
14
- from src.config import UPLOAD_DIR
 
 
 
 
 
15
 
16
  from src.team import TeamChatSession
17
  from src.log import get_logger
@@ -50,9 +55,13 @@ def _vm_host_path(user: str, vm_path: str) -> Path:
50
  def create_app() -> FastAPI:
51
  app = FastAPI(title="LLM Backend API")
52
 
 
 
 
 
53
  app.add_middleware(
54
  CORSMiddleware,
55
- allow_origins=["*"],
56
  allow_credentials=True,
57
  allow_methods=["*"],
58
  allow_headers=["*"],
 
11
  from typing import List
12
  import shutil
13
 
14
+ from src.config import UPLOAD_DIR, CORS_ORIGINS, RATE_LIMIT
15
+ from src.security import (
16
+ APIKeyAuthMiddleware,
17
+ RateLimiterMiddleware,
18
+ SecurityHeadersMiddleware,
19
+ )
20
 
21
  from src.team import TeamChatSession
22
  from src.log import get_logger
 
55
  def create_app() -> FastAPI:
56
  app = FastAPI(title="LLM Backend API")
57
 
58
+ app.add_middleware(APIKeyAuthMiddleware)
59
+ app.add_middleware(RateLimiterMiddleware, rate_limit=RATE_LIMIT)
60
+ app.add_middleware(SecurityHeadersMiddleware)
61
+
62
  app.add_middleware(
63
  CORSMiddleware,
64
+ allow_origins=CORS_ORIGINS,
65
  allow_credentials=True,
66
  allow_methods=["*"],
67
  allow_headers=["*"],
src/config.py CHANGED
@@ -59,3 +59,12 @@ Continue using tools until you have gathered everything the senior agent needs.
59
  Then send a brief, accurate summary so the senior agent can craft the final response.
60
  Remember: you never speak to the user directly; all communication flows through the senior agent.
61
  """.strip()
 
 
 
 
 
 
 
 
 
 
59
  Then send a brief, accurate summary so the senior agent can craft the final response.
60
  Remember: you never speak to the user directly; all communication flows through the senior agent.
61
  """.strip()
62
+
63
+ # Security settings
64
+ API_KEYS: Final[list[str]] = (
65
+ os.getenv("API_KEYS", "").split(",") if os.getenv("API_KEYS") else []
66
+ )
67
+ RATE_LIMIT: Final[int] = int(os.getenv("RATE_LIMIT", "60"))
68
+ CORS_ORIGINS: Final[list[str]] = (
69
+ os.getenv("CORS_ORIGINS", "*").split(",") if os.getenv("CORS_ORIGINS") else ["*"]
70
+ )
src/security.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from fastapi import Request, HTTPException
4
+ from starlette.middleware.base import BaseHTTPMiddleware
5
+ from starlette.responses import Response
6
+ from typing import Callable, Awaitable, MutableMapping
7
+ from time import monotonic
8
+ import asyncio
9
+
10
+ from .config import API_KEYS, RATE_LIMIT
11
+
12
+
13
+ class APIKeyAuthMiddleware(BaseHTTPMiddleware):
14
+ """Require a valid API key via the ``X-API-Key`` header."""
15
+
16
+ def __init__(self, app):
17
+ super().__init__(app)
18
+ self._keys = {k.strip() for k in API_KEYS if k.strip()}
19
+
20
+ async def dispatch(
21
+ self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
22
+ ) -> Response:
23
+ if self._keys:
24
+ key = request.headers.get("X-API-Key")
25
+ if key not in self._keys:
26
+ raise HTTPException(status_code=401, detail="Invalid API key")
27
+ return await call_next(request)
28
+
29
+
30
+ class RateLimiterMiddleware(BaseHTTPMiddleware):
31
+ """Simple in-memory rate limiter per client."""
32
+
33
+ def __init__(self, app, rate_limit: int = RATE_LIMIT) -> None:
34
+ super().__init__(app)
35
+ self.rate_limit = rate_limit
36
+ self._requests: MutableMapping[str, list[float]] = {}
37
+ self._lock = asyncio.Lock()
38
+
39
+ async def dispatch(
40
+ self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
41
+ ) -> Response:
42
+ identifier = request.headers.get("X-API-Key") or request.client.host
43
+ now = monotonic()
44
+ async with self._lock:
45
+ timestamps = self._requests.setdefault(identifier, [])
46
+ while timestamps and now - timestamps[0] > 60:
47
+ timestamps.pop(0)
48
+ if len(timestamps) >= self.rate_limit:
49
+ raise HTTPException(status_code=429, detail="Rate limit exceeded")
50
+ timestamps.append(now)
51
+ return await call_next(request)
52
+
53
+
54
+ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
55
+ """Add common security-related HTTP headers."""
56
+
57
+ async def dispatch(
58
+ self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
59
+ ) -> Response:
60
+ response = await call_next(request)
61
+ headers = response.headers
62
+ headers.setdefault("X-Frame-Options", "DENY")
63
+ headers.setdefault("X-Content-Type-Options", "nosniff")
64
+ headers.setdefault("Referrer-Policy", "same-origin")
65
+ headers.setdefault("Permissions-Policy", "geolocation=()")
66
+ headers.setdefault(
67
+ "Strict-Transport-Security", "max-age=63072000; includeSubDomains"
68
+ )
69
+ return response