hadadrjt commited on
Commit
c0d067d
·
1 Parent(s): 91573a9

api: Apply OpenAI plugins.

Browse files
Files changed (1) hide show
  1. app.py +142 -2
app.py CHANGED
@@ -14,7 +14,7 @@ from fastapi import FastAPI, HTTPException
14
  from fastapi.responses import JSONResponse, StreamingResponse
15
  from gradio_client import Client
16
  from pydantic import BaseModel
17
- from typing import AsyncGenerator, Optional, Dict, List, Tuple
18
 
19
  # Default AI model name used when no model is specified by user
20
  MODEL = "JARVIS: 2.1.3"
@@ -49,6 +49,21 @@ class ResponseRequest(BaseModel):
49
  stream: Optional[bool] = False
50
  session_id: Optional[str] = None
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def cleanup_expired_sessions():
53
  """
54
  Remove sessions that have been inactive for longer than EXPIRE.
@@ -230,7 +245,7 @@ async def event_generator(user_input: str, model: str, session_id: str) -> Async
230
  @app.post("/v1/responses")
231
  async def responses(req: ResponseRequest):
232
  """
233
- Main API endpoint to get AI responses.
234
  Supports both streaming and non-streaming modes.
235
 
236
  Workflow:
@@ -294,6 +309,131 @@ async def responses(req: ResponseRequest):
294
  # Return the JSON response
295
  return JSONResponse(response)
296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  @app.get("/v1/history")
298
  async def get_history(session_id: Optional[str] = None):
299
  """
 
14
  from fastapi.responses import JSONResponse, StreamingResponse
15
  from gradio_client import Client
16
  from pydantic import BaseModel
17
+ from typing import AsyncGenerator, Optional, Dict, List, Tuple, Any
18
 
19
  # Default AI model name used when no model is specified by user
20
  MODEL = "JARVIS: 2.1.3"
 
49
  stream: Optional[bool] = False
50
  session_id: Optional[str] = None
51
 
52
+ class OpenAIChatRequest(BaseModel):
53
+ """
54
+ Defines the OpenAI-compatible request structure for /v1/chat/completions endpoint.
55
+
56
+ Attributes:
57
+ - model: Optional; specifies which AI model to use. Defaults to MODEL if not provided.
58
+ - messages: List of message objects containing 'role' and 'content'
59
+ - stream: Optional; if True, the response will be streamed incrementally.
60
+ - session_id: Optional; unique session identifier for maintaining conversation history
61
+ """
62
+ model: Optional[str] = None
63
+ messages: List[Dict[str, str]]
64
+ stream: Optional[bool] = False
65
+ session_id: Optional[str] = None
66
+
67
  def cleanup_expired_sessions():
68
  """
69
  Remove sessions that have been inactive for longer than EXPIRE.
 
245
  @app.post("/v1/responses")
246
  async def responses(req: ResponseRequest):
247
  """
248
+ Original API endpoint to get AI responses.
249
  Supports both streaming and non-streaming modes.
250
 
251
  Workflow:
 
309
  # Return the JSON response
310
  return JSONResponse(response)
311
 
312
+ @app.post("/v1/chat/completions")
313
+ async def openai_chat_completions(req: OpenAIChatRequest):
314
+ """
315
+ OpenAI-compatible endpoint for chat completions.
316
+ Supports both streaming and non-streaming modes.
317
+
318
+ Workflow:
319
+ - Validate message structure and extract conversation history
320
+ - Validate or create session
321
+ - Update session history from messages
322
+ - Handle streaming or full response
323
+ - Save new interaction to session history
324
+
325
+ Returns:
326
+ - JSON response in OpenAI format with session ID extension
327
+ """
328
+ # Validate messages structure
329
+ if not req.messages:
330
+ raise HTTPException(status_code=400, detail="Messages cannot be empty")
331
+
332
+ # Extract conversation history and current input
333
+ history = []
334
+ current_input = ""
335
+
336
+ # Process messages to extract conversation history
337
+ try:
338
+ # Last message should be from user and used as current input
339
+ if req.messages[-1]["role"] != "user":
340
+ raise ValueError("Last message must be from user")
341
+
342
+ current_input = req.messages[-1]["content"]
343
+
344
+ # Process message pairs (user + assistant)
345
+ messages = req.messages[:-1] # Exclude last message (current input)
346
+ for i in range(0, len(messages), 2):
347
+ if i+1 < len(messages):
348
+ user_msg = messages[i]
349
+ assistant_msg = messages[i+1]
350
+
351
+ if user_msg["role"] != "user" or assistant_msg["role"] != "assistant":
352
+ # Skip invalid pairs but continue processing
353
+ continue
354
+
355
+ history.append({
356
+ "input": user_msg["content"],
357
+ "response": assistant_msg["content"]
358
+ })
359
+ except (KeyError, ValueError) as e:
360
+ raise HTTPException(status_code=400, detail=f"Invalid message format: {str(e)}")
361
+
362
+ model = req.model or MODEL # Use requested model or default
363
+ session_id = get_or_create_session(req.session_id, model) # Get or create session
364
+ last_update, session_data = session_store[session_id]
365
+
366
+ # Update session history from messages
367
+ session_data["history"] = history
368
+ session_store[session_id] = (time.time(), session_data)
369
+
370
+ client = session_data["client"]
371
+ if client is None:
372
+ raise HTTPException(status_code=503, detail="AI client not available")
373
+
374
+ if req.stream:
375
+ # Streaming response
376
+ return StreamingResponse(
377
+ event_generator(current_input, model, session_id),
378
+ media_type="text/event-stream"
379
+ )
380
+
381
+ # Non-streaming response
382
+ try:
383
+ jarvis_response = client.submit(multi={"text": current_input}, api_name="/api")
384
+ except Exception as e:
385
+ raise HTTPException(status_code=500, detail=f"Failed to submit to AI: {str(e)}")
386
+
387
+ buffer = ""
388
+ for partial in jarvis_response:
389
+ text = partial[0][0][1]
390
+ buffer = text
391
+
392
+ # Update session history with new interaction
393
+ session_data["history"].append({"input": current_input, "response": buffer})
394
+ session_store[session_id] = (time.time(), session_data)
395
+
396
+ # Format response in OpenAI style
397
+ response = {
398
+ "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
399
+ "object": "chat.completion",
400
+ "created": int(time.time()),
401
+ "model": model,
402
+ "choices": [
403
+ {
404
+ "index": 0,
405
+ "message": {
406
+ "role": "assistant",
407
+ "content": buffer
408
+ },
409
+ "finish_reason": "stop"
410
+ }
411
+ ],
412
+ "session_id": session_id # Custom extension for session management
413
+ }
414
+
415
+ return JSONResponse(response)
416
+
417
+ @app.get("/v1/models")
418
+ async def list_models():
419
+ """
420
+ OpenAI-compatible endpoint to list available models.
421
+ Returns a fixed list containing our default model.
422
+
423
+ This endpoint is required by many OpenAI-compatible clients.
424
+ """
425
+ return JSONResponse({
426
+ "object": "list",
427
+ "data": [
428
+ {
429
+ "id": MODEL,
430
+ "object": "model",
431
+ "created": 0, # Timestamp not available
432
+ "owned_by": "J.A.R.V.I.S."
433
+ }
434
+ ]
435
+ })
436
+
437
  @app.get("/v1/history")
438
  async def get_history(session_id: Optional[str] = None):
439
  """