File size: 10,467 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
"""
Tests the following endpoints used by the UI 

/global/spend/logs
/global/spend/keys
/global/spend/models
/global/activity
/global/activity/model


For all tests - test the following:
- Response is valid 
- Response for Admin User is different from response from Internal User
"""

import os
import sys
import traceback
import uuid
from datetime import datetime

from dotenv import load_dotenv
from fastapi import Request
from fastapi.routing import APIRoute

load_dotenv()
import io
import os
import time

# this file is to test litellm/proxy

sys.path.insert(
    0, os.path.abspath("../..")
)  # Adds the parent directory to the system path
import asyncio
import logging

import pytest

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy.management_endpoints.internal_user_endpoints import (
    new_user,
    user_info,
    user_update,
)
from litellm.proxy.management_endpoints.key_management_endpoints import (
    delete_key_fn,
    generate_key_fn,
    generate_key_helper_fn,
    info_key_fn,
    regenerate_key_fn,
    update_key_fn,
)
from litellm.proxy.management_endpoints.team_endpoints import (
    new_team,
    team_info,
    update_team,
)
from litellm.proxy.proxy_server import (
    LitellmUserRoles,
    audio_transcriptions,
    chat_completion,
    completion,
    embeddings,
    model_list,
    moderations,
    user_api_key_auth,
)
from litellm.proxy.management_endpoints.customer_endpoints import (
    new_end_user,
)
from litellm.proxy.spend_tracking.spend_management_endpoints import (
    global_spend,
    global_spend_logs,
    global_spend_models,
    global_spend_keys,
    spend_key_fn,
    spend_user_fn,
    view_spend_logs,
)
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend

verbose_proxy_logger.setLevel(level=logging.DEBUG)

from starlette.datastructures import URL

from litellm.caching.caching import DualCache
from litellm.proxy._types import (
    DynamoDBArgs,
    GenerateKeyRequest,
    RegenerateKeyRequest,
    KeyRequest,
    LiteLLM_UpperboundKeyGenerateParams,
    NewCustomerRequest,
    NewTeamRequest,
    NewUserRequest,
    ProxyErrorTypes,
    ProxyException,
    UpdateKeyRequest,
    UpdateTeamRequest,
    UpdateUserRequest,
    UserAPIKeyAuth,
)

proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())


@pytest.fixture
def prisma_client():
    from litellm.proxy.proxy_cli import append_query_params

    ### add connection pool + pool timeout args
    params = {"connection_limit": 100, "pool_timeout": 60}
    database_url = os.getenv("DATABASE_URL")
    modified_url = append_query_params(database_url, params)
    os.environ["DATABASE_URL"] = modified_url

    # Assuming PrismaClient is a class that needs to be instantiated
    prisma_client = PrismaClient(
        database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
    )

    # Reset litellm.proxy.proxy_server.prisma_client to None
    litellm.proxy.proxy_server.litellm_proxy_budget_name = (
        f"litellm-proxy-budget-{time.time()}"
    )
    litellm.proxy.proxy_server.user_custom_key_generate = None

    return prisma_client


@pytest.mark.asyncio()
async def test_view_daily_spend_ui(prisma_client):
    print("prisma client=", prisma_client)
    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")

    await litellm.proxy.proxy_server.prisma_client.connect()
    from litellm.proxy.proxy_server import user_api_key_cache

    spend_logs_for_admin = await global_spend_logs(
        user_api_key_dict=UserAPIKeyAuth(
            api_key="sk-1234",
            user_role=LitellmUserRoles.PROXY_ADMIN,
        ),
        api_key=None,
    )

    print("spend_logs_for_admin=", spend_logs_for_admin)

    spend_logs_for_internal_user = await global_spend_logs(
        user_api_key_dict=UserAPIKeyAuth(
            api_key="sk-1234", user_role=LitellmUserRoles.INTERNAL_USER, user_id="1234"
        ),
        api_key=None,
    )

    print("spend_logs_for_internal_user=", spend_logs_for_internal_user)

    # Calculate total spend for admin
    admin_total_spend = sum(log.get("spend", 0) for log in spend_logs_for_admin)

    # Calculate total spend for internal user (0 in this case, but we'll keep it generic)
    internal_user_total_spend = sum(
        log.get("spend", 0) for log in spend_logs_for_internal_user
    )

    print("total_spend_for_admin=", admin_total_spend)
    print("total_spend_for_internal_user=", internal_user_total_spend)

    assert (
        admin_total_spend > internal_user_total_spend
    ), "Admin should have more spend than internal user"


