hadadrjt commited on
Commit
d1ba698
·
1 Parent(s): 360d36a

api: Shut up!

Browse files
Files changed (1) hide show
  1. app.py +41 -55
app.py CHANGED
@@ -8,6 +8,7 @@ import time
8
  import json
9
  import asyncio
10
  import logging
 
11
  from typing import Optional, List, Union, Dict, Any, Literal
12
  from fastapi import FastAPI, HTTPException, Request, status
13
  from fastapi.middleware.cors import CORSMiddleware
@@ -29,6 +30,7 @@ class SessionData:
29
  self.history: List[Dict[str, Any]] = []
30
  self.last_access: float = time.time()
31
  self.active_tasks: Dict[str, asyncio.Task] = {}
 
32
 
33
  class SessionManager:
34
  def __init__(self):
@@ -48,55 +50,49 @@ class SessionManager:
48
  for task_id, task in data.active_tasks.items():
49
  if not task.done():
50
  task.cancel()
51
- logger.info(f"Cancelled active task {task_id} for expired session {sid}")
52
  for user, sid in expired:
53
  if user in self.sessions and sid in self.sessions[user]:
54
  del self.sessions[user][sid]
55
  if not self.sessions[user]:
56
  del self.sessions[user]
57
- logger.info(f"Session expired: user={user} session={sid}")
58
 
59
  async def get_session(self, user: Optional[str], session_id: Optional[str]) -> (str, str, SessionData):
60
  async with self.lock:
61
  if not user:
62
  user = str(uuid.uuid4())
63
- logger.debug(f"Generated new user ID: {user}")
64
  if user not in self.sessions:
65
  self.sessions[user] = {}
66
  if not session_id or session_id not in self.sessions[user]:
67
  session_id = str(uuid.uuid4())
68
  self.sessions[user][session_id] = SessionData()
69
- logger.info(f"Created new session: user={user} session={session_id}")
70
  session = self.sessions[user][session_id]
71
  session.last_access = time.time()
72
- logger.debug(f"Session accessed: user={user} session={session_id} history_length={len(session.history)}")
73
  return user, session_id, session
74
 
75
  session_manager = SessionManager()
76
 
77
  async def refresh_client(app: FastAPI):
78
  while True:
79
- await asyncio.sleep(1)
80
  async with app.state.client_lock:
81
- if app.state.client is None:
82
- await asyncio.sleep(1)
83
- continue
84
- while True:
85
- await asyncio.sleep(15)
86
- async with app.state.client_lock:
87
- if app.state.client is not None:
88
- try:
89
- old_client = app.state.client
90
- app.state.client = None
91
- del old_client
92
- app.state.client = Client("https://hadadrjt-ai.hf.space/")
93
- logger.info("Refreshed Gradio client connection")
94
- except Exception as e:
95
- logger.error(f"Error refreshing Gradio client: {e}", exc_info=True)
96
- app.state.client = None
97
- await asyncio.sleep(5)
98
- else:
99
- break
100
 
101
  @asynccontextmanager
102
  async def lifespan(app: FastAPI):
@@ -104,13 +100,15 @@ async def lifespan(app: FastAPI):
104
  app.state.client = None
105
  app.state.client_lock = asyncio.Lock()
106
  app.state.refresh_task = asyncio.create_task(refresh_client(app))
107
- logger.info("App lifespan started, refresh client task running")
 
108
  try:
109
  yield
110
  finally:
111
  app.state.refresh_task.cancel()
 
 
112
  await asyncio.sleep(0.1)
113
- logger.info("App lifespan ended, refresh client task cancelled")
114
 
