File size: 13,960 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
import os
import sys
import traceback
import uuid
from typing import List
from datetime import datetime, timezone, timedelta

from dotenv import load_dotenv
from fastapi import Request
from fastapi.routing import APIRoute
import httpx
import json
from unittest.mock import MagicMock, patch
load_dotenv()
import io
import os
import time
import fakeredis

# this file is to test litellm/proxy

sys.path.insert(
    0, os.path.abspath("../..")
)  # Adds the parent directory to the system path
import asyncio
import logging

import pytest
from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy.management_endpoints.internal_user_endpoints import (
    new_user,
    user_info,
    user_update,
)
from litellm.proxy.auth.auth_checks import get_key_object
from litellm.proxy.management_endpoints.key_management_endpoints import (
    delete_key_fn,
    generate_key_fn,
    generate_key_helper_fn,
    info_key_fn,
    list_keys,
    regenerate_key_fn,
    update_key_fn,
)
from litellm.proxy.management_endpoints.team_endpoints import (
    new_team,
    team_info,
    update_team,
)
from litellm.proxy.proxy_server import (
    LitellmUserRoles,
    audio_transcriptions,
    chat_completion,
    completion,
    embeddings,
    model_list,
    moderations,
    user_api_key_auth,
)
from litellm.proxy.management_endpoints.customer_endpoints import (
    new_end_user,
)
from litellm.proxy.spend_tracking.spend_management_endpoints import (
    global_spend,
    spend_key_fn,
    spend_user_fn,
    view_spend_logs,
)
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend

verbose_proxy_logger.setLevel(level=logging.DEBUG)

from starlette.datastructures import URL

from litellm.caching.caching import DualCache, RedisCache
from litellm.proxy._types import (
    DynamoDBArgs,
    GenerateKeyRequest,
    KeyRequest,
    LiteLLM_UpperboundKeyGenerateParams,
    NewCustomerRequest,
    NewTeamRequest,
    NewUserRequest,
    ProxyErrorTypes,
    ProxyException,
    UpdateKeyRequest,
    UpdateTeamRequest,
    UpdateUserRequest,
    UserAPIKeyAuth,
)

proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())


request_data = {
    "model": "azure-gpt-3.5",
    "messages": [
        {"role": "user", "content": "this is my new test. respond in 50 lines"}
    ],
}



@pytest.fixture
def prisma_client():
    from litellm.proxy.proxy_cli import append_query_params

    ### add connection pool + pool timeout args
    params = {"connection_limit": 100, "pool_timeout": 60}
    database_url = os.getenv("DATABASE_URL")
    modified_url = append_query_params(database_url, params)
    os.environ["DATABASE_URL"] = modified_url

    # Assuming PrismaClient is a class that needs to be instantiated
    prisma_client = PrismaClient(
        database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
    )

    # Reset litellm.proxy.proxy_server.prisma_client to None
    litellm.proxy.proxy_server.litellm_proxy_budget_name = (
        f"litellm-proxy-budget-{time.time()}"
    )
    litellm.proxy.proxy_server.user_custom_key_generate = None

    return prisma_client


async def setup_db_connection(prisma_client):
    setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
    setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
    await litellm.proxy.proxy_server.prisma_client.connect()


@pytest.mark.asyncio
async def test_pod_lock_acquisition_when_no_active_lock():
    """Test if a pod can acquire a lock when no lock is active"""
    cronjob_id = str(uuid.uuid4())
    global_redis_cache = RedisCache(
        host=os.getenv("REDIS_HOST"),
        port=os.getenv("REDIS_PORT"),
        password=os.getenv("REDIS_PASSWORD"),
    )
    lock_manager = PodLockManager(redis_cache=global_redis_cache)

    # Attempt to acquire lock
    result = await lock_manager.acquire_lock(
        cronjob_id=cronjob_id,
    )

    assert result == True, "Pod should be able to acquire lock when no lock exists"

    # Verify in database
    lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
    lock_record = await global_redis_cache.async_get_cache(lock_key)
    print("lock_record=", lock_record)
    assert lock_record == lock_manager.pod_id



