Spaces:
Configuration error
Configuration error
File size: 7,912 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import sys
import os
import io, asyncio
import json
import pytest
import time
from litellm import mock_completion
from unittest.mock import MagicMock, AsyncMock, patch
sys.path.insert(0, os.path.abspath("../.."))
import litellm
from litellm.proxy.guardrails.guardrail_hooks.presidio import _OPTIONAL_PresidioPIIMasking, PresidioPerRequestConfig
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import StandardLoggingPayload, StandardLoggingGuardrailInformation
from litellm.types.guardrails import GuardrailEventHooks
from typing import Optional
class TestCustomLogger(CustomLogger):
def __init__(self, *args, **kwargs):
self.standard_logging_payload: Optional[StandardLoggingPayload] = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
self.standard_logging_payload = kwargs.get("standard_logging_object")
pass
@pytest.mark.asyncio
async def test_standard_logging_payload_includes_guardrail_information():
"""
Test that the standard logging payload includes the guardrail information when a guardrail is applied
"""
test_custom_logger = TestCustomLogger()
litellm.callbacks = [test_custom_logger]
presidio_guard = _OPTIONAL_PresidioPIIMasking(
guardrail_name="presidio_guard",
event_hook=GuardrailEventHooks.pre_call,
presidio_analyzer_api_base=os.getenv("PRESIDIO_ANALYZER_API_BASE"),
presidio_anonymizer_api_base=os.getenv("PRESIDIO_ANONYMIZER_API_BASE"),
)
# 1. call the pre call hook with guardrail
request_data = {
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Hello, my phone number is +1 412 555 1212"},
],
"mock_response": "Hello",
"guardrails": ["presidio_guard"],
"metadata": {},
}
await presidio_guard.async_pre_call_hook(
user_api_key_dict={},
cache=None,
data=request_data,
call_type="acompletion"
)
# 2. call litellm.acompletion
response = await litellm.acompletion(**request_data)
# 3. assert that the standard logging payload includes the guardrail information
await asyncio.sleep(1)
print("got standard logging payload=", json.dumps(test_custom_logger.standard_logging_payload, indent=4, default=str))
assert test_custom_logger.standard_logging_payload is not None
assert test_custom_logger.standard_logging_payload["guardrail_information"] is not None
assert test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_name"] == "presidio_guard"
assert test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_mode"] == GuardrailEventHooks.pre_call
# assert that the guardrail_response is a response from presidio analyze
presidio_response = test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_response"]
assert isinstance(presidio_response, list)
for response_item in presidio_response:
assert "analysis_explanation" in response_item
assert "start" in response_item
assert "end" in response_item
assert "score" in response_item
assert "entity_type" in response_item
assert "recognition_metadata" in response_item
# assert that the duration is not None
assert test_custom_logger.standard_logging_payload["guardrail_information"]["duration"] is not None
assert test_custom_logger.standard_logging_payload["guardrail_information"]["duration"] > 0
# assert that we get the count of masked entities
assert test_custom_logger.standard_logging_payload["guardrail_information"]["masked_entity_count"] is not None
assert test_custom_logger.standard_logging_payload["guardrail_information"]["masked_entity_count"]["PHONE_NUMBER"] == 1
@pytest.mark.asyncio
@pytest.mark.skip(reason="Local only test")
async def test_langfuse_trace_includes_guardrail_information():
"""
Test that the langfuse trace includes the guardrail information when a guardrail is applied
"""
import httpx
from unittest.mock import AsyncMock, patch
from litellm.integrations.langfuse.langfuse_prompt_management import LangfusePromptManagement
callback = LangfusePromptManagement(flush_interval=3)
import json
# Create a mock Response object
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = {"status": "success"}
# Create mock for httpx.Client.post
mock_post = AsyncMock()
mock_post.return_value = mock_response
with patch("httpx.Client.post", mock_post):
litellm._turn_on_debug()
litellm.callbacks = [callback]
presidio_guard = _OPTIONAL_PresidioPIIMasking(
guardrail_name="presidio_guard",
event_hook=GuardrailEventHooks.pre_call,
presidio_analyzer_api_base=os.getenv("PRESIDIO_ANALYZER_API_BASE"),
presidio_anonymizer_api_base=os.getenv("PRESIDIO_ANONYMIZER_API_BASE"),
)
# 1. call the pre call hook with guardrail
request_data = {
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Hello, my phone number is +1 412 555 1212"},
],
"mock_response": "Hello",
"guardrails": ["presidio_guard"],
"metadata": {},
}
await presidio_guard.async_pre_call_hook(
user_api_key_dict={},
cache=None,
data=request_data,
call_type="acompletion"
)
# 2. call litellm.acompletion
response = await litellm.acompletion(**request_data)
# 3. Wait for async logging operations to complete
await asyncio.sleep(5)
# 4. Verify the Langfuse payload
assert mock_post.call_count >= 1
url = mock_post.call_args[0][0]
request_body = mock_post.call_args[1].get("content")
# Parse the JSON body
actual_payload = json.loads(request_body)
print("\nLangfuse payload:", json.dumps(actual_payload, indent=2))
# Look for the guardrail span in the payload
guardrail_span = None
for item in actual_payload["batch"]:
if (item["type"] == "span-create" and
item["body"].get("name") == "guardrail"):
guardrail_span = item
break
# Assert that the guardrail span exists
assert guardrail_span is not None, "No guardrail span found in Langfuse payload"
# Validate the structure of the guardrail span
assert guardrail_span["body"]["name"] == "guardrail"
assert "metadata" in guardrail_span["body"]
assert guardrail_span["body"]["metadata"]["guardrail_name"] == "presidio_guard"
assert guardrail_span["body"]["metadata"]["guardrail_mode"] == GuardrailEventHooks.pre_call
assert "guardrail_masked_entity_count" in guardrail_span["body"]["metadata"]
assert guardrail_span["body"]["metadata"]["guardrail_masked_entity_count"]["PHONE_NUMBER"] == 1
# Validate the output format matches the expected structure
assert "output" in guardrail_span["body"]
assert isinstance(guardrail_span["body"]["output"], list)
assert len(guardrail_span["body"]["output"]) > 0
# Validate the first output item has the expected structure
output_item = guardrail_span["body"]["output"][0]
assert "entity_type" in output_item
assert output_item["entity_type"] == "PHONE_NUMBER"
assert "score" in output_item
assert "start" in output_item
assert "end" in output_item
assert "recognition_metadata" in output_item
assert "recognizer_name" in output_item["recognition_metadata"]
|