File size: 2,581 Bytes
c1304d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import annotations

from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Callable, Awaitable, MutableMapping
from time import monotonic
import asyncio

from .config import API_KEYS, RATE_LIMIT


class APIKeyAuthMiddleware(BaseHTTPMiddleware):
    """Require a valid API key via the ``X-API-Key`` header."""

    def __init__(self, app):
        super().__init__(app)
        self._keys = {k.strip() for k in API_KEYS if k.strip()}

    async def dispatch(
        self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
    ) -> Response:
        if self._keys:
            key = request.headers.get("X-API-Key")
            if key not in self._keys:
                raise HTTPException(status_code=401, detail="Invalid API key")
        return await call_next(request)


class RateLimiterMiddleware(BaseHTTPMiddleware):
    """Simple in-memory rate limiter per client."""

    def __init__(self, app, rate_limit: int = RATE_LIMIT) -> None:
        super().__init__(app)
        self.rate_limit = rate_limit
        self._requests: MutableMapping[str, list[float]] = {}
        self._lock = asyncio.Lock()

    async def dispatch(
        self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
    ) -> Response:
        identifier = request.headers.get("X-API-Key") or request.client.host
        now = monotonic()
        async with self._lock:
            timestamps = self._requests.setdefault(identifier, [])
            while timestamps and now - timestamps[0] > 60:
                timestamps.pop(0)
            if len(timestamps) >= self.rate_limit:
                raise HTTPException(status_code=429, detail="Rate limit exceeded")
            timestamps.append(now)
        return await call_next(request)


class SecurityHeadersMiddleware(BaseHTTPMiddleware):
    """Add common security-related HTTP headers."""

    async def dispatch(
        self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
    ) -> Response:
        response = await call_next(request)
        headers = response.headers
        headers.setdefault("X-Frame-Options", "DENY")
        headers.setdefault("X-Content-Type-Options", "nosniff")
        headers.setdefault("Referrer-Policy", "same-origin")
        headers.setdefault("Permissions-Policy", "geolocation=()")
        headers.setdefault(
            "Strict-Transport-Security", "max-age=63072000; includeSubDomains"
        )
        return response