Spaces:
Configuration error
Configuration error
import pytest | |
from fastapi.testclient import TestClient | |
from fastapi import Request, Header | |
from unittest.mock import patch, MagicMock, AsyncMock | |
import sys | |
import os | |
sys.path.insert( | |
0, os.path.abspath("../..") | |
) # Adds the parent directory to the system path | |
import litellm | |
from litellm.proxy.proxy_server import app | |
from litellm.proxy.utils import PrismaClient, ProxyLogging | |
from litellm.proxy.management_endpoints.ui_sso import auth_callback | |
from litellm.proxy._types import LitellmUserRoles | |
import os | |
import jwt | |
import time | |
from litellm.caching.caching import DualCache | |
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) | |
def mock_env_vars(monkeypatch): | |
monkeypatch.setenv("GOOGLE_CLIENT_ID", "mock_google_client_id") | |
monkeypatch.setenv("GOOGLE_CLIENT_SECRET", "mock_google_client_secret") | |
monkeypatch.setenv("PROXY_BASE_URL", "http://testserver") | |
monkeypatch.setenv("LITELLM_MASTER_KEY", "mock_master_key") | |
def prisma_client(): | |
from litellm.proxy.proxy_cli import append_query_params | |
### add connection pool + pool timeout args | |
params = {"connection_limit": 100, "pool_timeout": 60} | |
database_url = os.getenv("DATABASE_URL") | |
modified_url = append_query_params(database_url, params) | |
os.environ["DATABASE_URL"] = modified_url | |
# Assuming PrismaClient is a class that needs to be instantiated | |
prisma_client = PrismaClient( | |
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj | |
) | |
# Reset litellm.proxy.proxy_server.prisma_client to None | |
litellm.proxy.proxy_server.litellm_proxy_budget_name = ( | |
f"litellm-proxy-budget-{time.time()}" | |
) | |
litellm.proxy.proxy_server.user_custom_key_generate = None | |
return prisma_client | |
async def test_auth_callback_new_user(mock_google_sso, mock_env_vars, prisma_client): | |
""" | |
Tests that a new SSO Sign In user is by default given an 'INTERNAL_USER_VIEW_ONLY' role | |
""" | |
import uuid | |
import litellm | |
litellm._turn_on_debug() | |
# Generate a unique user ID | |
unique_user_id = str(uuid.uuid4()) | |
unique_user_email = f"newuser{unique_user_id}@example.com" | |
try: | |
# Set up the prisma client | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
# Set up the master key | |
litellm.proxy.proxy_server.master_key = "mock_master_key" | |
# Mock the GoogleSSO verify_and_process method | |
mock_sso_result = MagicMock() | |
mock_sso_result.email = unique_user_email | |
mock_sso_result.id = unique_user_id | |
mock_sso_result.provider = "google" | |
mock_google_sso.return_value.verify_and_process = AsyncMock( | |
return_value=mock_sso_result | |
) | |
# Create a mock Request object | |
mock_request = Request( | |
scope={ | |
"type": "http", | |
"method": "GET", | |
"scheme": "http", | |
"server": ("testserver", 80), | |
"path": "/sso/callback", | |
"query_string": b"", | |
"headers": {}, | |
} | |
) | |
# Call the auth_callback function directly | |
response = await auth_callback(request=mock_request) | |
# Assert the response | |
assert response.status_code == 303 | |
assert response.headers["location"].startswith(f"http://testserver/ui/?login=success") | |
# Verify that the user was added to the database | |
user = await prisma_client.db.litellm_usertable.find_first( | |
where={"user_id": unique_user_id} | |
) | |
print("inserted user from SSO", user) | |
assert user is not None | |
assert user.user_email == unique_user_email | |
assert user.user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY | |
assert user.metadata == {"auth_provider": "google"} | |
finally: | |
# Clean up: Delete the user from the database | |
await prisma_client.db.litellm_usertable.delete( | |
where={"user_id": unique_user_id} | |
) | |
async def test_auth_callback_new_user_with_sso_default( | |
mock_google_sso, mock_env_vars, prisma_client | |
): | |
""" | |
When litellm_settings.default_internal_user_params.user_role = 'INTERNAL_USER' | |
Tests that a new SSO Sign In user is by default given an 'INTERNAL_USER' role | |
""" | |
import uuid | |
# Generate a unique user ID | |
unique_user_id = str(uuid.uuid4()) | |
unique_user_email = f"newuser{unique_user_id}@example.com" | |
try: | |
# Set up the prisma client | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
litellm.default_internal_user_params = { | |
"user_role": LitellmUserRoles.INTERNAL_USER.value | |
} | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
# Set up the master key | |
litellm.proxy.proxy_server.master_key = "mock_master_key" | |
# Mock the GoogleSSO verify_and_process method | |
mock_sso_result = MagicMock() | |
mock_sso_result.email = unique_user_email | |
mock_sso_result.id = unique_user_id | |
mock_sso_result.provider = "google" | |
mock_google_sso.return_value.verify_and_process = AsyncMock( | |
return_value=mock_sso_result | |
) | |
# Create a mock Request object | |
mock_request = Request( | |
scope={ | |
"type": "http", | |
"method": "GET", | |
"scheme": "http", | |
"server": ("testserver", 80), | |
"path": "/sso/callback", | |
"query_string": b"", | |
"headers": {}, | |
} | |
) | |
# Call the auth_callback function directly | |
response = await auth_callback(request=mock_request) | |
# Assert the response | |
assert response.status_code == 303 | |
assert response.headers["location"].startswith(f"http://testserver/ui/?login=success") | |
# Verify that the user was added to the database | |
user = await prisma_client.db.litellm_usertable.find_first( | |
where={"user_id": unique_user_id} | |
) | |
print("inserted user from SSO", user) | |
assert user is not None | |
assert user.user_email == unique_user_email | |
assert user.user_role == LitellmUserRoles.INTERNAL_USER | |
finally: | |
# Clean up: Delete the user from the database | |
await prisma_client.db.litellm_usertable.delete( | |
where={"user_id": unique_user_id} | |
) | |
litellm.default_internal_user_params = None | |