115
  app = FastAPI(
116
  title="J.A.R.V.I.S. OpenAI-Compatible API",
@@ -220,20 +218,16 @@ async def get_client(app: FastAPI) -> Client:
220
  async def call_gradio(client: Client, params: dict):
221
  for attempt in range(3):
222
  try:
223
- logger.debug(f"Calling Gradio attempt {attempt+1}")
224
  return await asyncio.to_thread(lambda: client.submit(**params))
225
  except Exception as e:
226
- logger.warning(f"Gradio call failed attempt {attempt+1}: {e}", exc_info=True)
227
  await asyncio.sleep(0.2 * (attempt + 1))
228
- logger.error("Gradio upstream error after 3 attempts")
229
  raise HTTPException(status_code=502, detail="Upstream Gradio app error")
230
 
231
  async def stream_response(job, session_id: str, session_history: List[Dict[str, Any]], new_messages: List[Message], response_type: str):
232
  partial = ""
233
  try:
234
  chunks = await asyncio.to_thread(lambda: list(job))
235
- except Exception as e:
236
- logger.error(f"Streaming error: {e}", exc_info=True)
237
  chunks = []
238
  for chunk in chunks:
239
  try:
@@ -258,8 +252,7 @@ async def stream_response(job, session_id: str, session_history: List[Dict[str,
258
  "session_id": session_id
259
  }
260
  yield f"data: {json.dumps(data)}\n\n"
261
- except Exception as e:
262
- logger.error(f"Chunk yield error: {e}", exc_info=True)
263
  continue
264
  session_history.extend([m.model_dump() for m in new_messages if m.role != "system"])
265
  session_history.append({"role": "assistant", "content": partial})
@@ -271,9 +264,15 @@ async def stream_response(job, session_id: str, session_history: List[Dict[str,
271
  }
272
  yield f"data: {json.dumps(done_data)}\n\n"
273
 
 
 
274
  @app.post("/v1/chat/completions")
275
  async def chat_completions(req: ChatCompletionRequest):
276
  user, session_id, session = await session_manager.get_session(req.user, req.session_id)
 
 
 
 
277
  req.messages = sanitize_messages(req.messages)
278
  for m in req.messages:
279
  if m.role == "system":
@@ -299,7 +298,6 @@ async def chat_completions(req: ChatCompletionRequest):
299
  "function_call": req.function_call or req.tool_choice,
300
  }
301
  params = {k: v for k, v in params.items() if v is not None}
302
- logger.info(f"Chat completion request user={user} session={session_id} model={req.model} stream={req.stream}")
303
  client = await get_client(app)
304
  if req.stream:
305
  job = await call_gradio(client, params)
@@ -309,12 +307,10 @@ async def chat_completions(req: ChatCompletionRequest):
309
  loop = asyncio.get_running_loop()
310
  try:
311
  result = await loop.run_in_executor(None, lambda: client.predict(**params))
312
- except Exception as e:
313
- logger.error(f"Gradio predict error: {e}", exc_info=True)
314
  raise HTTPException(status_code=502, detail="Upstream Gradio app error")
315
  session.history.extend([m.model_dump() for m in req.messages if m.role != "system"])
316
  session.history.append({"role": "assistant", "content": result})
317
- logger.info(f"Chat completion response sent user={user} session={session_id}")
318
  return {
319
  "id": str(uuid.uuid4()),
320
  "object": "chat.completion",
@@ -324,7 +320,11 @@ async def chat_completions(req: ChatCompletionRequest):
324
 
325
  @app.post("/v1/completions")
326
  async def completions(req: CompletionRequest):
327
- user, session_id, _ = await session_manager.get_session(req.user, req.session_id)
 
 
 
 
328
  prompt = req.prompt if isinstance(req.prompt, str) else "\n".join(req.prompt)
329
  params = {
330
  "message": prompt,
@@ -343,7 +343,6 @@ async def completions(req: CompletionRequest):
343
  "seed": req.seed,
344
  }
345
  params = {k: v for k, v in params.items() if v is not None}
346
- logger.info(f"Completion request user={user} session={session_id} model={req.model} stream={req.stream}")
347
  client = await get_client(app)
348
  if req.stream:
349
  job = await call_gradio(client, params)
@@ -353,29 +352,24 @@ async def completions(req: CompletionRequest):
353
  loop = asyncio.get_running_loop()
354
  try:
355
  result = await loop.run_in_executor(None, lambda: client.predict(**params))
356
- except Exception as e:
357
- logger.error(f"Gradio predict error: {e}", exc_info=True)
358
  raise HTTPException(status_code=502, detail="Upstream Gradio app error")
359
- logger.info(f"Completion response sent user={user} session={session_id}")
360
  return {"id": str(uuid.uuid4()), "object": "text_completion", "choices": [{"text": result}]}
361
 
362
  @app.post("/v1/embeddings")
363
  async def embeddings(req: EmbeddingRequest):
364
  inputs = req.input if isinstance(req.input, list) else [req.input]
365
  embeddings = [[0.0] * 768 for _ in inputs]
366
- logger.info(f"Embedding request model={req.model} inputs_count={len(inputs)}")
367
  return {"object": "list", "data": [{"embedding": emb, "index": i} for i, emb in enumerate(embeddings)]}
368
 
369
  @app.get("/v1/models")
370
  async def get_models():
371
- logger.info("Models list requested")
372
  return {"object": "list", "data": [{"id": "Q8_K_XL", "object": "model", "owned_by": "J.A.R.V.I.S."}]}
373
 
374
  @app.get("/v1/history")
375
  async def get_history(user: Optional[str] = None, session_id: Optional[str] = None):
376
  user = user or "anonymous"
377
  sessions = session_manager.sessions
378
- logger.info(f"History requested user={user} session={session_id}")
379
  if user in sessions and session_id and session_id in sessions[user]:
380
  return {"user": user, "session_id": session_id, "history": sessions[user][session_id].history}
381
  return {"user": user, "session_id": session_id, "history": []}
@@ -384,7 +378,6 @@ async def get_history(user: Optional[str] = None, session_id: Optional[str] = No
384
  async def cancel_response(user: Optional[str], session_id: Optional[str], task_id: Optional[str]):
385
  user = user or "anonymous"
386
  if not task_id:
387
- logger.warning(f"Cancel response missing task_id user={user} session={session_id}")
388
  raise HTTPException(status_code=400, detail="Missing task_id for cancellation")
389
  async with session_manager.lock:
390
  if user in session_manager.sessions and session_id in session_manager.sessions[user]:
@@ -392,9 +385,7 @@ async def cancel_response(user: Optional[str], session_id: Optional[str], task_i
392
  task = session.active_tasks.get(task_id)
393
  if task and not task.done():
394
  task.cancel()
395
- logger.info(f"Cancelled task {task_id} for user={user} session={session_id}")
396
  return {"message": f"Cancelled task {task_id}"}
397
- logger.warning(f"Task not found or already completed task_id={task_id} user={user} session={session_id}")
398
  raise HTTPException(status_code=404, detail="Task not found or already completed")
399
 
400
  @app.api_route("/v1", methods=["POST", "GET", "OPTIONS", "HEAD"])
@@ -403,15 +394,12 @@ async def router(request: Request):
403
  try:
404
  body_json = await request.json()
405
  except Exception:
406
- logger.error("Invalid JSON body in router POST")
407
  raise HTTPException(status_code=400, detail="Invalid JSON body")
408
  try:
409
  body = RouterRequest(**body_json)
410
  except ValidationError as e:
411
- logger.error(f"Validation error in router POST: {e.errors()}")
412
  raise HTTPException(status_code=422, detail=e.errors())
413
  endpoint = body.endpoint or "chat/completions"
414
- logger.info(f"Router POST to endpoint={endpoint}")
415
  if endpoint == "chat/completions":
416
  if not body.model or not body.messages:
417
  raise HTTPException(status_code=422, detail="Missing 'model' or 'messages'")
@@ -432,12 +420,10 @@ async def router(request: Request):
432
  elif endpoint == "history":
433
  return await get_history(body.user, body.session_id)
434
  elif endpoint == "responses/cancel":
435
- return await cancel_response(body.user, body.session_id, body.session_id)
436
  else:
437
- logger.warning(f"Router POST unknown endpoint: {endpoint}")
438
  raise HTTPException(status_code=404, detail="Endpoint not found")
439
  else:
440
- logger.info(f"Router {request.method} called - only POST supported with JSON body")
441
  return JSONResponse({"message": "Send POST request with JSON body"}, status_code=status.HTTP_405_METHOD_NOT_ALLOWED)
442
 
443
  @app.get("/")
 
8
  import json
9
  import asyncio
10
  import logging
11
+ import os
12
  from typing import Optional, List, Union, Dict, Any, Literal
13
  from fastapi import FastAPI, HTTPException, Request, status
14
  from fastapi.middleware.cors import CORSMiddleware
 
30
  self.history: List[Dict[str, Any]] = []
31
  self.last_access: float = time.time()
32
  self.active_tasks: Dict[str, asyncio.Task] = {}
33
+ self.last_request_time: float = 0.0
34
 
35
  class SessionManager:
36
  def __init__(self):
 
50
  for task_id, task in data.active_tasks.items():
51
  if not task.done():
52
  task.cancel()
 
53
  for user, sid in expired:
54
  if user in self.sessions and sid in self.sessions[user]:
55
  del self.sessions[user][sid]
56
  if not self.sessions[user]:
57
  del self.sessions[user]
 
58
 
59
  async def get_session(self, user: Optional[str], session_id: Optional[str]) -> (str, str, SessionData):
60
  async with self.lock:
61
  if not user:
62
  user = str(uuid.uuid4())
 
63
  if user not in self.sessions:
64
  self.sessions[user] = {}
65
  if not session_id or session_id not in self.sessions[user]:
66
  session_id = str(uuid.uuid4())
67
  self.sessions[user][session_id] = SessionData()
 
68
  session = self.sessions[user][session_id]
69
  session.last_access = time.time()
 
70
  return user, session_id, session
71
 
72
  session_manager = SessionManager()
73
 
74
  async def refresh_client(app: FastAPI):
75
  while True:
76
+ await asyncio.sleep(15 * 60)
77
  async with app.state.client_lock:
78
+ if app.state.client is not None:
79
+ try:
80
+ old_client = app.state.client
81
+ app.state.client = None
82
+ del old_client
83
+ app.state.client = Client("https://hadadrjt-ai.hf.space/")
84
+ logger.info("Refreshed Gradio client connection")
85
+ except Exception as e:
86
+ logger.error(f"Error refreshing Gradio client: {e}", exc_info=True)
87
+ app.state.client = None
88
+
89
+ async def clear_terminal_periodically():
90
+ while True:
91
+ await asyncio.sleep(300)
92
+ if os.name == "nt":
93
+ os.system("cls")
94
+ else:
95
+ print("\033c", end="", flush=True)
 
96
 
97
  @asynccontextmanager
98
  async def lifespan(app: FastAPI):
 
100
  app.state.client = None
101
  app.state.client_lock = asyncio.Lock()
102
  app.state.refresh_task = asyncio.create_task(refresh_client(app))
103
+ app.state.cleanup_task = asyncio.create_task(session_manager.cleanup())
104
+ app.state.clear_log_task = asyncio.create_task(clear_terminal_periodically())
105
  try:
106
  yield
107
  finally:
108
  app.state.refresh_task.cancel()
109
+ app.state.cleanup_task.cancel()
110
+ app.state.clear_log_task.cancel()
111
  await asyncio.sleep(0.1)
 
112
 
113
  app = FastAPI(
114
  title="J.A.R.V.I.S. OpenAI-Compatible API",
 
218
  async def call_gradio(client: Client, params: dict):
219
  for attempt in range(3):
220
  try:
 
221
  return await asyncio.to_thread(lambda: client.submit(**params))
222
  except Exception as e:
 
223
  await asyncio.sleep(0.2 * (attempt + 1))
 
224
  raise HTTPException(status_code=502, detail="Upstream Gradio app error")
225
 
226
  async def stream_response(job, session_id: str, session_history: List[Dict[str, Any]], new_messages: List[Message], response_type: str):
227
  partial = ""
228
  try:
229
  chunks = await asyncio.to_thread(lambda: list(job))
230
+ except Exception:
 
231
  chunks = []
232
  for chunk in chunks:
233
  try:
 
252
  "session_id": session_id
253
  }
254
  yield f"data: {json.dumps(data)}\n\n"
255
+ except Exception:
 
256
  continue
257
  session_history.extend([m.model_dump() for m in new_messages if m.role != "system"])
258
  session_history.append({"role": "assistant", "content": partial})
 
264
  }
265
  yield f"data: {json.dumps(done_data)}\n\n"
266
 
267
+ RATE_LIMIT_SECONDS = 1.0
268
+
269
  @app.post("/v1/chat/completions")
270
  async def chat_completions(req: ChatCompletionRequest):
271
  user, session_id, session = await session_manager.get_session(req.user, req.session_id)
272
+ now = time.time()
273
+ if now - session.last_request_time < RATE_LIMIT_SECONDS:
274
+ raise HTTPException(status_code=429, detail="Too many requests, please slow down")
275
+ session.last_request_time = now
276
  req.messages = sanitize_messages(req.messages)
277
  for m in req.messages:
278
  if m.role == "system":
 
298
  "function_call": req.function_call or req.tool_choice,
299
  }
300
  params = {k: v for k, v in params.items() if v is not None}
 
301
  client = await get_client(app)
302
  if req.stream:
303
  job = await call_gradio(client, params)
 
307
  loop = asyncio.get_running_loop()
308
  try:
309
  result = await loop.run_in_executor(None, lambda: client.predict(**params))
310
+ except Exception:
 
311
  raise HTTPException(status_code=502, detail="Upstream Gradio app error")
312
  session.history.extend([m.model_dump() for m in req.messages if m.role != "system"])
313
  session.history.append({"role": "assistant", "content": result})
 
314
  return {
315
  "id": str(uuid.uuid4()),
316
  "object": "chat.completion",
 
320
 
321
  @app.post("/v1/completions")
322
  async def completions(req: CompletionRequest):
323
+ user, session_id, session = await session_manager.get_session(req.user, req.session_id)
324
+ now = time.time()
325
+ if now - session.last_request_time < RATE_LIMIT_SECONDS:
326
+ raise HTTPException(status_code=429, detail="Too many requests, please slow down")
327
+ session.last_request_time = now
328
  prompt = req.prompt if isinstance(req.prompt, str) else "\n".join(req.prompt)
329
  params = {
330
  "message": prompt,
 
343
  "seed": req.seed,
344
  }
345
  params = {k: v for k, v in params.items() if v is not None}
 
346
  client = await get_client(app)
347
  if req.stream:
348
  job = await call_gradio(client, params)
 
352
  loop = asyncio.get_running_loop()
353
  try:
354
  result = await loop.run_in_executor(None, lambda: client.predict(**params))
355
+ except Exception:
 
356
  raise HTTPException(status_code=502, detail="Upstream Gradio app error")
 
357
  return {"id": str(uuid.uuid4()), "object": "text_completion", "choices": [{"text": result}]}
358
 
359
  @app.post("/v1/embeddings")
360
  async def embeddings(req: EmbeddingRequest):
361
  inputs = req.input if isinstance(req.input, list) else [req.input]
362
  embeddings = [[0.0] * 768 for _ in inputs]
 
363
  return {"object": "list", "data": [{"embedding": emb, "index": i} for i, emb in enumerate(embeddings)]}
364
 
365
  @app.get("/v1/models")
366
  async def get_models():
 
367
  return {"object": "list", "data": [{"id": "Q8_K_XL", "object": "model", "owned_by": "J.A.R.V.I.S."}]}
368
 
369
  @app.get("/v1/history")
370
  async def get_history(user: Optional[str] = None, session_id: Optional[str] = None):
371
  user = user or "anonymous"
372
  sessions = session_manager.sessions
 
373
  if user in sessions and session_id and session_id in sessions[user]:
374
  return {"user": user, "session_id": session_id, "history": sessions[user][session_id].history}
375
  return {"user": user, "session_id": session_id, "history": []}
 
378
  async def cancel_response(user: Optional[str], session_id: Optional[str], task_id: Optional[str]):
379
  user = user or "anonymous"
380
  if not task_id:
 
381
  raise HTTPException(status_code=400, detail="Missing task_id for cancellation")
382
  async with session_manager.lock:
383
  if user in session_manager.sessions and session_id in session_manager.sessions[user]:
 
385
  task = session.active_tasks.get(task_id)
386
  if task and not task.done():
387
  task.cancel()
 
388
  return {"message": f"Cancelled task {task_id}"}
 
389
  raise HTTPException(status_code=404, detail="Task not found or already completed")
390
 
391
  @app.api_route("/v1", methods=["POST", "GET", "OPTIONS", "HEAD"])
 
394
  try:
395
  body_json = await request.json()
396
  except Exception:
 
397
  raise HTTPException(status_code=400, detail="Invalid JSON body")
398
  try:
399
  body = RouterRequest(**body_json)
400
  except ValidationError as e:
 
401
  raise HTTPException(status_code=422, detail=e.errors())
402
  endpoint = body.endpoint or "chat/completions"
 
403
  if endpoint == "chat/completions":
404
  if not body.model or not body.messages:
405
  raise HTTPException(status_code=422, detail="Missing 'model' or 'messages'")
 
420
  elif endpoint == "history":
421
  return await get_history(body.user, body.session_id)
422
  elif endpoint == "responses/cancel":
423
+ return await cancel_response(body.user, body.session_id, body.tool_choice if isinstance(body.tool_choice, str) else None)
424
  else:
 
425
  raise HTTPException(status_code=404, detail="Endpoint not found")
426
  else:
 
427
  return JSONResponse({"message": "Send POST request with JSON body"}, status_code=status.HTTP_405_METHOD_NOT_ALLOWED)
428
 
429
  @app.get("/")