hadadrjt commited on
Commit
91573a9
·
1 Parent(s): eb36b93

api: Production ready.

Browse files
Files changed (1) hide show
  1. app.py +207 -59
app.py CHANGED
@@ -6,6 +6,7 @@
6
  import json
7
  import time
8
  import uuid
 
9
  import uvicorn
10
 
11
  from contextlib import asynccontextmanager
@@ -13,74 +14,175 @@ from fastapi import FastAPI, HTTPException
13
  from fastapi.responses import JSONResponse, StreamingResponse
14
  from gradio_client import Client
15
  from pydantic import BaseModel
16
- from typing import AsyncGenerator, Optional
17
 
18
- # Default AI model
19
  MODEL = "JARVIS: 2.1.3"
20
 
21
- # Global Gradio client instance
22
- jarvis: Optional[Client] = None
 
 
 
 
 
 
23
 
24
- @asynccontextmanager
25
- async def lifespan(app: FastAPI):
26
- """
27
- Initialize Gradio client at app startup.
28
- """
29
- global jarvis
30
- print("Initializing Gradio AI client...")
31
- try:
32
- jarvis = Client("hadadrjt/ai")
33
- print(f"Connected to Gradio AI client at: {jarvis.src}")
34
-
35
- jarvis.predict(new=MODEL, api_name="/change_model")
36
- print(f"Default model set to: {MODEL}")
37
 
38
- yield
39
- except Exception as e:
40
- print(f"Error initializing Gradio client: {e}")
41
- yield
42
-
43
- app = FastAPI(lifespan=lifespan)
44
 
45
  class ResponseRequest(BaseModel):
46
  """
47
- Request body for /v1/responses endpoint.
48
- - model: AI model to use (optional).
49
- - input: User input text.
50
- - stream: Whether to stream response.
 
 
 
51
  """
52
- model: Optional[str] = MODEL
53
  input: str
54
  stream: Optional[bool] = False
 
55
 
56
- async def event_generator(user_input: str, model: str) -> AsyncGenerator[str, None]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  """
58
- Stream incremental AI responses (deltas) as Server-Sent Events.
 
 
 
 
 
 
59
  """
60
- global jarvis
 
 
 
61
 
62
- if model != MODEL:
63
- jarvis.predict(new=model, api_name="/change_model")
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- jarvis_response = jarvis.submit(multi={"text": user_input}, api_name="/api")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- buffer = ""
 
 
 
 
 
 
 
 
68
 
69
  try:
70
  for partial in jarvis_response:
 
71
  text = partial[0][0][1]
72
 
 
73
  if text.startswith(buffer):
74
  delta = text[len(buffer):]
75
  else:
76
  delta = text
77
 
78
- buffer = text
79
 
80
- # Skip empty chunks
81
  if delta == "":
 
82
  continue
83
 
 
84
  chunk = {
85
  "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
86
  "object": "chat.completion.chunk",
@@ -95,9 +197,14 @@ async def event_generator(user_input: str, model: str) -> AsyncGenerator[str, No
95
  ]
96
  }
97
 
 
98
  yield f"data: {json.dumps(chunk)}\n\n"
99
 
100
- # Final chunk to signal completion
 
 
 
 
101
  done_chunk = {
102
  "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
103
  "object": "chat.completion.chunk",
@@ -114,6 +221,7 @@ async def event_generator(user_input: str, model: str) -> AsyncGenerator[str, No
114
  yield f"data: {json.dumps(done_chunk)}\n\n"
115
 
116
  except Exception as e:
 
117
  error_chunk = {
118
  "error": {"message": f"Streaming error: {str(e)}"}
119
  }
@@ -122,30 +230,49 @@ async def event_generator(user_input: str, model: str) -> AsyncGenerator[str, No
122
  @app.post("/v1/responses")
123
  async def responses(req: ResponseRequest):
124
  """
125
- Main endpoint to get AI response.
126
- Supports streaming or full JSON response.
 
 
 
 
 
 
 
 
 
127
  """
128
- global jarvis
129
-
130
- if jarvis is None:
131
- raise HTTPException(status_code=503, detail="AI service not initialized or failed to connect.")
132
-
133
  user_input = req.input
134
- model = req.model or MODEL
135
 
136
- if req.stream:
137
- return StreamingResponse(event_generator(user_input, model), media_type="text/event-stream")
 
 
138
 
139
- if model != MODEL:
140
- jarvis.predict(new=model, api_name="/change_model")
 
141
 
142
- jarvis_response = jarvis.submit(multi={"text": user_input}, api_name="/api")
 
 
 
 
 
143
 
144
  buffer = ""
145
  for partial in jarvis_response:
146
  text = partial[0][0][1]
147
- buffer = text
 
 
 
 
148
 
 
149
  response = {
150
  "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
151
  "object": "chat.completion",
@@ -160,20 +287,41 @@ async def responses(req: ResponseRequest):
160
  },
161
  "finish_reason": "stop"
162
  }
163
- ]
 
164
  }
