Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
from fastapi import Request, HTTPException | |
from starlette.middleware.base import BaseHTTPMiddleware | |
from starlette.responses import Response | |
from typing import Callable, Awaitable, MutableMapping | |
from time import monotonic | |
import asyncio | |
from .config import API_KEYS, RATE_LIMIT | |
class APIKeyAuthMiddleware(BaseHTTPMiddleware): | |
"""Require a valid API key via the ``X-API-Key`` header.""" | |
def __init__(self, app): | |
super().__init__(app) | |
self._keys = {k.strip() for k in API_KEYS if k.strip()} | |
async def dispatch( | |
self, request: Request, call_next: Callable[[Request], Awaitable[Response]] | |
) -> Response: | |
if self._keys: | |
key = request.headers.get("X-API-Key") | |
if key not in self._keys: | |
raise HTTPException(status_code=401, detail="Invalid API key") | |
return await call_next(request) | |
class RateLimiterMiddleware(BaseHTTPMiddleware): | |
"""Simple in-memory rate limiter per client.""" | |
def __init__(self, app, rate_limit: int = RATE_LIMIT) -> None: | |
super().__init__(app) | |
self.rate_limit = rate_limit | |
self._requests: MutableMapping[str, list[float]] = {} | |
self._lock = asyncio.Lock() | |
async def dispatch( | |
self, request: Request, call_next: Callable[[Request], Awaitable[Response]] | |
) -> Response: | |
identifier = request.headers.get("X-API-Key") or request.client.host | |
now = monotonic() | |
async with self._lock: | |
timestamps = self._requests.setdefault(identifier, []) | |
while timestamps and now - timestamps[0] > 60: | |
timestamps.pop(0) | |
if len(timestamps) >= self.rate_limit: | |
raise HTTPException(status_code=429, detail="Rate limit exceeded") | |
timestamps.append(now) | |
return await call_next(request) | |
class SecurityHeadersMiddleware(BaseHTTPMiddleware): | |
"""Add common security-related HTTP headers.""" | |
async def dispatch( | |
self, request: Request, call_next: Callable[[Request], Awaitable[Response]] | |
) -> Response: | |
response = await call_next(request) | |
headers = response.headers | |
headers.setdefault("X-Frame-Options", "DENY") | |
headers.setdefault("X-Content-Type-Options", "nosniff") | |
headers.setdefault("Referrer-Policy", "same-origin") | |
headers.setdefault("Permissions-Policy", "geolocation=()") | |
headers.setdefault( | |
"Strict-Transport-Security", "max-age=63072000; includeSubDomains" | |
) | |
return response | |