Spaces:
Configuration error
Configuration error
# +------------------------------+ | |
# | |
# Banned Keywords | |
# | |
# +------------------------------+ | |
# Thank you users! We ❤️ you! - Krrish & Ishaan | |
## Reject a call / response if it contains certain keywords | |
from typing import Literal | |
import litellm | |
from litellm.caching.caching import DualCache | |
from litellm.proxy._types import UserAPIKeyAuth | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm._logging import verbose_proxy_logger | |
from fastapi import HTTPException | |
class _ENTERPRISE_BannedKeywords(CustomLogger): | |
# Class variables or attributes | |
def __init__(self): | |
banned_keywords_list = litellm.banned_keywords_list | |
if banned_keywords_list is None: | |
raise Exception( | |
"`banned_keywords_list` can either be a list or filepath. None set." | |
) | |
if isinstance(banned_keywords_list, list): | |
self.banned_keywords_list = banned_keywords_list | |
if isinstance(banned_keywords_list, str): # assume it's a filepath | |
try: | |
with open(banned_keywords_list, "r") as file: | |
data = file.read() | |
self.banned_keywords_list = data.split("\n") | |
except FileNotFoundError: | |
raise Exception( | |
f"File not found. banned_keywords_list={banned_keywords_list}" | |
) | |
except Exception as e: | |
raise Exception( | |
f"An error occurred: {str(e)}, banned_keywords_list={banned_keywords_list}" | |
) | |
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"): | |
if level == "INFO": | |
verbose_proxy_logger.info(print_statement) | |
elif level == "DEBUG": | |
verbose_proxy_logger.debug(print_statement) | |
if litellm.set_verbose is True: | |
print(print_statement) # noqa | |
def test_violation(self, test_str: str): | |
for word in self.banned_keywords_list: | |
if word in test_str.lower(): | |
raise HTTPException( | |
status_code=400, | |
detail={"error": f"Keyword banned. Keyword={word}"}, | |
) | |
async def async_pre_call_hook( | |
self, | |
user_api_key_dict: UserAPIKeyAuth, | |
cache: DualCache, | |
data: dict, | |
call_type: str, # "completion", "embeddings", "image_generation", "moderation" | |
): | |
try: | |
""" | |
- check if user id part of call | |
- check if user id part of blocked list | |
""" | |
self.print_verbose("Inside Banned Keyword List Pre-Call Hook") | |
if call_type == "completion" and "messages" in data: | |
for m in data["messages"]: | |
if "content" in m and isinstance(m["content"], str): | |
self.test_violation(test_str=m["content"]) | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
verbose_proxy_logger.exception( | |
"litellm.enterprise.enterprise_hooks.banned_keywords::async_pre_call_hook - Exception occurred - {}".format( | |
str(e) | |
) | |
) | |
async def async_post_call_success_hook( | |
self, | |
data: dict, | |
user_api_key_dict: UserAPIKeyAuth, | |
response, | |
): | |
if isinstance(response, litellm.ModelResponse) and isinstance( | |
response.choices[0], litellm.utils.Choices | |
): | |
for word in self.banned_keywords_list: | |
self.test_violation(test_str=response.choices[0].message.content or "") | |
async def async_post_call_streaming_hook( | |
self, | |
user_api_key_dict: UserAPIKeyAuth, | |
response: str, | |
): | |
self.test_violation(test_str=response) | |