165
 
 
166
  return JSONResponse(response)
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  @app.get("/")
169
  def root():
170
  """
171
- Health check endpoint.
 
172
  """
173
- if jarvis:
174
- return {"status": "API is running", "jarvis_service": True}
175
- else:
176
- return {"status": "API is running", "jarvis_service": False, "message": "AI service not ready."}
177
 
 
178
  if __name__ == "__main__":
179
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
6
  import json
7
  import time
8
  import uuid
9
+ import asyncio
10
  import uvicorn
11
 
12
  from contextlib import asynccontextmanager
 
14
  from fastapi.responses import JSONResponse, StreamingResponse
15
  from gradio_client import Client
16
  from pydantic import BaseModel
17
+ from typing import AsyncGenerator, Optional, Dict, List, Tuple
18
 
19
+ # Default AI model name used when no model is specified by user
20
  MODEL = "JARVIS: 2.1.3"
21
 
22
+ # Session store keeps track of active sessions.
23
+ # Each session_id maps to a tuple:
24
+ # (last_update_timestamp, session_data_dict)
25
+ # session_data_dict contains:
26
+ # - "model": the AI model name used in this session
27
+ # - "history": list of past chat messages (input and response)
28
+ # - "client": the Gradio Client instance specific to this session
29
+ session_store: Dict[str, Tuple[float, Dict]] = {}
30
 
31
+ # Duration (in seconds) after which inactive sessions are removed
32
+ EXPIRE = 3600 # 1 hour
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Create FastAPI app instance
35
+ app = FastAPI()
 
 
 
 
36
 
37
  class ResponseRequest(BaseModel):
38
  """
39
+ Defines the expected structure of the request body for /v1/responses endpoint.
40
+
41
+ Attributes:
42
+ - model: Optional; specifies which AI model to use. Defaults to MODEL if not provided.
43
+ - input: The user's input text to send to the AI.
44
+ - stream: Optional; if True, the response will be streamed incrementally.
45
+ - session_id: Optional; unique identifier for the user's session. If missing, a new session will be created.
46
  """
47
+ model: Optional[str] = None
48
  input: str
49
  stream: Optional[bool] = False
50
+ session_id: Optional[str] = None
51
 
52
+ def cleanup_expired_sessions():
53
+ """
54
+ Remove sessions that have been inactive for longer than EXPIRE.
55
+ This helps free up memory by deleting old sessions and closing their clients.
56
+ """
57
+ now = time.time()
58
+ expired_sessions = [
59
+ sid for sid, (last_update, _) in session_store.items()
60
+ if now - last_update > EXPIRE
61
+ ]
62
+ for sid in expired_sessions:
63
+ # Attempt to close the Gradio client associated with the session
64
+ _, data = session_store[sid]
65
+ client = data.get("client")
66
+ if client:
67
+ try:
68
+ client.close()
69
+ except Exception:
70
+ # Ignore errors during client close to avoid crashing cleanup
71
+ pass
72
+ # Remove the session from the store
73
+ del session_store[sid]
74
+
75
+ def create_client_for_model(model: str) -> Client:
76
  """
77
+ Create a new Gradio Client instance and set it to use the specified AI model.
78
+
79
+ Parameters:
80
+ - model: The name of the AI model to initialize the client with.
81
+
82
+ Returns:
83
+ - A new Gradio Client instance configured with the given model.
84
  """
