|
from collections import OrderedDict |
|
from functools import wraps |
|
|
|
import sentry_sdk |
|
from sentry_sdk.ai.monitoring import set_ai_pipeline_name, record_token_usage |
|
from sentry_sdk.consts import OP, SPANDATA |
|
from sentry_sdk.ai.utils import set_data_normalized |
|
from sentry_sdk.scope import should_send_default_pii |
|
from sentry_sdk.tracing import Span |
|
from sentry_sdk.integrations import DidNotEnable, Integration |
|
from sentry_sdk.utils import logger, capture_internal_exceptions |
|
|
|
from typing import TYPE_CHECKING |
|
|
|
if TYPE_CHECKING: |
|
from typing import Any, List, Callable, Dict, Union, Optional |
|
from uuid import UUID |
|
|
|
try: |
|
from langchain_core.messages import BaseMessage |
|
from langchain_core.outputs import LLMResult |
|
from langchain_core.callbacks import ( |
|
manager, |
|
BaseCallbackHandler, |
|
) |
|
from langchain_core.agents import AgentAction, AgentFinish |
|
except ImportError: |
|
raise DidNotEnable("langchain not installed") |
|
|
|
|
|
DATA_FIELDS = { |
|
"temperature": SPANDATA.AI_TEMPERATURE, |
|
"top_p": SPANDATA.AI_TOP_P, |
|
"top_k": SPANDATA.AI_TOP_K, |
|
"function_call": SPANDATA.AI_FUNCTION_CALL, |
|
"tool_calls": SPANDATA.AI_TOOL_CALLS, |
|
"tools": SPANDATA.AI_TOOLS, |
|
"response_format": SPANDATA.AI_RESPONSE_FORMAT, |
|
"logit_bias": SPANDATA.AI_LOGIT_BIAS, |
|
"tags": SPANDATA.AI_TAGS, |
|
} |
|
|
|
|
|
|
|
NO_COLLECT_TOKEN_MODELS = [ |
|
"openai-chat", |
|
"anthropic-chat", |
|
"cohere-chat", |
|
"huggingface_endpoint", |
|
] |
|
|
|
|
|
class LangchainIntegration(Integration): |
|
identifier = "langchain" |
|
origin = f"auto.ai.{identifier}" |
|
|
|
|
|
max_spans = 1024 |
|
|
|
def __init__( |
|
self, include_prompts=True, max_spans=1024, tiktoken_encoding_name=None |
|
): |
|
|
|
self.include_prompts = include_prompts |
|
self.max_spans = max_spans |
|
self.tiktoken_encoding_name = tiktoken_encoding_name |
|
|
|
@staticmethod |
|
def setup_once(): |
|
|
|
manager._configure = _wrap_configure(manager._configure) |
|
|
|
|
|
class WatchedSpan: |
|
span = None |
|
num_completion_tokens = 0 |
|
num_prompt_tokens = 0 |
|
no_collect_tokens = False |
|
children = [] |
|
is_pipeline = False |
|
|
|
def __init__(self, span): |
|
|
|
self.span = span |
|
|
|
|
|
class SentryLangchainCallback(BaseCallbackHandler): |
|
"""Base callback handler that can be used to handle callbacks from langchain.""" |
|
|
|
span_map = OrderedDict() |
|
|
|
max_span_map_size = 0 |
|
|
|
def __init__(self, max_span_map_size, include_prompts, tiktoken_encoding_name=None): |
|
|
|
self.max_span_map_size = max_span_map_size |
|
self.include_prompts = include_prompts |
|
|
|
self.tiktoken_encoding = None |
|
if tiktoken_encoding_name is not None: |
|
import tiktoken |
|
|
|
self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name) |
|
|
|
def count_tokens(self, s): |
|
|
|
if self.tiktoken_encoding is not None: |
|
return len(self.tiktoken_encoding.encode_ordinary(s)) |
|
return 0 |
|
|
|
def gc_span_map(self): |
|
|
|
|
|
while len(self.span_map) > self.max_span_map_size: |
|
run_id, watched_span = self.span_map.popitem(last=False) |
|
self._exit_span(watched_span, run_id) |
|
|
|
def _handle_error(self, run_id, error): |
|
|
|
if not run_id or run_id not in self.span_map: |
|
return |
|
|
|
span_data = self.span_map[run_id] |
|
if not span_data: |
|
return |
|
sentry_sdk.capture_exception(error, span_data.span.scope) |
|
span_data.span.__exit__(None, None, None) |
|
del self.span_map[run_id] |
|
|
|
def _normalize_langchain_message(self, message): |
|
|
|
parsed = {"content": message.content, "role": message.type} |
|
parsed.update(message.additional_kwargs) |
|
return parsed |
|
|
|
def _create_span(self, run_id, parent_id, **kwargs): |
|
|
|
|
|
watched_span = None |
|
if parent_id: |
|
parent_span = self.span_map.get(parent_id) |
|
if parent_span: |
|
watched_span = WatchedSpan(parent_span.span.start_child(**kwargs)) |
|
parent_span.children.append(watched_span) |
|
if watched_span is None: |
|
watched_span = WatchedSpan(sentry_sdk.start_span(**kwargs)) |
|
|
|
if kwargs.get("op", "").startswith("ai.pipeline."): |
|
if kwargs.get("name"): |
|
set_ai_pipeline_name(kwargs.get("name")) |
|
watched_span.is_pipeline = True |
|
|
|
watched_span.span.__enter__() |
|
self.span_map[run_id] = watched_span |
|
self.gc_span_map() |
|
return watched_span |
|
|
|
def _exit_span(self, span_data, run_id): |
|
|
|
|
|
if span_data.is_pipeline: |
|
set_ai_pipeline_name(None) |
|
|
|
span_data.span.__exit__(None, None, None) |
|
del self.span_map[run_id] |
|
|
|
def on_llm_start( |
|
self, |
|
serialized, |
|
prompts, |
|
*, |
|
run_id, |
|
tags=None, |
|
parent_run_id=None, |
|
metadata=None, |
|
**kwargs, |
|
): |
|
|
|
"""Run when LLM starts running.""" |
|
with capture_internal_exceptions(): |
|
if not run_id: |
|
return |
|
all_params = kwargs.get("invocation_params", {}) |
|
all_params.update(serialized.get("kwargs", {})) |
|
watched_span = self._create_span( |
|
run_id, |
|
kwargs.get("parent_run_id"), |
|
op=OP.LANGCHAIN_RUN, |
|
name=kwargs.get("name") or "Langchain LLM call", |
|
origin=LangchainIntegration.origin, |
|
) |
|
span = watched_span.span |
|
if should_send_default_pii() and self.include_prompts: |
|
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompts) |
|
for k, v in DATA_FIELDS.items(): |
|
if k in all_params: |
|
set_data_normalized(span, v, all_params[k]) |
|
|
|
def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs): |
|
|
|
"""Run when Chat Model starts running.""" |
|
with capture_internal_exceptions(): |
|
if not run_id: |
|
return |
|
all_params = kwargs.get("invocation_params", {}) |
|
all_params.update(serialized.get("kwargs", {})) |
|
watched_span = self._create_span( |
|
run_id, |
|
kwargs.get("parent_run_id"), |
|
op=OP.LANGCHAIN_CHAT_COMPLETIONS_CREATE, |
|
name=kwargs.get("name") or "Langchain Chat Model", |
|
origin=LangchainIntegration.origin, |
|
) |
|
span = watched_span.span |
|
model = all_params.get( |
|
"model", all_params.get("model_name", all_params.get("model_id")) |
|
) |
|
watched_span.no_collect_tokens = any( |
|
x in all_params.get("_type", "") for x in NO_COLLECT_TOKEN_MODELS |
|
) |
|
|
|
if not model and "anthropic" in all_params.get("_type"): |
|
model = "claude-2" |
|
if model: |
|
span.set_data(SPANDATA.AI_MODEL_ID, model) |
|
if should_send_default_pii() and self.include_prompts: |
|
set_data_normalized( |
|
span, |
|
SPANDATA.AI_INPUT_MESSAGES, |
|
[ |
|
[self._normalize_langchain_message(x) for x in list_] |
|
for list_ in messages |
|
], |
|
) |
|
for k, v in DATA_FIELDS.items(): |
|
if k in all_params: |
|
set_data_normalized(span, v, all_params[k]) |
|
if not watched_span.no_collect_tokens: |
|
for list_ in messages: |
|
for message in list_: |
|
self.span_map[run_id].num_prompt_tokens += self.count_tokens( |
|
message.content |
|
) + self.count_tokens(message.type) |
|
|
|
def on_llm_new_token(self, token, *, run_id, **kwargs): |
|
|
|
"""Run on new LLM token. Only available when streaming is enabled.""" |
|
with capture_internal_exceptions(): |
|
if not run_id or run_id not in self.span_map: |
|
return |
|
span_data = self.span_map[run_id] |
|
if not span_data or span_data.no_collect_tokens: |
|
return |
|
span_data.num_completion_tokens += self.count_tokens(token) |
|
|
|
def on_llm_end(self, response, *, run_id, **kwargs): |
|
|
|
"""Run when LLM ends running.""" |
|
with capture_internal_exceptions(): |
|
if not run_id: |
|
return |
|
|
|
token_usage = ( |
|
response.llm_output.get("token_usage") if response.llm_output else None |
|
) |
|
|
|
span_data = self.span_map[run_id] |
|
if not span_data: |
|
return |
|
|
|
if should_send_default_pii() and self.include_prompts: |
|
set_data_normalized( |
|
span_data.span, |
|
SPANDATA.AI_RESPONSES, |
|
[[x.text for x in list_] for list_ in response.generations], |
|
) |
|
|
|
if not span_data.no_collect_tokens: |
|
if token_usage: |
|
record_token_usage( |
|
span_data.span, |
|
token_usage.get("prompt_tokens"), |
|
token_usage.get("completion_tokens"), |
|
token_usage.get("total_tokens"), |
|
) |
|
else: |
|
record_token_usage( |
|
span_data.span, |
|
span_data.num_prompt_tokens, |
|
span_data.num_completion_tokens, |
|
) |
|
|
|
self._exit_span(span_data, run_id) |
|
|
|
def on_llm_error(self, error, *, run_id, **kwargs): |
|
|
|
"""Run when LLM errors.""" |
|
with capture_internal_exceptions(): |
|
self._handle_error(run_id, error) |
|
|
|
def on_chain_start(self, serialized, inputs, *, run_id, **kwargs): |
|
|
|
"""Run when chain starts running.""" |
|
with capture_internal_exceptions(): |
|
if not run_id: |
|
return |
|
watched_span = self._create_span( |
|
run_id, |
|
kwargs.get("parent_run_id"), |
|
op=( |
|
OP.LANGCHAIN_RUN |
|
if kwargs.get("parent_run_id") is not None |
|
else OP.LANGCHAIN_PIPELINE |
|
), |
|
name=kwargs.get("name") or "Chain execution", |
|
origin=LangchainIntegration.origin, |
|
) |
|
metadata = kwargs.get("metadata") |
|
if metadata: |
|
set_data_normalized(watched_span.span, SPANDATA.AI_METADATA, metadata) |
|
|
|
def on_chain_end(self, outputs, *, run_id, **kwargs): |
|
|
|
"""Run when chain ends running.""" |
|
with capture_internal_exceptions(): |
|
if not run_id or run_id not in self.span_map: |
|
return |
|
|
|
span_data = self.span_map[run_id] |
|
if not span_data: |
|
return |
|
self._exit_span(span_data, run_id) |
|
|
|
def on_chain_error(self, error, *, run_id, **kwargs): |
|
|
|
"""Run when chain errors.""" |
|
self._handle_error(run_id, error) |
|
|
|
def on_agent_action(self, action, *, run_id, **kwargs): |
|
|
|
with capture_internal_exceptions(): |
|
if not run_id: |
|
return |
|
watched_span = self._create_span( |
|
run_id, |
|
kwargs.get("parent_run_id"), |
|
op=OP.LANGCHAIN_AGENT, |
|
name=action.tool or "AI tool usage", |
|
origin=LangchainIntegration.origin, |
|
) |
|
if action.tool_input and should_send_default_pii() and self.include_prompts: |
|
set_data_normalized( |
|
watched_span.span, SPANDATA.AI_INPUT_MESSAGES, action.tool_input |
|
) |
|
|
|
def on_agent_finish(self, finish, *, run_id, **kwargs): |
|
|
|
with capture_internal_exceptions(): |
|
if not run_id: |
|
return |
|
|
|
span_data = self.span_map[run_id] |
|
if not span_data: |
|
return |
|
if should_send_default_pii() and self.include_prompts: |
|
set_data_normalized( |
|
span_data.span, SPANDATA.AI_RESPONSES, finish.return_values.items() |
|
) |
|
self._exit_span(span_data, run_id) |
|
|
|
def on_tool_start(self, serialized, input_str, *, run_id, **kwargs): |
|
|
|
"""Run when tool starts running.""" |
|
with capture_internal_exceptions(): |
|
if not run_id: |
|
return |
|
watched_span = self._create_span( |
|
run_id, |
|
kwargs.get("parent_run_id"), |
|
op=OP.LANGCHAIN_TOOL, |
|
name=serialized.get("name") or kwargs.get("name") or "AI tool usage", |
|
origin=LangchainIntegration.origin, |
|
) |
|
if should_send_default_pii() and self.include_prompts: |
|
set_data_normalized( |
|
watched_span.span, |
|
SPANDATA.AI_INPUT_MESSAGES, |
|
kwargs.get("inputs", [input_str]), |
|
) |
|
if kwargs.get("metadata"): |
|
set_data_normalized( |
|
watched_span.span, SPANDATA.AI_METADATA, kwargs.get("metadata") |
|
) |
|
|
|
def on_tool_end(self, output, *, run_id, **kwargs): |
|
|
|
"""Run when tool ends running.""" |
|
with capture_internal_exceptions(): |
|
if not run_id or run_id not in self.span_map: |
|
return |
|
|
|
span_data = self.span_map[run_id] |
|
if not span_data: |
|
return |
|
if should_send_default_pii() and self.include_prompts: |
|
set_data_normalized(span_data.span, SPANDATA.AI_RESPONSES, output) |
|
self._exit_span(span_data, run_id) |
|
|
|
def on_tool_error(self, error, *args, run_id, **kwargs): |
|
|
|
"""Run when tool errors.""" |
|
self._handle_error(run_id, error) |
|
|
|
|
|
def _wrap_configure(f): |
|
|
|
|
|
@wraps(f) |
|
def new_configure(*args, **kwargs): |
|
|
|
|
|
integration = sentry_sdk.get_client().get_integration(LangchainIntegration) |
|
if integration is None: |
|
return f(*args, **kwargs) |
|
|
|
with capture_internal_exceptions(): |
|
new_callbacks = [] |
|
if "local_callbacks" in kwargs: |
|
existing_callbacks = kwargs["local_callbacks"] |
|
kwargs["local_callbacks"] = new_callbacks |
|
elif len(args) > 2: |
|
existing_callbacks = args[2] |
|
args = ( |
|
args[0], |
|
args[1], |
|
new_callbacks, |
|
) + args[3:] |
|
else: |
|
existing_callbacks = [] |
|
|
|
if existing_callbacks: |
|
if isinstance(existing_callbacks, list): |
|
for cb in existing_callbacks: |
|
new_callbacks.append(cb) |
|
elif isinstance(existing_callbacks, BaseCallbackHandler): |
|
new_callbacks.append(existing_callbacks) |
|
else: |
|
logger.debug("Unknown callback type: %s", existing_callbacks) |
|
|
|
already_added = False |
|
for callback in new_callbacks: |
|
if isinstance(callback, SentryLangchainCallback): |
|
already_added = True |
|
|
|
if not already_added: |
|
new_callbacks.append( |
|
SentryLangchainCallback( |
|
integration.max_spans, |
|
integration.include_prompts, |
|
integration.tiktoken_encoding_name, |
|
) |
|
) |
|
return f(*args, **kwargs) |
|
|
|
return new_configure |
|
|