@pytest.mark.asyncio
async def test_global_spend_models(prisma_client):
    print("prisma client=", prisma_client)
    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")

    await litellm.proxy.proxy_server.prisma_client.connect()

    # Test for admin user
    models_spend_for_admin = await global_spend_models(
        limit=10,
        user_api_key_dict=UserAPIKeyAuth(
            api_key="sk-1234",
            user_role=LitellmUserRoles.PROXY_ADMIN,
        ),
    )

    print("models_spend_for_admin=", models_spend_for_admin)

    # Test for internal user
    models_spend_for_internal_user = await global_spend_models(
        limit=10,
        user_api_key_dict=UserAPIKeyAuth(
            api_key="sk-1234", user_role=LitellmUserRoles.INTERNAL_USER, user_id="1234"
        ),
    )

    print("models_spend_for_internal_user=", models_spend_for_internal_user)

    # Assertions
    assert isinstance(models_spend_for_admin, list), "Admin response should be a list"
    assert isinstance(
        models_spend_for_internal_user, list
    ), "Internal user response should be a list"

    # Check if the response has the expected shape for both admin and internal user
    expected_keys = ["model", "total_spend"]

    if len(models_spend_for_admin) > 0:
        assert all(
            key in models_spend_for_admin[0] for key in expected_keys
        ), f"Admin response should contain keys: {expected_keys}"
        assert isinstance(
            models_spend_for_admin[0]["model"], str
        ), "Model should be a string"
        assert isinstance(
            models_spend_for_admin[0]["total_spend"], (int, float)
        ), "Total spend should be a number"

    if len(models_spend_for_internal_user) > 0:
        assert all(
            key in models_spend_for_internal_user[0] for key in expected_keys
        ), f"Internal user response should contain keys: {expected_keys}"
        assert isinstance(
            models_spend_for_internal_user[0]["model"], str
        ), "Model should be a string"
        assert isinstance(
            models_spend_for_internal_user[0]["total_spend"], (int, float)
        ), "Total spend should be a number"

    # Check if the lists are sorted by total_spend in descending order
    if len(models_spend_for_admin) > 1:
        assert all(
            models_spend_for_admin[i]["total_spend"]
            >= models_spend_for_admin[i + 1]["total_spend"]
            for i in range(len(models_spend_for_admin) - 1)
        ), "Admin response should be sorted by total_spend in descending order"

    if len(models_spend_for_internal_user) > 1:
        assert all(
            models_spend_for_internal_user[i]["total_spend"]
            >= models_spend_for_internal_user[i + 1]["total_spend"]
            for i in range(len(models_spend_for_internal_user) - 1)
        ), "Internal user response should be sorted by total_spend in descending order"

    # Check if admin has access to more or equal models compared to internal user
    assert len(models_spend_for_admin) >= len(
        models_spend_for_internal_user
    ), "Admin should have access to at least as many models as internal user"

    # Check if the response contains expected fields
    if len(models_spend_for_admin) > 0:
        assert all(
            key in models_spend_for_admin[0] for key in ["model", "total_spend"]
        ), "Admin response should contain model, total_spend, and total_tokens"

    if len(models_spend_for_internal_user) > 0:
        assert all(
            key in models_spend_for_internal_user[0] for key in ["model", "total_spend"]
        ), "Internal user response should contain model, total_spend, and total_tokens"


@pytest.mark.asyncio
async def test_global_spend_keys(prisma_client):
    print("prisma client=", prisma_client)
    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")

    await litellm.proxy.proxy_server.prisma_client.connect()

    # Test for admin user
    keys_spend_for_admin = await global_spend_keys(
        limit=10,
        user_api_key_dict=UserAPIKeyAuth(
            api_key="sk-1234",
            user_role=LitellmUserRoles.PROXY_ADMIN,
        ),
    )

    print("keys_spend_for_admin=", keys_spend_for_admin)

    # Test for internal user
    keys_spend_for_internal_user = await global_spend_keys(
        limit=10,
        user_api_key_dict=UserAPIKeyAuth(
            api_key="sk-1234", user_role=LitellmUserRoles.INTERNAL_USER, user_id="1234"
        ),
    )

    print("keys_spend_for_internal_user=", keys_spend_for_internal_user)

    # Assertions
    assert isinstance(keys_spend_for_admin, list), "Admin response should be a list"
    assert isinstance(
        keys_spend_for_internal_user, list
    ), "Internal user response should be a list"

    # Check if admin has access to more or equal keys compared to internal user
    assert len(keys_spend_for_admin) >= len(
        keys_spend_for_internal_user
    ), "Admin should have access to at least as many keys as internal user"

    # Check if the response contains expected fields
    if len(keys_spend_for_admin) > 0:
        assert all(
            key in keys_spend_for_admin[0]
            for key in ["api_key", "total_spend", "key_alias", "key_name"]
        ), "Admin response should contain api_key, total_spend, key_alias, and key_name"

    if len(keys_spend_for_internal_user) > 0:
        assert all(
            key in keys_spend_for_internal_user[0]
            for key in ["api_key", "total_spend", "key_alias", "key_name"]
        ), "Internal user response should contain api_key, total_spend, key_alias, and key_name"