|
|
|
|
|
|
|
|
|
|
|
import json |
|
import time |
|
import uuid |
|
import asyncio |
|
import uvicorn |
|
|
|
from contextlib import asynccontextmanager |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import JSONResponse, StreamingResponse |
|
from gradio_client import Client |
|
from pydantic import BaseModel |
|
from typing import AsyncGenerator, Optional, Dict, List, Tuple, Any |
|
|
|
|
|
MODEL = "JARVIS: 2.1.3" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
session_store: Dict[str, Tuple[float, Dict]] = {} |
|
|
|
|
|
EXPIRE = 3600 |
|
|
|
|
|
app = FastAPI() |
|
|
|
class ResponseRequest(BaseModel): |
|
""" |
|
Defines the expected structure of the request body for /v1/responses endpoint. |
|
|
|
Attributes: |
|
- model: Optional; specifies which AI model to use. Defaults to MODEL if not provided. |
|
- input: The user's input text to send to the AI. |
|
- stream: Optional; if True, the response will be streamed incrementally. |
|
- session_id: Optional; unique identifier for the user's session. If missing, a new session will be created. |
|
""" |
|
model: Optional[str] = None |
|
input: str |
|
stream: Optional[bool] = False |
|
session_id: Optional[str] = None |
|
|
|
class OpenAIChatRequest(BaseModel): |
|
""" |
|
Defines the OpenAI-compatible request structure for /v1/chat/completions endpoint. |
|
|
|
Attributes: |
|
- model: Optional; specifies which AI model to use. Defaults to MODEL if not provided. |
|
- messages: List of message objects containing 'role' and 'content' |
|
- stream: Optional; if True, the response will be streamed incrementally. |
|
- session_id: Optional; unique session identifier for maintaining conversation history |
|
""" |
|
model: Optional[str] = None |
|
messages: List[Dict[str, str]] |
|
stream: Optional[bool] = False |
|
session_id: Optional[str] = None |
|
|
|
def cleanup_expired_sessions(): |
|
""" |
|
Remove sessions that have been inactive for longer than EXPIRE. |
|
This helps free up memory by deleting old sessions and closing their clients. |
|
""" |
|
now = time.time() |
|
expired_sessions = [ |
|
sid for sid, (last_update, _) in session_store.items() |
|
if now - last_update > EXPIRE |
|
] |
|
for sid in expired_sessions: |
|
|
|
_, data = session_store[sid] |
|
client = data.get("client") |
|
if client: |
|
try: |
|
client.close() |
|
except Exception: |
|
|
|
pass |
|
|
|
del session_store[sid] |
|
|
|
def create_client_for_model(model: str) -> Client: |
|
""" |
|
Create a new Gradio Client instance and set it to use the specified AI model. |
|
|
|
Parameters: |
|
- model: The name of the AI model to initialize the client with. |
|
|
|
Returns: |
|
- A new Gradio Client instance configured with the given model. |
|
""" |
|
client = Client("hadadrjt/ai") |
|
|
|
client.predict(new=model, api_name="/change_model") |
|
return client |
|
|
|
def get_or_create_session(session_id: Optional[str], model: str) -> str: |
|
""" |
|
Retrieve an existing session by session_id or create a new one if it doesn't exist. |
|
Also cleans up expired sessions before proceeding. |
|
|
|
Parameters: |
|
- session_id: The unique identifier of the session (optional). |
|
- model: The AI model to use for this session. |
|
|
|
Returns: |
|
- The session_id for the active or newly created session. |
|
""" |
|
cleanup_expired_sessions() |
|
|
|
|
|
if not session_id or session_id not in session_store: |
|
session_id = str(uuid.uuid4()) |
|
client = create_client_for_model(model) |
|
session_store[session_id] = (time.time(), { |
|
"model": model, |
|
"history": [], |
|
"client": client |
|
}) |
|
else: |
|
|
|
last_update, data = session_store[session_id] |
|
if data["model"] != model: |
|
|
|
old_client = data.get("client") |
|
if old_client: |
|
try: |
|
old_client.close() |
|
except Exception: |
|
pass |
|
new_client = create_client_for_model(model) |
|
data["model"] = model |
|
data["client"] = new_client |
|
session_store[session_id] = (time.time(), data) |
|
else: |
|
|
|
session_store[session_id] = (time.time(), data) |
|
|
|
return session_id |
|
|
|
async def event_generator(user_input: str, model: str, session_id: str) -> AsyncGenerator[str, None]: |
|
""" |
|
Asynchronous generator that streams AI responses incrementally as Server-Sent Events (SSE). |
|
|
|
Parameters: |
|
- user_input: The input text from the user. |
|
- model: The AI model to use. |
|
- session_id: The unique session identifier. |
|
|
|
Yields: |
|
- JSON-formatted chunks representing incremental AI response deltas. |
|
""" |
|
last_update, session_data = session_store.get(session_id, (0, None)) |
|
if session_data is None: |
|
|
|
yield f"data: {json.dumps({'error': 'Session not found'})}\n\n" |
|
return |
|
|
|
client = session_data["client"] |
|
if client is None: |
|
|
|
yield f"data: {json.dumps({'error': 'AI client not available'})}\n\n" |
|
return |
|
|
|
try: |
|
|
|
jarvis_response = client.submit(multi={"text": user_input}, api_name="/api") |
|
except Exception as e: |
|
|
|
yield f"data: {json.dumps({'error': f'Failed to submit to AI: {str(e)}'})}\n\n" |
|
return |
|
|
|
buffer = "" |
|
|
|
try: |
|
for partial in jarvis_response: |
|
|
|
text = partial[0][0][1] |
|
|
|
|
|
if text.startswith(buffer): |
|
delta = text[len(buffer):] |
|
else: |
|
delta = text |
|
|
|
buffer = text |
|
|
|
if delta == "": |
|
|
|
continue |
|
|
|
|
|
chunk = { |
|
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}", |
|
"object": "chat.completion.chunk", |
|
"created": int(time.time()), |
|
"model": model, |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"delta": {"content": delta}, |
|
"finish_reason": None |
|
} |
|
] |
|
} |
|
|
|
|
|
yield f"data: {json.dumps(chunk)}\n\n" |
|
|
|
|
|
session_data["history"].append({"input": user_input, "response": buffer}) |
|
session_store[session_id] = (time.time(), session_data) |
|
|
|
|
|
done_chunk = { |
|
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}", |
|
"object": "chat.completion.chunk", |
|
"created": int(time.time()), |
|
"model": model, |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"delta": {}, |
|
"finish_reason": "stop" |
|
} |
|
] |
|
} |
|
yield f"data: {json.dumps(done_chunk)}\n\n" |
|
|
|
except Exception as e: |
|
|
|
error_chunk = { |
|
"error": {"message": f"Streaming error: {str(e)}"} |
|
} |
|
yield f"data: {json.dumps(error_chunk)}\n\n" |
|
|
|
@app.post("/v1/responses") |
|
async def responses(req: ResponseRequest): |
|
""" |
|
Original API endpoint to get AI responses. |
|
Supports both streaming and non-streaming modes. |
|
|
|
Workflow: |
|
- Validate or create session. |
|
- Ensure AI client is available. |
|
- Handle streaming or full response accordingly. |
|
- Save chat history per session. |
|
|
|
Returns: |
|
- JSON response with AI output and session ID. |
|
""" |
|
model = req.model or MODEL |
|
session_id = get_or_create_session(req.session_id, model) |
|
last_update, session_data = session_store[session_id] |
|
user_input = req.input |
|
|
|
client = session_data["client"] |
|
if client is None: |
|
|
|
raise HTTPException(status_code=503, detail="AI client not available") |
|
|
|
if req.stream: |
|
|
|
return StreamingResponse(event_generator(user_input, model, session_id), media_type="text/event-stream") |
|
|
|
|
|
try: |
|
jarvis_response = client.submit(multi={"text": user_input}, api_name="/api") |
|
except Exception as e: |
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to submit to AI: {str(e)}") |
|
|
|
buffer = "" |
|
for partial in jarvis_response: |
|
text = partial[0][0][1] |
|
buffer = text |
|
|
|
|
|
session_data["history"].append({"input": user_input, "response": buffer}) |
|
session_store[session_id] = (time.time(), session_data) |
|
|
|
|
|
response = { |
|
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}", |
|
"object": "chat.completion", |
|
"created": int(time.time()), |
|
"model": model, |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": buffer |
|
}, |
|
"finish_reason": "stop" |
|
} |
|
], |
|
"session_id": session_id |
|
} |
|
|
|
|
|
return JSONResponse(response) |
|
|
|
@app.post("/v1/chat/completions") |
|
async def openai_chat_completions(req: OpenAIChatRequest): |
|
""" |
|
OpenAI-compatible endpoint for chat completions. |
|
Supports both streaming and non-streaming modes. |
|
|
|
Workflow: |
|
- Validate message structure and extract conversation history |
|
- Validate or create session |
|
- Update session history from messages |
|
- Handle streaming or full response |
|
- Save new interaction to session history |
|
|
|
Returns: |
|
- JSON response in OpenAI format with session ID extension |
|
""" |
|
|
|
if not req.messages: |
|
raise HTTPException(status_code=400, detail="Messages cannot be empty") |
|
|
|
|
|
history = [] |
|
current_input = "" |
|
|
|
|
|
try: |
|
|
|
if req.messages[-1]["role"] != "user": |
|
raise ValueError("Last message must be from user") |
|
|
|
current_input = req.messages[-1]["content"] |
|
|
|
|
|
messages = req.messages[:-1] |
|
for i in range(0, len(messages), 2): |
|
if i+1 < len(messages): |
|
user_msg = messages[i] |
|
assistant_msg = messages[i+1] |
|
|
|
if user_msg["role"] != "user" or assistant_msg["role"] != "assistant": |
|
|
|
continue |
|
|
|
history.append({ |
|
"input": user_msg["content"], |
|
"response": assistant_msg["content"] |
|
}) |
|
except (KeyError, ValueError) as e: |
|
raise HTTPException(status_code=400, detail=f"Invalid message format: {str(e)}") |
|
|
|
model = req.model or MODEL |
|
session_id = get_or_create_session(req.session_id, model) |
|
last_update, session_data = session_store[session_id] |
|
|
|
|
|
session_data["history"] = history |
|
session_store[session_id] = (time.time(), session_data) |
|
|
|
client = session_data["client"] |
|
if client is None: |
|
raise HTTPException(status_code=503, detail="AI client not available") |
|
|
|
if req.stream: |
|
|
|
return StreamingResponse( |
|
event_generator(current_input, model, session_id), |
|
media_type="text/event-stream" |
|
) |
|
|
|
|
|
try: |
|
jarvis_response = client.submit(multi={"text": current_input}, api_name="/api") |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Failed to submit to AI: {str(e)}") |
|
|
|
buffer = "" |
|
for partial in jarvis_response: |
|
text = partial[0][0][1] |
|
buffer = text |
|
|
|
|
|
session_data["history"].append({"input": current_input, "response": buffer}) |
|
session_store[session_id] = (time.time(), session_data) |
|
|
|
|
|
response = { |
|
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}", |
|
"object": "chat.completion", |
|
"created": int(time.time()), |
|
"model": model, |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": buffer |
|
}, |
|
"finish_reason": "stop" |
|
} |
|
], |
|
"session_id": session_id |
|
} |
|
|
|
return JSONResponse(response) |
|
|
|
@app.get("/v1/models") |
|
async def list_models(): |
|
""" |
|
OpenAI-compatible endpoint to list available models. |
|
Returns a fixed list containing our default model. |
|
|
|
This endpoint is required by many OpenAI-compatible clients. |
|
""" |
|
return JSONResponse({ |
|
"object": "list", |
|
"data": [ |
|
{ |
|
"id": MODEL, |
|
"object": "model", |
|
"created": 0, |
|
"owned_by": "J.A.R.V.I.S." |
|
} |
|
] |
|
}) |
|
|
|
@app.get("/v1/history") |
|
async def get_history(session_id: Optional[str] = None): |
|
""" |
|
Endpoint to retrieve chat history for a given session. |
|
|
|
Parameters: |
|
- session_id: The unique session identifier. |
|
|
|
Returns: |
|
- JSON object containing session_id and list of past input-response pairs. |
|
|
|
Raises: |
|
- 404 error if session_id is missing or session does not exist. |
|
""" |
|
if not session_id or session_id not in session_store: |
|
raise HTTPException(status_code=404, detail="Session not found or session_id missing.") |
|
|
|
_, session_data = session_store[session_id] |
|
return {"session_id": session_id, "history": session_data["history"]} |
|
|
|
@app.get("/") |
|
def root(): |
|
""" |
|
Simple health check endpoint. |
|
Returns basic status indicating if API is running. |
|
""" |
|
return {"status": "API is running"} |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|