85
+ client = Client("hadadrjt/ai")
86
+ # Set the model on the Gradio client by calling the /change_model API
87
+ client.predict(new=model, api_name="/change_model")
88
+ return client
89
 
90
+ def get_or_create_session(session_id: Optional[str], model: str) -> str:
91
+ """
92
+ Retrieve an existing session by session_id or create a new one if it doesn't exist.
93
+ Also cleans up expired sessions before proceeding.
94
+
95
+ Parameters:
96
+ - session_id: The unique identifier of the session (optional).
97
+ - model: The AI model to use for this session.
98
+
99
+ Returns:
100
+ - The session_id for the active or newly created session.
101
+ """
102
+ cleanup_expired_sessions()
103
 
104
+ # If no session_id provided or session does not exist, create a new session
105
+ if not session_id or session_id not in session_store:
106
+ session_id = str(uuid.uuid4()) # Generate a new unique session ID
107
+ client = create_client_for_model(model) # Create a new client for this session
108
+ session_store[session_id] = (time.time(), {
109
+ "model": model,
110
+ "history": [],
111
+ "client": client
112
+ })
113
+ else:
114
+ # Session exists, update last access time and check if model changed
115
+ last_update, data = session_store[session_id]
116
+ if data["model"] != model:
117
+ # If model changed, close old client and create a new one with the new model
118
+ old_client = data.get("client")
119
+ if old_client:
120
+ try:
121
+ old_client.close()
122
+ except Exception:
123
+ pass # Ignore errors on close
124
+ new_client = create_client_for_model(model)
125
+ data["model"] = model
126
+ data["client"] = new_client
127
+ session_store[session_id] = (time.time(), data)
128
+ else:
129
+ # Just update the last access time to keep session alive
130
+ session_store[session_id] = (time.time(), data)
131
+
132
+ return session_id
133
+
134
+ async def event_generator(user_input: str, model: str, session_id: str) -> AsyncGenerator[str, None]:
135
+ """
136
+ Asynchronous generator that streams AI responses incrementally as Server-Sent Events (SSE).
137
+
138
+ Parameters:
139
+ - user_input: The input text from the user.
140
+ - model: The AI model to use.
141
+ - session_id: The unique session identifier.
142
+
143
+ Yields:
144
+ - JSON-formatted chunks representing incremental AI response deltas.
145
+ """
146
+ last_update, session_data = session_store.get(session_id, (0, None))
147
+ if session_data is None:
148
+ # Session not found; yield error and stop
149
+ yield f"data: {json.dumps({'error': 'Session not found'})}\n\n"
150
+ return
151
+
152
+ client = session_data["client"]
153
+ if client is None:
154
+ # Client missing for session; yield error and stop
155
+ yield f"data: {json.dumps({'error': 'AI client not available'})}\n\n"
156
+ return
157
 
158
+ try:
159
+ # Submit the user input to the AI model via Gradio client
160
+ jarvis_response = client.submit(multi={"text": user_input}, api_name="/api")
161
+ except Exception as e:
162
+ # If submission fails, yield error and stop
163
+ yield f"data: {json.dumps({'error': f'Failed to submit to AI: {str(e)}'})}\n\n"
164
+ return
165
+
166
+ buffer = "" # Buffer to track full response text progressively
167
 
168
  try:
169
  for partial in jarvis_response:
170
+ # Extract the current partial text from the response
171
  text = partial[0][0][1]
172
 
173
+ # Calculate the delta (new text since last chunk)
174
  if text.startswith(buffer):
175
  delta = text[len(buffer):]
176
  else:
177
  delta = text
178
 
179
+ buffer = text # Update buffer with latest full text
180
 
 
181
  if delta == "":
182
+ # Skip empty delta chunks
183
  continue
184
 
185
+ # Prepare chunk data in OpenAI streaming format
186
  chunk = {
187
  "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
188
  "object": "chat.completion.chunk",
 
197
  ]
198
  }
199
 
