File size: 6,551 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from functools import wraps
from sentry_sdk import consts
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.ai.utils import set_data_normalized
from sentry_sdk.consts import SPANDATA
from typing import Any, Iterable, Callable
import sentry_sdk
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.utils import (
capture_internal_exceptions,
event_from_exception,
)
try:
import huggingface_hub.inference._client
from huggingface_hub import ChatCompletionStreamOutput, TextGenerationOutput
except ImportError:
raise DidNotEnable("Huggingface not installed")
class HuggingfaceHubIntegration(Integration):
identifier = "huggingface_hub"
origin = f"auto.ai.{identifier}"
def __init__(self, include_prompts=True):
# type: (HuggingfaceHubIntegration, bool) -> None
self.include_prompts = include_prompts
@staticmethod
def setup_once():
# type: () -> None
huggingface_hub.inference._client.InferenceClient.text_generation = (
_wrap_text_generation(
huggingface_hub.inference._client.InferenceClient.text_generation
)
)
def _capture_exception(exc):
# type: (Any) -> None
event, hint = event_from_exception(
exc,
client_options=sentry_sdk.get_client().options,
mechanism={"type": "huggingface_hub", "handled": False},
)
sentry_sdk.capture_event(event, hint=hint)
def _wrap_text_generation(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
def new_text_generation(*args, **kwargs):
# type: (*Any, **Any) -> Any
integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)
if integration is None:
return f(*args, **kwargs)
if "prompt" in kwargs:
prompt = kwargs["prompt"]
elif len(args) >= 2:
kwargs["prompt"] = args[1]
prompt = kwargs["prompt"]
args = (args[0],) + args[2:]
else:
# invalid call, let it return error
return f(*args, **kwargs)
model = kwargs.get("model")
streaming = kwargs.get("stream")
span = sentry_sdk.start_span(
op=consts.OP.HUGGINGFACE_HUB_CHAT_COMPLETIONS_CREATE,
name="Text Generation",
origin=HuggingfaceHubIntegration.origin,
)
span.__enter__()
try:
res = f(*args, **kwargs)
except Exception as e:
_capture_exception(e)
span.__exit__(None, None, None)
raise e from None
with capture_internal_exceptions():
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompt)
set_data_normalized(span, SPANDATA.AI_MODEL_ID, model)
set_data_normalized(span, SPANDATA.AI_STREAMING, streaming)
if isinstance(res, str):
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span,
SPANDATA.AI_RESPONSES,
[res],
)
span.__exit__(None, None, None)
return res
if isinstance(res, TextGenerationOutput):
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span,
SPANDATA.AI_RESPONSES,
[res.generated_text],
)
if res.details is not None and res.details.generated_tokens > 0:
record_token_usage(span, total_tokens=res.details.generated_tokens)
span.__exit__(None, None, None)
return res
if not isinstance(res, Iterable):
# we only know how to deal with strings and iterables, ignore
set_data_normalized(span, "unknown_response", True)
span.__exit__(None, None, None)
return res
if kwargs.get("details", False):
# res is Iterable[TextGenerationStreamOutput]
def new_details_iterator():
# type: () -> Iterable[ChatCompletionStreamOutput]
with capture_internal_exceptions():
tokens_used = 0
data_buf: list[str] = []
for x in res:
if hasattr(x, "token") and hasattr(x.token, "text"):
data_buf.append(x.token.text)
if hasattr(x, "details") and hasattr(
x.details, "generated_tokens"
):
tokens_used = x.details.generated_tokens
yield x
if (
len(data_buf) > 0
and should_send_default_pii()
and integration.include_prompts
):
set_data_normalized(
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
)
if tokens_used > 0:
record_token_usage(span, total_tokens=tokens_used)
span.__exit__(None, None, None)
return new_details_iterator()
else:
# res is Iterable[str]
def new_iterator():
# type: () -> Iterable[str]
data_buf: list[str] = []
with capture_internal_exceptions():
for s in res:
if isinstance(s, str):
data_buf.append(s)
yield s
if (
len(data_buf) > 0
and should_send_default_pii()
and integration.include_prompts
):
set_data_normalized(
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
)
span.__exit__(None, None, None)
return new_iterator()
return new_text_generation
|