File size: 4,948 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from __future__ import annotations
import functools
import inspect
import sys
import typing
from urllib.parse import urlencode
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
from starlette._utils import is_async_callable
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse
from starlette.websockets import WebSocket
_P = ParamSpec("_P")
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
for scope in scopes:
if scope not in conn.auth.scopes:
return False
return True
def requires(
scopes: str | typing.Sequence[str],
status_code: int = 403,
redirect: str | None = None,
) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]:
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
def decorator(
func: typing.Callable[_P, typing.Any],
) -> typing.Callable[_P, typing.Any]:
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request" or parameter.name == "websocket":
type_ = parameter.name
break
else:
raise Exception(f'No "request" or "websocket" argument on function "{func}"')
if type_ == "websocket":
# Handle websocket functions. (Always async)
@functools.wraps(func)
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
assert isinstance(websocket, WebSocket)
if not has_required_scope(websocket, scopes_list):
await websocket.close()
else:
await func(*args, **kwargs)
return websocket_wrapper
elif is_async_callable(func):
# Handle async request/response functions.
@functools.wraps(func)
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return await func(*args, **kwargs)
return async_wrapper
else:
# Handle sync request/response functions.
@functools.wraps(func)
def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return func(*args, **kwargs)
return sync_wrapper
return decorator
class AuthenticationError(Exception):
pass
class AuthenticationBackend:
async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
raise NotImplementedError() # pragma: no cover
class AuthCredentials:
def __init__(self, scopes: typing.Sequence[str] | None = None):
self.scopes = [] if scopes is None else list(scopes)
class BaseUser:
@property
def is_authenticated(self) -> bool:
raise NotImplementedError() # pragma: no cover
@property
def display_name(self) -> str:
raise NotImplementedError() # pragma: no cover
@property
def identity(self) -> str:
raise NotImplementedError() # pragma: no cover
class SimpleUser(BaseUser):
def __init__(self, username: str) -> None:
self.username = username
@property
def is_authenticated(self) -> bool:
return True
@property
def display_name(self) -> str:
return self.username
class UnauthenticatedUser(BaseUser):
@property
def is_authenticated(self) -> bool:
return False
@property
def display_name(self) -> str:
return ""
|