200
+ # Yield the chunk as a Server-Sent Event
201
  yield f"data: {json.dumps(chunk)}\n\n"
202
 
203
+ # After streaming completes, save the full input-response pair to session history
204
+ session_data["history"].append({"input": user_input, "response": buffer})
205
+ session_store[session_id] = (time.time(), session_data) # Update last access time
206
+
207
+ # Send a final chunk signaling completion of the stream
208
  done_chunk = {
209
  "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
210
  "object": "chat.completion.chunk",
 
221
  yield f"data: {json.dumps(done_chunk)}\n\n"
222
 
223
  except Exception as e:
224
+ # If streaming fails at any point, yield an error chunk
225
  error_chunk = {
226
  "error": {"message": f"Streaming error: {str(e)}"}
227
  }
 
230
  @app.post("/v1/responses")
231
  async def responses(req: ResponseRequest):
232
  """
233
+ Main API endpoint to get AI responses.
234
+ Supports both streaming and non-streaming modes.
235
+
236
+ Workflow:
237
+ - Validate or create session.
238
+ - Ensure AI client is available.
239
+ - Handle streaming or full response accordingly.
240
+ - Save chat history per session.
241
+
242
+ Returns:
243
+ - JSON response with AI output and session ID.
244
  """
245
+ model = req.model or MODEL # Use requested model or default
246
+ session_id = get_or_create_session(req.session_id, model) # Get or create session
247
+ last_update, session_data = session_store[session_id]
 
 
248
  user_input = req.input
 
249
 
250
+ client = session_data["client"]
251
+ if client is None:
252
+ # If client is missing, return 503 error
253
+ raise HTTPException(status_code=503, detail="AI client not available")
254
 
255
+ if req.stream:
256
+ # If streaming requested, return a streaming response using event_generator
257
+ return StreamingResponse(event_generator(user_input, model, session_id), media_type="text/event-stream")
258
 
259
+ # Non-streaming request: submit input and collect full response
260
+ try:
261
+ jarvis_response = client.submit(multi={"text": user_input}, api_name="/api")
262
+ except Exception as e:
263
+ # Return 500 error if submission fails
264
+ raise HTTPException(status_code=500, detail=f"Failed to submit to AI: {str(e)}")
265
 
266
  buffer = ""
267
  for partial in jarvis_response:
268
  text = partial[0][0][1]
269
+ buffer = text # Update buffer with latest full response
270
+
271
+ # Save input and response to session history and update last access time
272
+ session_data["history"].append({"input": user_input, "response": buffer})
273
+ session_store[session_id] = (time.time(), session_data)
274
 
275
+ # Prepare the JSON response in OpenAI style format
276
  response = {
277
  "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
278
  "object": "chat.completion",
 
287
  },
288
  "finish_reason": "stop"
289
  }
290
+ ],
291
+ "session_id": session_id # Return session_id so client can reuse it
292
  }
293
 
294
+ # Return the JSON response
295
  return JSONResponse(response)
296
 
297
+ @app.get("/v1/history")
298
+ async def get_history(session_id: Optional[str] = None):
299
+ """
300
+ Endpoint to retrieve chat history for a given session.
301
+
302
+ Parameters:
303
+ - session_id: The unique session identifier.
304
+
305
+ Returns:
306
+ - JSON object containing session_id and list of past input-response pairs.
307
+
308
+ Raises:
309
+ - 404 error if session_id is missing or session does not exist.
310
+ """
311
+ if not session_id or session_id not in session_store:
312
+ raise HTTPException(status_code=404, detail="Session not found or session_id missing.")
313
+
314
+ _, session_data = session_store[session_id]
315
+ return {"session_id": session_id, "history": session_data["history"]}
316
+
317
  @app.get("/")
318
  def root():
319
  """
320
+ Simple health check endpoint.
321
+ Returns basic status indicating if API is running.
322
  """
323
+ return {"status": "API is running"}
 
 
 
324
 
325
+ # Run the app with Uvicorn ASGI server when executed directly
326
  if __name__ == "__main__":
327
+ uvicorn.run(app, host="0.0.0.0", port=7860)