@pytest.mark.asyncio
async def test_pod_lock_acquisition_after_completion():
    """Test if a new pod can acquire lock after previous pod completes"""
    cronjob_id = str(uuid.uuid4())
    # First pod acquires and releases lock
    global_redis_cache = RedisCache(
        host=os.getenv("REDIS_HOST"),
        port=os.getenv("REDIS_PORT"),
        password=os.getenv("REDIS_PASSWORD"),
    )
    first_lock_manager = PodLockManager(redis_cache=global_redis_cache)
    await first_lock_manager.acquire_lock(
        cronjob_id=cronjob_id,
    )
    await first_lock_manager.release_lock(
        cronjob_id=cronjob_id,
    )

    # Second pod attempts to acquire lock
    second_lock_manager = PodLockManager(redis_cache=global_redis_cache)
    result = await second_lock_manager.acquire_lock(
        cronjob_id=cronjob_id,
    )

    assert result == True, "Second pod should acquire lock after first pod releases it"

    # Verify in redis
    lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
    lock_record = await global_redis_cache.async_get_cache(lock_key)
    assert lock_record == second_lock_manager.pod_id


@pytest.mark.asyncio
async def test_pod_lock_acquisition_after_expiry():
    """Test if a new pod can acquire lock after previous pod's lock expires"""
    cronjob_id = str(uuid.uuid4())
    # First pod acquires lock
    global_redis_cache = RedisCache(
        host=os.getenv("REDIS_HOST"),
        port=os.getenv("REDIS_PORT"),
        password=os.getenv("REDIS_PASSWORD"),
    )
    first_lock_manager = PodLockManager(redis_cache=global_redis_cache)
    await first_lock_manager.acquire_lock(
        cronjob_id=cronjob_id,
    )

    # release the lock from the first pod
    await first_lock_manager.release_lock(
        cronjob_id=cronjob_id,
    )

    # Second pod attempts to acquire lock
    second_lock_manager = PodLockManager(redis_cache=global_redis_cache)
    result = await second_lock_manager.acquire_lock(
        cronjob_id=cronjob_id,
    )

    assert (
        result == True
    ), "Second pod should acquire lock after first pod's lock expires"

    # Verify in redis
    lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
    lock_record = await global_redis_cache.async_get_cache(lock_key)
    assert lock_record == second_lock_manager.pod_id


@pytest.mark.asyncio
async def test_pod_lock_release():
    """Test if a pod can successfully release its lock"""
    cronjob_id = str(uuid.uuid4())
    global_redis_cache = RedisCache(
        host=os.getenv("REDIS_HOST"),
        port=os.getenv("REDIS_PORT"),
        password=os.getenv("REDIS_PASSWORD"),
    )
    lock_manager = PodLockManager(redis_cache=global_redis_cache)

    # Acquire and then release lock
    await lock_manager.acquire_lock(
        cronjob_id=cronjob_id,
    )
    await lock_manager.release_lock(
        cronjob_id=cronjob_id,
    )

    # Verify in redis
    lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
    lock_record = await global_redis_cache.async_get_cache(lock_key)
    assert lock_record is None


@pytest.mark.asyncio
async def test_concurrent_lock_acquisition():
    """Test that only one pod can acquire the lock when multiple pods try simultaneously"""

    cronjob_id = str(uuid.uuid4())
    global_redis_cache = RedisCache(
        host=os.getenv("REDIS_HOST"),
        port=os.getenv("REDIS_PORT"),
        password=os.getenv("REDIS_PASSWORD"),
    )
    # Create multiple lock managers simulating different pods
    lock_manager1 = PodLockManager(redis_cache=global_redis_cache)
    lock_manager2 = PodLockManager(redis_cache=global_redis_cache)
    lock_manager3 = PodLockManager(redis_cache=global_redis_cache)

    # Try to acquire locks concurrently
    results = await asyncio.gather(
        lock_manager1.acquire_lock(
            cronjob_id=cronjob_id,
        ),
        lock_manager2.acquire_lock(
            cronjob_id=cronjob_id,
        ),
        lock_manager3.acquire_lock(
            cronjob_id=cronjob_id,
        ),
    )

    # Only one should succeed
    print("all results=", results)
    assert sum(results) == 1, "Only one pod should acquire the lock"

    # Verify in redis
    lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
    lock_record = await global_redis_cache.async_get_cache(lock_key)
    assert lock_record in [
        lock_manager1.pod_id,
        lock_manager2.pod_id,
        lock_manager3.pod_id,
    ]



