|
|
|
|
|
|
|
|
|
|
|
import uuid |
|
import time |
|
import json |
|
import asyncio |
|
import logging |
|
import os |
|
from typing import Optional, List, Union, Dict, Any, Literal |
|
from fastapi import FastAPI, HTTPException, Request, status |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from pydantic import BaseModel, Field, ValidationError |
|
from gradio_client import Client |
|
from contextlib import asynccontextmanager |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s | %(levelname)s | %(name)s | %(threadName)s | %(message)s", |
|
handlers=[logging.StreamHandler()] |
|
) |
|
logger = logging.getLogger("api_gateway") |
|
|
|
class SessionData: |
|
def __init__(self): |
|
self.system: str = "" |
|
self.history: List[Dict[str, Any]] = [] |
|
self.last_access: float = time.time() |
|
self.active_tasks: Dict[str, asyncio.Task] = {} |
|
self.last_request_time: float = 0.0 |
|
|
|
class SessionManager: |
|
def __init__(self): |
|
self.sessions: Dict[str, Dict[str, SessionData]] = {} |
|
self.lock = asyncio.Lock() |
|
|
|
async def cleanup(self): |
|
while True: |
|
await asyncio.sleep(60) |
|
async with self.lock: |
|
now = time.time() |
|
expired = [] |
|
for user, sessions in list(self.sessions.items()): |
|
for sid, data in list(sessions.items()): |
|
if now - data.last_access > 300: |
|
expired.append((user, sid)) |
|
for task_id, task in data.active_tasks.items(): |
|
if not task.done(): |
|
task.cancel() |
|
for user, sid in expired: |
|
if user in self.sessions and sid in self.sessions[user]: |
|
del self.sessions[user][sid] |
|
if not self.sessions[user]: |
|
del self.sessions[user] |
|
|
|
async def get_session(self, user: Optional[str], session_id: Optional[str]) -> (str, str, SessionData): |
|
async with self.lock: |
|
if not user: |
|
user = str(uuid.uuid4()) |
|
if user not in self.sessions: |
|
self.sessions[user] = {} |
|
if not session_id or session_id not in self.sessions[user]: |
|
session_id = str(uuid.uuid4()) |
|
self.sessions[user][session_id] = SessionData() |
|
session = self.sessions[user][session_id] |
|
session.last_access = time.time() |
|
return user, session_id, session |
|
|
|
session_manager = SessionManager() |
|
|
|
async def refresh_client(app: FastAPI): |
|
while True: |
|
await asyncio.sleep(15 * 60) |
|
async with app.state.client_lock: |
|
if app.state.client is not None: |
|
try: |
|
old_client = app.state.client |
|
app.state.client = None |
|
del old_client |
|
app.state.client = Client("https://hadadrjt-ai.hf.space/") |
|
logger.info("Refreshed Gradio client connection") |
|
except Exception as e: |
|
logger.error(f"Error refreshing Gradio client: {e}", exc_info=True) |
|
app.state.client = None |
|
|
|
async def clear_terminal_periodically(): |
|
while True: |
|
await asyncio.sleep(300) |
|
if os.name == "nt": |
|
os.system("cls") |
|
else: |
|
print("\033c", end="", flush=True) |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
app.state.session_manager = session_manager |
|
app.state.client = None |
|
app.state.client_lock = asyncio.Lock() |
|
app.state.refresh_task = asyncio.create_task(refresh_client(app)) |
|
app.state.cleanup_task = asyncio.create_task(session_manager.cleanup()) |
|
app.state.clear_log_task = asyncio.create_task(clear_terminal_periodically()) |
|
try: |
|
yield |
|
finally: |
|
app.state.refresh_task.cancel() |
|
app.state.cleanup_task.cancel() |
|
app.state.clear_log_task.cancel() |
|
await asyncio.sleep(0.1) |
|
|
|
app = FastAPI( |
|
title="J.A.R.V.I.S. OpenAI-Compatible API", |
|
version="2.1.3-0625", |
|
lifespan=lifespan, |
|
) |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_methods=["GET", "POST", "OPTIONS", "HEAD"], |
|
allow_headers=["*"], |
|
) |
|
|
|
class Function(BaseModel): |
|
name: str |
|
description: Optional[str] |
|
parameters: Dict[str, Any] |
|
|
|
class Tool(BaseModel): |
|
type: Literal["function"] = "function" |
|
function: Function |
|
|
|
class ToolCall(BaseModel): |
|
id: str |
|
type: Literal["function"] |
|
function: Dict[str, str] |
|
|
|
class Message(BaseModel): |
|
role: str = Field(..., pattern="^(system|user|assistant|tool|function)$") |
|
content: Optional[Union[str, List[Dict[str, Any]]]] |
|
name: Optional[str] = None |
|
tool_calls: Optional[List[ToolCall]] = None |
|
tool_call_id: Optional[str] = None |
|
|
|
class CommonParams(BaseModel): |
|
model: str |
|
stream: bool = False |
|
user: Optional[str] = None |
|
session_id: Optional[str] = None |
|
top_p: Optional[float] = None |
|
top_k: Optional[int] = None |
|
temperature: Optional[float] = None |
|
max_tokens: Optional[int] = None |
|
max_new_tokens: Optional[int] = None |
|
presence_penalty: Optional[float] = None |
|
frequency_penalty: Optional[float] = None |
|
repetition_penalty: Optional[float] = None |
|
logit_bias: Optional[Dict[str, float]] = None |
|
repeat_penalty: Optional[float] = None |
|
seed: Optional[int] = None |
|
tools: Optional[List[Tool]] = None |
|
tool_choice: Optional[Union[str, Dict[str, str]]] = None |
|
functions: Optional[List[Function]] = None |
|
function_call: Optional[Union[str, Dict[str, str]]] = None |
|
|
|
class ChatCompletionRequest(CommonParams): |
|
messages: List[Message] |
|
|
|
class CompletionRequest(CommonParams): |
|
prompt: Union[str, List[str]] |
|
|
|
class EmbeddingRequest(BaseModel): |
|
model: str |
|
input: Union[str, List[str]] |
|
user: Optional[str] = None |
|
|
|
class RouterRequest(CommonParams): |
|
endpoint: Optional[str] = "chat/completions" |
|
messages: Optional[List[Message]] = None |
|
prompt: Optional[Union[str, List[str]]] = None |
|
input: Optional[Union[str, List[str]]] = None |
|
|
|
def sanitize_messages(messages: List[Message]) -> List[Message]: |
|
cleaned = [] |
|
for m in messages: |
|
if isinstance(m.content, list): |
|
texts = [c.get("text", "") for c in m.content if isinstance(c, dict) and c.get("type") == "text"] |
|
if texts: |
|
cleaned.append(Message(role=m.role, content=" ".join(texts))) |
|
elif isinstance(m.content, str): |
|
cleaned.append(m) |
|
return cleaned |
|
|
|
def map_messages(system: str, history: List[Dict[str, Any]], new_msgs: List[Message]) -> str: |
|
msgs = [] |
|
if system: |
|
msgs.append({"role": "system", "content": system}) |
|
msgs.extend(history) |
|
msgs.extend([m.model_dump() for m in new_msgs if m.role != "system"]) |
|
text = "" |
|
for m in msgs: |
|
text += f"{m.get('role','')}:{m.get('content','')}\n" |
|
return text.strip() |
|
|
|
async def get_client(app: FastAPI) -> Client: |
|
async with app.state.client_lock: |
|
if app.state.client is None: |
|
try: |
|
app.state.client = Client("https://hadadrjt-ai.hf.space/") |
|
logger.info("Created Gradio client connection on demand") |
|
except Exception as e: |
|
logger.error(f"Failed to create Gradio client: {e}", exc_info=True) |
|
raise HTTPException(status_code=502, detail="Failed to connect to upstream Gradio app") |
|
return app.state.client |
|
|
|
async def call_gradio(client: Client, params: dict): |
|
for attempt in range(3): |
|
try: |
|
return await asyncio.to_thread(lambda: client.submit(**params)) |
|
except Exception as e: |
|
await asyncio.sleep(0.2 * (attempt + 1)) |
|
raise HTTPException(status_code=502, detail="Upstream Gradio app error") |
|
|
|
async def stream_response(job, session_id: str, session_history: List[Dict[str, Any]], new_messages: List[Message], response_type: str): |
|
partial = "" |
|
try: |
|
chunks = await asyncio.to_thread(lambda: list(job)) |
|
except Exception: |
|
chunks = [] |
|
for chunk in chunks: |
|
try: |
|
if isinstance(chunk, list): |
|
response = next((item.get('content') for item in chunk if isinstance(item, dict) and 'content' in item), str(chunk)) |
|
else: |
|
response = str(chunk) |
|
token = response[len(partial):] if response.startswith(partial) else response |
|
partial = response |
|
if response_type == "chat": |
|
data = { |
|
"id": str(uuid.uuid4()), |
|
"object": "chat.completion.chunk", |
|
"choices": [{"delta": {"content": token}, "index": 0, "finish_reason": None}], |
|
"session_id": session_id |
|
} |
|
else: |
|
data = { |
|
"id": str(uuid.uuid4()), |
|
"object": "text_completion.chunk", |
|
"choices": [{"text": token, "index": 0, "finish_reason": None}], |
|
"session_id": session_id |
|
} |
|
yield f"data: {json.dumps(data)}\n\n" |
|
except Exception: |
|
continue |
|
session_history.extend([m.model_dump() for m in new_messages if m.role != "system"]) |
|
session_history.append({"role": "assistant", "content": partial}) |
|
done_data = { |
|
"id": str(uuid.uuid4()), |
|
"object": f"{response_type}.completion.chunk", |
|
"choices": [{"delta" if response_type=="chat" else "text": {} if response_type=="chat" else "", "index": 0, "finish_reason": "stop"}], |
|
"session_id": session_id |
|
} |
|
yield f"data: {json.dumps(done_data)}\n\n" |
|
|
|
RATE_LIMIT_SECONDS = 1.0 |
|
|
|
@app.post("/v1/chat/completions") |
|
async def chat_completions(req: ChatCompletionRequest): |
|
user, session_id, session = await session_manager.get_session(req.user, req.session_id) |
|
now = time.time() |
|
if now - session.last_request_time < RATE_LIMIT_SECONDS: |
|
raise HTTPException(status_code=429, detail="Too many requests, please slow down") |
|
session.last_request_time = now |
|
req.messages = sanitize_messages(req.messages) |
|
for m in req.messages: |
|
if m.role == "system": |
|
session.system = m.content |
|
break |
|
text = map_messages(session.system, session.history, req.messages) |
|
params = { |
|
"message": text, |
|
"model_label": req.model, |
|
"api_name": "/api", |
|
"top_p": req.top_p, |
|
"top_k": req.top_k, |
|
"temperature": req.temperature, |
|
"max_tokens": req.max_tokens, |
|
"max_new_tokens": req.max_new_tokens, |
|
"presence_penalty": req.presence_penalty, |
|
"frequency_penalty": req.frequency_penalty, |
|
"repetition_penalty": req.repetition_penalty, |
|
"repeat_penalty": req.repeat_penalty, |
|
"logit_bias": req.logit_bias, |
|
"seed": req.seed, |
|
"functions": req.functions or req.tools, |
|
"function_call": req.function_call or req.tool_choice, |
|
} |
|
params = {k: v for k, v in params.items() if v is not None} |
|
client = await get_client(app) |
|
if req.stream: |
|
job = await call_gradio(client, params) |
|
generator = stream_response(job, session_id, session.history, req.messages, "chat") |
|
return StreamingResponse(generator, media_type="text/event-stream") |
|
else: |
|
loop = asyncio.get_running_loop() |
|
try: |
|
result = await loop.run_in_executor(None, lambda: client.predict(**params)) |
|
except Exception: |
|
raise HTTPException(status_code=502, detail="Upstream Gradio app error") |
|
session.history.extend([m.model_dump() for m in req.messages if m.role != "system"]) |
|
session.history.append({"role": "assistant", "content": result}) |
|
return { |
|
"id": str(uuid.uuid4()), |
|
"object": "chat.completion", |
|
"choices": [{"message": {"role": "assistant", "content": result}}], |
|
"session_id": session_id |
|
} |
|
|
|
@app.post("/v1/completions") |
|
async def completions(req: CompletionRequest): |
|
user, session_id, session = await session_manager.get_session(req.user, req.session_id) |
|
now = time.time() |
|
if now - session.last_request_time < RATE_LIMIT_SECONDS: |
|
raise HTTPException(status_code=429, detail="Too many requests, please slow down") |
|
session.last_request_time = now |
|
prompt = req.prompt if isinstance(req.prompt, str) else "\n".join(req.prompt) |
|
params = { |
|
"message": prompt, |
|
"model_label": req.model, |
|
"api_name": "/api", |
|
"top_p": req.top_p, |
|
"top_k": req.top_k, |
|
"temperature": req.temperature, |
|
"max_tokens": req.max_tokens, |
|
"max_new_tokens": req.max_new_tokens, |
|
"presence_penalty": req.presence_penalty, |
|
"frequency_penalty": req.frequency_penalty, |
|
"repetition_penalty": req.repetition_penalty, |
|
"repeat_penalty": req.repeat_penalty, |
|
"logit_bias": req.logit_bias, |
|
"seed": req.seed, |
|
} |
|
params = {k: v for k, v in params.items() if v is not None} |
|
client = await get_client(app) |
|
if req.stream: |
|
job = await call_gradio(client, params) |
|
generator = stream_response(job, session_id, [], [], "text") |
|
return StreamingResponse(generator, media_type="text/event-stream") |
|
else: |
|
loop = asyncio.get_running_loop() |
|
try: |
|
result = await loop.run_in_executor(None, lambda: client.predict(**params)) |
|
except Exception: |
|
raise HTTPException(status_code=502, detail="Upstream Gradio app error") |
|
return {"id": str(uuid.uuid4()), "object": "text_completion", "choices": [{"text": result}]} |
|
|
|
@app.post("/v1/embeddings") |
|
async def embeddings(req: EmbeddingRequest): |
|
inputs = req.input if isinstance(req.input, list) else [req.input] |
|
embeddings = [[0.0] * 768 for _ in inputs] |
|
return {"object": "list", "data": [{"embedding": emb, "index": i} for i, emb in enumerate(embeddings)]} |
|
|
|
@app.get("/v1/models") |
|
async def get_models(): |
|
return {"object": "list", "data": [{"id": "Q8_K_XL", "object": "model", "owned_by": "J.A.R.V.I.S."}]} |
|
|
|
@app.get("/v1/history") |
|
async def get_history(user: Optional[str] = None, session_id: Optional[str] = None): |
|
user = user or "anonymous" |
|
sessions = session_manager.sessions |
|
if user in sessions and session_id and session_id in sessions[user]: |
|
return {"user": user, "session_id": session_id, "history": sessions[user][session_id].history} |
|
return {"user": user, "session_id": session_id, "history": []} |
|
|
|
@app.post("/v1/responses/cancel") |
|
async def cancel_response(user: Optional[str], session_id: Optional[str], task_id: Optional[str]): |
|
user = user or "anonymous" |
|
if not task_id: |
|
raise HTTPException(status_code=400, detail="Missing task_id for cancellation") |
|
async with session_manager.lock: |
|
if user in session_manager.sessions and session_id in session_manager.sessions[user]: |
|
session = session_manager.sessions[user][session_id] |
|
task = session.active_tasks.get(task_id) |
|
if task and not task.done(): |
|
task.cancel() |
|
return {"message": f"Cancelled task {task_id}"} |
|
raise HTTPException(status_code=404, detail="Task not found or already completed") |
|
|
|
@app.api_route("/v1", methods=["POST", "GET", "OPTIONS", "HEAD"]) |
|
async def router(request: Request): |
|
if request.method == "POST": |
|
try: |
|
body_json = await request.json() |
|
except Exception: |
|
raise HTTPException(status_code=400, detail="Invalid JSON body") |
|
try: |
|
body = RouterRequest(**body_json) |
|
except ValidationError as e: |
|
raise HTTPException(status_code=422, detail=e.errors()) |
|
endpoint = body.endpoint or "chat/completions" |
|
if endpoint == "chat/completions": |
|
if not body.model or not body.messages: |
|
raise HTTPException(status_code=422, detail="Missing 'model' or 'messages'") |
|
req_obj = ChatCompletionRequest(**body.dict()) |
|
return await chat_completions(req_obj) |
|
elif endpoint == "completions": |
|
if not body.model or not body.prompt: |
|
raise HTTPException(status_code=422, detail="Missing 'model' or 'prompt'") |
|
req_obj = CompletionRequest(**body.dict()) |
|
return await completions(req_obj) |
|
elif endpoint == "embeddings": |
|
if not body.model or body.input is None: |
|
raise HTTPException(status_code=422, detail="Missing 'model' or 'input'") |
|
req_obj = EmbeddingRequest(**body.dict()) |
|
return await embeddings(req_obj) |
|
elif endpoint == "models": |
|
return await get_models() |
|
elif endpoint == "history": |
|
return await get_history(body.user, body.session_id) |
|
elif endpoint == "responses/cancel": |
|
return await cancel_response(body.user, body.session_id, body.tool_choice if isinstance(body.tool_choice, str) else None) |
|
else: |
|
raise HTTPException(status_code=404, detail="Endpoint not found") |
|
else: |
|
return JSONResponse({"message": "Send POST request with JSON body"}, status_code=status.HTTP_405_METHOD_NOT_ALLOWED) |
|
|
|
@app.get("/") |
|
async def root(): |
|
return { |
|
"endpoints": [ |
|
"/v1/chat/completions", |
|
"/v1/completions", |
|
"/v1/embeddings", |
|
"/v1/models", |
|
"/v1/history", |
|
"/v1/responses/cancel", |
|
"/v1" |
|
] |
|
} |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|