api / app.py
hadadrjt's picture
api: Apply OpenAI plugins.
c0d067d
raw
history blame
16.5 kB
#
# SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
# SPDX-License-Identifier: Apache-2.0
#
import json
import time
import uuid
import asyncio
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from gradio_client import Client
from pydantic import BaseModel
from typing import AsyncGenerator, Optional, Dict, List, Tuple, Any
# Default AI model name used when no model is specified by user
MODEL = "JARVIS: 2.1.3"
# Session store keeps track of active sessions.
# Each session_id maps to a tuple:
# (last_update_timestamp, session_data_dict)
# session_data_dict contains:
# - "model": the AI model name used in this session
# - "history": list of past chat messages (input and response)
# - "client": the Gradio Client instance specific to this session
session_store: Dict[str, Tuple[float, Dict]] = {}
# Duration (in seconds) after which inactive sessions are removed
EXPIRE = 3600 # 1 hour
# Create FastAPI app instance
app = FastAPI()
class ResponseRequest(BaseModel):
"""
Defines the expected structure of the request body for /v1/responses endpoint.
Attributes:
- model: Optional; specifies which AI model to use. Defaults to MODEL if not provided.
- input: The user's input text to send to the AI.
- stream: Optional; if True, the response will be streamed incrementally.
- session_id: Optional; unique identifier for the user's session. If missing, a new session will be created.
"""
model: Optional[str] = None
input: str
stream: Optional[bool] = False
session_id: Optional[str] = None
class OpenAIChatRequest(BaseModel):
"""
Defines the OpenAI-compatible request structure for /v1/chat/completions endpoint.
Attributes:
- model: Optional; specifies which AI model to use. Defaults to MODEL if not provided.
- messages: List of message objects containing 'role' and 'content'
- stream: Optional; if True, the response will be streamed incrementally.
- session_id: Optional; unique session identifier for maintaining conversation history
"""
model: Optional[str] = None
messages: List[Dict[str, str]]
stream: Optional[bool] = False
session_id: Optional[str] = None
def cleanup_expired_sessions():
"""
Remove sessions that have been inactive for longer than EXPIRE.
This helps free up memory by deleting old sessions and closing their clients.
"""
now = time.time()
expired_sessions = [
sid for sid, (last_update, _) in session_store.items()
if now - last_update > EXPIRE
]
for sid in expired_sessions:
# Attempt to close the Gradio client associated with the session
_, data = session_store[sid]
client = data.get("client")
if client:
try:
client.close()
except Exception:
# Ignore errors during client close to avoid crashing cleanup
pass
# Remove the session from the store
del session_store[sid]
def create_client_for_model(model: str) -> Client:
"""
Create a new Gradio Client instance and set it to use the specified AI model.
Parameters:
- model: The name of the AI model to initialize the client with.
Returns:
- A new Gradio Client instance configured with the given model.
"""
client = Client("hadadrjt/ai")
# Set the model on the Gradio client by calling the /change_model API
client.predict(new=model, api_name="/change_model")
return client
def get_or_create_session(session_id: Optional[str], model: str) -> str:
"""
Retrieve an existing session by session_id or create a new one if it doesn't exist.
Also cleans up expired sessions before proceeding.
Parameters:
- session_id: The unique identifier of the session (optional).
- model: The AI model to use for this session.
Returns:
- The session_id for the active or newly created session.
"""
cleanup_expired_sessions()
# If no session_id provided or session does not exist, create a new session
if not session_id or session_id not in session_store:
session_id = str(uuid.uuid4()) # Generate a new unique session ID
client = create_client_for_model(model) # Create a new client for this session
session_store[session_id] = (time.time(), {
"model": model,
"history": [],
"client": client
})
else:
# Session exists, update last access time and check if model changed
last_update, data = session_store[session_id]
if data["model"] != model:
# If model changed, close old client and create a new one with the new model
old_client = data.get("client")
if old_client:
try:
old_client.close()
except Exception:
pass # Ignore errors on close
new_client = create_client_for_model(model)
data["model"] = model
data["client"] = new_client
session_store[session_id] = (time.time(), data)
else:
# Just update the last access time to keep session alive
session_store[session_id] = (time.time(), data)
return session_id
async def event_generator(user_input: str, model: str, session_id: str) -> AsyncGenerator[str, None]:
"""
Asynchronous generator that streams AI responses incrementally as Server-Sent Events (SSE).
Parameters:
- user_input: The input text from the user.
- model: The AI model to use.
- session_id: The unique session identifier.
Yields:
- JSON-formatted chunks representing incremental AI response deltas.
"""
last_update, session_data = session_store.get(session_id, (0, None))
if session_data is None:
# Session not found; yield error and stop
yield f"data: {json.dumps({'error': 'Session not found'})}\n\n"
return
client = session_data["client"]
if client is None:
# Client missing for session; yield error and stop
yield f"data: {json.dumps({'error': 'AI client not available'})}\n\n"
return
try:
# Submit the user input to the AI model via Gradio client
jarvis_response = client.submit(multi={"text": user_input}, api_name="/api")
except Exception as e:
# If submission fails, yield error and stop
yield f"data: {json.dumps({'error': f'Failed to submit to AI: {str(e)}'})}\n\n"
return
buffer = "" # Buffer to track full response text progressively
try:
for partial in jarvis_response:
# Extract the current partial text from the response
text = partial[0][0][1]
# Calculate the delta (new text since last chunk)
if text.startswith(buffer):
delta = text[len(buffer):]
else:
delta = text
buffer = text # Update buffer with latest full text
if delta == "":
# Skip empty delta chunks
continue
# Prepare chunk data in OpenAI streaming format
chunk = {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": delta},
"finish_reason": None
}
]
}
# Yield the chunk as a Server-Sent Event
yield f"data: {json.dumps(chunk)}\n\n"
# After streaming completes, save the full input-response pair to session history
session_data["history"].append({"input": user_input, "response": buffer})
session_store[session_id] = (time.time(), session_data) # Update last access time
# Send a final chunk signaling completion of the stream
done_chunk = {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "stop"
}
]
}
yield f"data: {json.dumps(done_chunk)}\n\n"
except Exception as e:
# If streaming fails at any point, yield an error chunk
error_chunk = {
"error": {"message": f"Streaming error: {str(e)}"}
}
yield f"data: {json.dumps(error_chunk)}\n\n"
@app.post("/v1/responses")
async def responses(req: ResponseRequest):
"""
Original API endpoint to get AI responses.
Supports both streaming and non-streaming modes.
Workflow:
- Validate or create session.
- Ensure AI client is available.
- Handle streaming or full response accordingly.
- Save chat history per session.
Returns:
- JSON response with AI output and session ID.
"""
model = req.model or MODEL # Use requested model or default
session_id = get_or_create_session(req.session_id, model) # Get or create session
last_update, session_data = session_store[session_id]
user_input = req.input
client = session_data["client"]
if client is None:
# If client is missing, return 503 error
raise HTTPException(status_code=503, detail="AI client not available")
if req.stream:
# If streaming requested, return a streaming response using event_generator
return StreamingResponse(event_generator(user_input, model, session_id), media_type="text/event-stream")
# Non-streaming request: submit input and collect full response
try:
jarvis_response = client.submit(multi={"text": user_input}, api_name="/api")
except Exception as e:
# Return 500 error if submission fails
raise HTTPException(status_code=500, detail=f"Failed to submit to AI: {str(e)}")
buffer = ""
for partial in jarvis_response:
text = partial[0][0][1]
buffer = text # Update buffer with latest full response
# Save input and response to session history and update last access time
session_data["history"].append({"input": user_input, "response": buffer})
session_store[session_id] = (time.time(), session_data)
# Prepare the JSON response in OpenAI style format
response = {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": buffer
},
"finish_reason": "stop"
}
],
"session_id": session_id # Return session_id so client can reuse it
}
# Return the JSON response
return JSONResponse(response)
@app.post("/v1/chat/completions")
async def openai_chat_completions(req: OpenAIChatRequest):
"""
OpenAI-compatible endpoint for chat completions.
Supports both streaming and non-streaming modes.
Workflow:
- Validate message structure and extract conversation history
- Validate or create session
- Update session history from messages
- Handle streaming or full response
- Save new interaction to session history
Returns:
- JSON response in OpenAI format with session ID extension
"""
# Validate messages structure
if not req.messages:
raise HTTPException(status_code=400, detail="Messages cannot be empty")
# Extract conversation history and current input
history = []
current_input = ""
# Process messages to extract conversation history
try:
# Last message should be from user and used as current input
if req.messages[-1]["role"] != "user":
raise ValueError("Last message must be from user")
current_input = req.messages[-1]["content"]
# Process message pairs (user + assistant)
messages = req.messages[:-1] # Exclude last message (current input)
for i in range(0, len(messages), 2):
if i+1 < len(messages):
user_msg = messages[i]
assistant_msg = messages[i+1]
if user_msg["role"] != "user" or assistant_msg["role"] != "assistant":
# Skip invalid pairs but continue processing
continue
history.append({
"input": user_msg["content"],
"response": assistant_msg["content"]
})
except (KeyError, ValueError) as e:
raise HTTPException(status_code=400, detail=f"Invalid message format: {str(e)}")
model = req.model or MODEL # Use requested model or default
session_id = get_or_create_session(req.session_id, model) # Get or create session
last_update, session_data = session_store[session_id]
# Update session history from messages
session_data["history"] = history
session_store[session_id] = (time.time(), session_data)
client = session_data["client"]
if client is None:
raise HTTPException(status_code=503, detail="AI client not available")
if req.stream:
# Streaming response
return StreamingResponse(
event_generator(current_input, model, session_id),
media_type="text/event-stream"
)
# Non-streaming response
try:
jarvis_response = client.submit(multi={"text": current_input}, api_name="/api")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to submit to AI: {str(e)}")
buffer = ""
for partial in jarvis_response:
text = partial[0][0][1]
buffer = text
# Update session history with new interaction
session_data["history"].append({"input": current_input, "response": buffer})
session_store[session_id] = (time.time(), session_data)
# Format response in OpenAI style
response = {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": buffer
},
"finish_reason": "stop"
}
],
"session_id": session_id # Custom extension for session management
}
return JSONResponse(response)
@app.get("/v1/models")
async def list_models():
"""
OpenAI-compatible endpoint to list available models.
Returns a fixed list containing our default model.
This endpoint is required by many OpenAI-compatible clients.
"""
return JSONResponse({
"object": "list",
"data": [
{
"id": MODEL,
"object": "model",
"created": 0, # Timestamp not available
"owned_by": "J.A.R.V.I.S."
}
]
})
@app.get("/v1/history")
async def get_history(session_id: Optional[str] = None):
"""
Endpoint to retrieve chat history for a given session.
Parameters:
- session_id: The unique session identifier.
Returns:
- JSON object containing session_id and list of past input-response pairs.
Raises:
- 404 error if session_id is missing or session does not exist.
"""
if not session_id or session_id not in session_store:
raise HTTPException(status_code=404, detail="Session not found or session_id missing.")
_, session_data = session_store[session_id]
return {"session_id": session_id, "history": session_data["history"]}
@app.get("/")
def root():
"""
Simple health check endpoint.
Returns basic status indicating if API is running.
"""
return {"status": "API is running"}
# Run the app with Uvicorn ASGI server when executed directly
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)