@pytest.mark.asyncio
async def test_lock_acquisition_with_expired_ttl():
    """Test that a pod can acquire a lock when existing lock has expired TTL"""
    cronjob_id = str(uuid.uuid4())
    global_redis_cache = RedisCache(
        host=os.getenv("REDIS_HOST"),
        port=os.getenv("REDIS_PORT"),
        password=os.getenv("REDIS_PASSWORD"),
    )
    first_lock_manager = PodLockManager(redis_cache=global_redis_cache)

    # First pod acquires lock with a very short TTL to simulate expiration
    short_ttl = 1  # 1 second
    lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
    await global_redis_cache.async_set_cache(
        lock_key,
        first_lock_manager.pod_id,
        ttl=short_ttl,
    )

    # Wait for the lock to expire
    await asyncio.sleep(short_ttl + 0.5)  # Wait slightly longer than the TTL

    # Second pod tries to acquire without explicit release
    second_lock_manager = PodLockManager(redis_cache=global_redis_cache)
    result = await second_lock_manager.acquire_lock(
        cronjob_id=cronjob_id,
    )

    assert result == True, "Should acquire lock when existing lock has expired TTL"

    # Verify in Redis
    lock_record = await global_redis_cache.async_get_cache(lock_key)
    print("lock_record=", lock_record)
    assert lock_record == second_lock_manager.pod_id


@pytest.mark.asyncio
async def test_release_expired_lock():
    """Test that a pod cannot release a lock that has been taken over by another pod"""
    cronjob_id = str(uuid.uuid4())
    global_redis_cache = RedisCache(
        host=os.getenv("REDIS_HOST"),
        port=os.getenv("REDIS_PORT"),
        password=os.getenv("REDIS_PASSWORD"),
    )
    # First pod acquires lock with a very short TTL
    first_lock_manager = PodLockManager(redis_cache=global_redis_cache)
    short_ttl = 1  # 1 second
    lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
    await global_redis_cache.async_set_cache(
        lock_key,
        first_lock_manager.pod_id,
        ttl=short_ttl,
    )

    # Wait for the lock to expire
    await asyncio.sleep(short_ttl + 0.5)  # Wait slightly longer than the TTL

    # Second pod acquires the lock
    second_lock_manager = PodLockManager(redis_cache=global_redis_cache)
    await second_lock_manager.acquire_lock(
        cronjob_id=cronjob_id,
    )

    # First pod attempts to release its lock
    await first_lock_manager.release_lock(
        cronjob_id=cronjob_id,
    )

    # Verify that second pod's lock is still active
    lock_record = await global_redis_cache.async_get_cache(lock_key)
    assert lock_record == second_lock_manager.pod_id

@pytest.mark.asyncio
async def test_e2e_size_of_redis_buffer():
    """
    Ensure that all elements from the redis queue's get flushed to the DB

    Goal of this is to ensure Redis does not blow up in size
    """
    from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
    from litellm.proxy.db.db_transaction_queue.base_update_queue import BaseUpdateQueue
    from litellm.caching import RedisCache
    import uuid


    redis_cache = RedisCache(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
    fake_redis_client = fakeredis.FakeAsyncRedis()
    redis_cache.redis_async_client = fake_redis_client

    setattr(litellm.proxy.proxy_server, "use_redis_transaction_buffer", True)
    db_writer = DBSpendUpdateWriter(redis_cache=redis_cache)
    
    # get all the queues
    initialized_queues: List[BaseUpdateQueue] = []
    for attr in dir(db_writer):
        if isinstance(getattr(db_writer, attr), BaseUpdateQueue):
            initialized_queues.append(getattr(db_writer, attr))

    # add mock data to each queue
    new_keys_added = []
    for queue in initialized_queues:
        key = f"test_key_{queue.__class__.__name__}_{uuid.uuid4()}"
        new_keys_added.append(key)
        await queue.add_update({key: {"spend": 1.0, "entity_id": "test_entity_id", "entity_type": "user", "api_key": "test_api_key", "model": "test_model", "custom_llm_provider": "test_custom_llm_provider", "date": "2025-01-01", "prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200, "response_cost": 1.0, "api_requests": 1, "successful_requests": 1, "failed_requests": 0}})
    
    print("initialized_queues=", initialized_queues)
    print("new_keys_added=", new_keys_added)

    # get the size of each queue
    for queue in initialized_queues:
        assert queue.update_queue.qsize() == 1, f"Queue {queue.__class__.__name__} was not initialized with mock data. Expected size 1, got {queue.update_queue.qsize()}"


    # flush from in-memory -> redis -> to DB
    with patch("litellm.proxy.db.db_spend_update_writer.PodLockManager.acquire_lock", return_value=True):
        await db_writer._commit_spend_updates_to_db_with_redis(
            prisma_client=MagicMock(),
            n_retry_times=3,
            proxy_logging_obj=MagicMock()
        )
    
    # Verify all the keys were looked up in Redis
    keys = await fake_redis_client.keys("*")
    print("found keys even after flushing to DB", keys)
    assert len(keys) == 0, f"Expected Redis to be empty, but found keys: {keys}"