File size: 6,521 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
from __future__ import annotations
import contextlib
from typing import Any, TypeVar, Callable, Awaitable, Iterator
import sentry_sdk
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.integrations import _check_minimum_version, Integration, DidNotEnable
from sentry_sdk.tracing import Span
from sentry_sdk.tracing_utils import add_query_source, record_sql_queries
from sentry_sdk.utils import (
ensure_integration_enabled,
parse_version,
capture_internal_exceptions,
)
try:
import asyncpg # type: ignore[import-not-found]
from asyncpg.cursor import BaseCursor # type: ignore
except ImportError:
raise DidNotEnable("asyncpg not installed.")
class AsyncPGIntegration(Integration):
identifier = "asyncpg"
origin = f"auto.db.{identifier}"
_record_params = False
def __init__(self, *, record_params: bool = False):
AsyncPGIntegration._record_params = record_params
@staticmethod
def setup_once() -> None:
# asyncpg.__version__ is a string containing the semantic version in the form of "<major>.<minor>.<patch>"
asyncpg_version = parse_version(asyncpg.__version__)
_check_minimum_version(AsyncPGIntegration, asyncpg_version)
asyncpg.Connection.execute = _wrap_execute(
asyncpg.Connection.execute,
)
asyncpg.Connection._execute = _wrap_connection_method(
asyncpg.Connection._execute
)
asyncpg.Connection._executemany = _wrap_connection_method(
asyncpg.Connection._executemany, executemany=True
)
asyncpg.Connection.cursor = _wrap_cursor_creation(asyncpg.Connection.cursor)
asyncpg.Connection.prepare = _wrap_connection_method(asyncpg.Connection.prepare)
asyncpg.connect_utils._connect_addr = _wrap_connect_addr(
asyncpg.connect_utils._connect_addr
)
T = TypeVar("T")
def _wrap_execute(f: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
async def _inner(*args: Any, **kwargs: Any) -> T:
if sentry_sdk.get_client().get_integration(AsyncPGIntegration) is None:
return await f(*args, **kwargs)
# Avoid recording calls to _execute twice.
# Calls to Connection.execute with args also call
# Connection._execute, which is recorded separately
# args[0] = the connection object, args[1] is the query
if len(args) > 2:
return await f(*args, **kwargs)
query = args[1]
with record_sql_queries(
cursor=None,
query=query,
params_list=None,
paramstyle=None,
executemany=False,
span_origin=AsyncPGIntegration.origin,
) as span:
res = await f(*args, **kwargs)
with capture_internal_exceptions():
add_query_source(span)
return res
return _inner
SubCursor = TypeVar("SubCursor", bound=BaseCursor)
@contextlib.contextmanager
def _record(
cursor: SubCursor | None,
query: str,
params_list: tuple[Any, ...] | None,
*,
executemany: bool = False,
) -> Iterator[Span]:
integration = sentry_sdk.get_client().get_integration(AsyncPGIntegration)
if integration is not None and not integration._record_params:
params_list = None
param_style = "pyformat" if params_list else None
with record_sql_queries(
cursor=cursor,
query=query,
params_list=params_list,
paramstyle=param_style,
executemany=executemany,
record_cursor_repr=cursor is not None,
span_origin=AsyncPGIntegration.origin,
) as span:
yield span
def _wrap_connection_method(
f: Callable[..., Awaitable[T]], *, executemany: bool = False
) -> Callable[..., Awaitable[T]]:
async def _inner(*args: Any, **kwargs: Any) -> T:
if sentry_sdk.get_client().get_integration(AsyncPGIntegration) is None:
return await f(*args, **kwargs)
query = args[1]
params_list = args[2] if len(args) > 2 else None
with _record(None, query, params_list, executemany=executemany) as span:
_set_db_data(span, args[0])
res = await f(*args, **kwargs)
return res
return _inner
def _wrap_cursor_creation(f: Callable[..., T]) -> Callable[..., T]:
@ensure_integration_enabled(AsyncPGIntegration, f)
def _inner(*args: Any, **kwargs: Any) -> T: # noqa: N807
query = args[1]
params_list = args[2] if len(args) > 2 else None
with _record(
None,
query,
params_list,
executemany=False,
) as span:
_set_db_data(span, args[0])
res = f(*args, **kwargs)
span.set_data("db.cursor", res)
return res
return _inner
def _wrap_connect_addr(f: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
async def _inner(*args: Any, **kwargs: Any) -> T:
if sentry_sdk.get_client().get_integration(AsyncPGIntegration) is None:
return await f(*args, **kwargs)
user = kwargs["params"].user
database = kwargs["params"].database
with sentry_sdk.start_span(
op=OP.DB,
name="connect",
origin=AsyncPGIntegration.origin,
) as span:
span.set_data(SPANDATA.DB_SYSTEM, "postgresql")
addr = kwargs.get("addr")
if addr:
try:
span.set_data(SPANDATA.SERVER_ADDRESS, addr[0])
span.set_data(SPANDATA.SERVER_PORT, addr[1])
except IndexError:
pass
span.set_data(SPANDATA.DB_NAME, database)
span.set_data(SPANDATA.DB_USER, user)
with capture_internal_exceptions():
sentry_sdk.add_breadcrumb(
message="connect", category="query", data=span._data
)
res = await f(*args, **kwargs)
return res
return _inner
def _set_db_data(span: Span, conn: Any) -> None:
span.set_data(SPANDATA.DB_SYSTEM, "postgresql")
addr = conn._addr
if addr:
try:
span.set_data(SPANDATA.SERVER_ADDRESS, addr[0])
span.set_data(SPANDATA.SERVER_PORT, addr[1])
except IndexError:
pass
database = conn._params.database
if database:
span.set_data(SPANDATA.DB_NAME, database)
user = conn._params.user
if user:
span.set_data(SPANDATA.DB_USER, user)
|