|
from __future__ import annotations |
|
|
|
import functools |
|
import inspect |
|
import sys |
|
import typing |
|
from urllib.parse import urlencode |
|
|
|
if sys.version_info >= (3, 10): |
|
from typing import ParamSpec |
|
else: |
|
from typing_extensions import ParamSpec |
|
|
|
from starlette._utils import is_async_callable |
|
from starlette.exceptions import HTTPException |
|
from starlette.requests import HTTPConnection, Request |
|
from starlette.responses import RedirectResponse |
|
from starlette.websockets import WebSocket |
|
|
|
_P = ParamSpec("_P") |
|
|
|
|
|
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool: |
|
for scope in scopes: |
|
if scope not in conn.auth.scopes: |
|
return False |
|
return True |
|
|
|
|
|
def requires( |
|
scopes: str | typing.Sequence[str], |
|
status_code: int = 403, |
|
redirect: str | None = None, |
|
) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]: |
|
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) |
|
|
|
def decorator( |
|
func: typing.Callable[_P, typing.Any], |
|
) -> typing.Callable[_P, typing.Any]: |
|
sig = inspect.signature(func) |
|
for idx, parameter in enumerate(sig.parameters.values()): |
|
if parameter.name == "request" or parameter.name == "websocket": |
|
type_ = parameter.name |
|
break |
|
else: |
|
raise Exception(f'No "request" or "websocket" argument on function "{func}"') |
|
|
|
if type_ == "websocket": |
|
|
|
@functools.wraps(func) |
|
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: |
|
websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None) |
|
assert isinstance(websocket, WebSocket) |
|
|
|
if not has_required_scope(websocket, scopes_list): |
|
await websocket.close() |
|
else: |
|
await func(*args, **kwargs) |
|
|
|
return websocket_wrapper |
|
|
|
elif is_async_callable(func): |
|
|
|
@functools.wraps(func) |
|
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: |
|
request = kwargs.get("request", args[idx] if idx < len(args) else None) |
|
assert isinstance(request, Request) |
|
|
|
if not has_required_scope(request, scopes_list): |
|
if redirect is not None: |
|
orig_request_qparam = urlencode({"next": str(request.url)}) |
|
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}" |
|
return RedirectResponse(url=next_url, status_code=303) |
|
raise HTTPException(status_code=status_code) |
|
return await func(*args, **kwargs) |
|
|
|
return async_wrapper |
|
|
|
else: |
|
|
|
@functools.wraps(func) |
|
def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: |
|
request = kwargs.get("request", args[idx] if idx < len(args) else None) |
|
assert isinstance(request, Request) |
|
|
|
if not has_required_scope(request, scopes_list): |
|
if redirect is not None: |
|
orig_request_qparam = urlencode({"next": str(request.url)}) |
|
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}" |
|
return RedirectResponse(url=next_url, status_code=303) |
|
raise HTTPException(status_code=status_code) |
|
return func(*args, **kwargs) |
|
|
|
return sync_wrapper |
|
|
|
return decorator |
|
|
|
|
|
class AuthenticationError(Exception): |
|
pass |
|
|
|
|
|
class AuthenticationBackend: |
|
async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None: |
|
raise NotImplementedError() |
|
|
|
|
|
class AuthCredentials: |
|
def __init__(self, scopes: typing.Sequence[str] | None = None): |
|
self.scopes = [] if scopes is None else list(scopes) |
|
|
|
|
|
class BaseUser: |
|
@property |
|
def is_authenticated(self) -> bool: |
|
raise NotImplementedError() |
|
|
|
@property |
|
def display_name(self) -> str: |
|
raise NotImplementedError() |
|
|
|
@property |
|
def identity(self) -> str: |
|
raise NotImplementedError() |
|
|
|
|
|
class SimpleUser(BaseUser): |
|
def __init__(self, username: str) -> None: |
|
self.username = username |
|
|
|
@property |
|
def is_authenticated(self) -> bool: |
|
return True |
|
|
|
@property |
|
def display_name(self) -> str: |
|
return self.username |
|
|
|
|
|
class UnauthenticatedUser(BaseUser): |
|
@property |
|
def is_authenticated(self) -> bool: |
|
return False |
|
|
|
@property |
|
def display_name(self) -> str: |
|
return "" |
|
|