File size: 16,462 Bytes
eb36b93
 
 
 
 
 
 
 
91573a9
eb36b93
 
 
 
 
 
 
c0d067d
eb36b93
91573a9
eb36b93
 
91573a9
 
 
 
 
 
 
 
eb36b93
91573a9
 
eb36b93
91573a9
 
eb36b93
 
 
91573a9
 
 
 
 
 
 
eb36b93
91573a9
eb36b93
 
91573a9
eb36b93
c0d067d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91573a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb36b93
91573a9
 
 
 
 
 
 
eb36b93
91573a9
 
 
 
eb36b93
91573a9
 
 
 
 
 
 
 
 
 
 
 
 
eb36b93
91573a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb36b93
91573a9
 
 
 
 
 
 
 
 
eb36b93
 
 
91573a9
eb36b93
 
91573a9
eb36b93
 
 
 
 
91573a9
eb36b93
 
91573a9
eb36b93
 
91573a9
eb36b93
 
 
 
 
 
 
 
 
 
 
 
 
 
91573a9
eb36b93
 
91573a9
 
 
 
 
eb36b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91573a9
eb36b93
 
 
 
 
 
 
 
c0d067d
91573a9
 
 
 
 
 
 
 
 
 
eb36b93
91573a9
 
 
eb36b93
 
91573a9
 
 
 
eb36b93
91573a9
 
 
eb36b93
91573a9
 
 
 
 
 
eb36b93
 
 
 
91573a9
 
 
 
 
eb36b93
91573a9
eb36b93
 
 
 
 
 
 
 
 
 
 
 
 
 
91573a9
 
eb36b93
 
91573a9
eb36b93
 
c0d067d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91573a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb36b93
 
 
91573a9
 
eb36b93
91573a9
eb36b93
91573a9
eb36b93
91573a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
#
# SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
# SPDX-License-Identifier: Apache-2.0
#

import json
import time
import uuid
import asyncio
import uvicorn

from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from gradio_client import Client
from pydantic import BaseModel
from typing import AsyncGenerator, Optional, Dict, List, Tuple, Any

# Default AI model name used when no model is specified by user
MODEL = "JARVIS: 2.1.3"

# Session store keeps track of active sessions.
# Each session_id maps to a tuple:
# (last_update_timestamp, session_data_dict)
# session_data_dict contains:
#   - "model": the AI model name used in this session
#   - "history": list of past chat messages (input and response)
#   - "client": the Gradio Client instance specific to this session
session_store: Dict[str, Tuple[float, Dict]] = {}

# Duration (in seconds) after which inactive sessions are removed
EXPIRE = 3600  # 1 hour

# Create FastAPI app instance
app = FastAPI()

class ResponseRequest(BaseModel):
    """
    Defines the expected structure of the request body for /v1/responses endpoint.
    
    Attributes:
    - model: Optional; specifies which AI model to use. Defaults to MODEL if not provided.
    - input: The user's input text to send to the AI.
    - stream: Optional; if True, the response will be streamed incrementally.
    - session_id: Optional; unique identifier for the user's session. If missing, a new session will be created.
    """
    model: Optional[str] = None
    input: str
    stream: Optional[bool] = False
    session_id: Optional[str] = None

class OpenAIChatRequest(BaseModel):
    """
    Defines the OpenAI-compatible request structure for /v1/chat/completions endpoint.
    
    Attributes:
    - model: Optional; specifies which AI model to use. Defaults to MODEL if not provided.
    - messages: List of message objects containing 'role' and 'content'
    - stream: Optional; if True, the response will be streamed incrementally.
    - session_id: Optional; unique session identifier for maintaining conversation history
    """
    model: Optional[str] = None
    messages: List[Dict[str, str]]
    stream: Optional[bool] = False
    session_id: Optional[str] = None

def cleanup_expired_sessions():
    """
    Remove sessions that have been inactive for longer than EXPIRE.
    This helps free up memory by deleting old sessions and closing their clients.
    """
    now = time.time()
    expired_sessions = [
        sid for sid, (last_update, _) in session_store.items()
        if now - last_update > EXPIRE
    ]
    for sid in expired_sessions:
        # Attempt to close the Gradio client associated with the session
        _, data = session_store[sid]
        client = data.get("client")
        if client:
            try:
                client.close()
            except Exception:
                # Ignore errors during client close to avoid crashing cleanup
                pass
        # Remove the session from the store
        del session_store[sid]

def create_client_for_model(model: str) -> Client:
    """
    Create a new Gradio Client instance and set it to use the specified AI model.
    
    Parameters:
    - model: The name of the AI model to initialize the client with.
    
    Returns:
    - A new Gradio Client instance configured with the given model.
    """
    client = Client("hadadrjt/ai")
    # Set the model on the Gradio client by calling the /change_model API
    client.predict(new=model, api_name="/change_model")
    return client

def get_or_create_session(session_id: Optional[str], model: str) -> str:
    """
    Retrieve an existing session by session_id or create a new one if it doesn't exist.
    Also cleans up expired sessions before proceeding.
    
    Parameters:
    - session_id: The unique identifier of the session (optional).
    - model: The AI model to use for this session.
    
    Returns:
    - The session_id for the active or newly created session.
    """
    cleanup_expired_sessions()

    # If no session_id provided or session does not exist, create a new session
    if not session_id or session_id not in session_store:
        session_id = str(uuid.uuid4())  # Generate a new unique session ID
        client = create_client_for_model(model)  # Create a new client for this session
        session_store[session_id] = (time.time(), {
            "model": model,
            "history": [],
            "client": client
        })
    else:
        # Session exists, update last access time and check if model changed
        last_update, data = session_store[session_id]
        if data["model"] != model:
            # If model changed, close old client and create a new one with the new model
            old_client = data.get("client")
            if old_client:
                try:
                    old_client.close()
                except Exception:
                    pass  # Ignore errors on close
            new_client = create_client_for_model(model)
            data["model"] = model
            data["client"] = new_client
            session_store[session_id] = (time.time(), data)
        else:
            # Just update the last access time to keep session alive
            session_store[session_id] = (time.time(), data)

    return session_id

async def event_generator(user_input: str, model: str, session_id: str) -> AsyncGenerator[str, None]:
    """
    Asynchronous generator that streams AI responses incrementally as Server-Sent Events (SSE).
    
    Parameters:
    - user_input: The input text from the user.
    - model: The AI model to use.
    - session_id: The unique session identifier.
    
    Yields:
    - JSON-formatted chunks representing incremental AI response deltas.
    """
    last_update, session_data = session_store.get(session_id, (0, None))
    if session_data is None:
        # Session not found; yield error and stop
        yield f"data: {json.dumps({'error': 'Session not found'})}\n\n"
        return

    client = session_data["client"]
    if client is None:
        # Client missing for session; yield error and stop
        yield f"data: {json.dumps({'error': 'AI client not available'})}\n\n"
        return

    try:
        # Submit the user input to the AI model via Gradio client
        jarvis_response = client.submit(multi={"text": user_input}, api_name="/api")
    except Exception as e:
        # If submission fails, yield error and stop
        yield f"data: {json.dumps({'error': f'Failed to submit to AI: {str(e)}'})}\n\n"
        return

    buffer = ""  # Buffer to track full response text progressively

    try:
        for partial in jarvis_response:
            # Extract the current partial text from the response
            text = partial[0][0][1]

            # Calculate the delta (new text since last chunk)
            if text.startswith(buffer):
                delta = text[len(buffer):]
            else:
                delta = text

            buffer = text  # Update buffer with latest full text

            if delta == "":
                # Skip empty delta chunks
                continue

            # Prepare chunk data in OpenAI streaming format
            chunk = {
                "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
                "object": "chat.completion.chunk",
                "created": int(time.time()),
                "model": model,
                "choices": [
                    {
                        "index": 0,
                        "delta": {"content": delta},
                        "finish_reason": None
                    }
                ]
            }

            # Yield the chunk as a Server-Sent Event
            yield f"data: {json.dumps(chunk)}\n\n"

        # After streaming completes, save the full input-response pair to session history
        session_data["history"].append({"input": user_input, "response": buffer})
        session_store[session_id] = (time.time(), session_data)  # Update last access time

        # Send a final chunk signaling completion of the stream
        done_chunk = {
            "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
            "object": "chat.completion.chunk",
            "created": int(time.time()),
            "model": model,
            "choices": [
                {
                    "index": 0,
                    "delta": {},
                    "finish_reason": "stop"
                }
            ]
        }
        yield f"data: {json.dumps(done_chunk)}\n\n"

    except Exception as e:
        # If streaming fails at any point, yield an error chunk
        error_chunk = {
            "error": {"message": f"Streaming error: {str(e)}"}
        }
        yield f"data: {json.dumps(error_chunk)}\n\n"

@app.post("/v1/responses")
async def responses(req: ResponseRequest):
    """
    Original API endpoint to get AI responses.
    Supports both streaming and non-streaming modes.
    
    Workflow:
    - Validate or create session.
    - Ensure AI client is available.
    - Handle streaming or full response accordingly.
    - Save chat history per session.
    
    Returns:
    - JSON response with AI output and session ID.
    """
    model = req.model or MODEL  # Use requested model or default
    session_id = get_or_create_session(req.session_id, model)  # Get or create session
    last_update, session_data = session_store[session_id]
    user_input = req.input

    client = session_data["client"]
    if client is None:
        # If client is missing, return 503 error
        raise HTTPException(status_code=503, detail="AI client not available")

    if req.stream:
        # If streaming requested, return a streaming response using event_generator
        return StreamingResponse(event_generator(user_input, model, session_id), media_type="text/event-stream")

    # Non-streaming request: submit input and collect full response
    try:
        jarvis_response = client.submit(multi={"text": user_input}, api_name="/api")
    except Exception as e:
        # Return 500 error if submission fails
        raise HTTPException(status_code=500, detail=f"Failed to submit to AI: {str(e)}")

    buffer = ""
    for partial in jarvis_response:
        text = partial[0][0][1]
        buffer = text  # Update buffer with latest full response

    # Save input and response to session history and update last access time
    session_data["history"].append({"input": user_input, "response": buffer})
    session_store[session_id] = (time.time(), session_data)

    # Prepare the JSON response in OpenAI style format
    response = {
        "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
        "object": "chat.completion",
        "created": int(time.time()),
        "model": model,
        "choices": [
            {
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": buffer
                },
                "finish_reason": "stop"
            }
        ],
        "session_id": session_id  # Return session_id so client can reuse it
    }

    # Return the JSON response
    return JSONResponse(response)

@app.post("/v1/chat/completions")
async def openai_chat_completions(req: OpenAIChatRequest):
    """
    OpenAI-compatible endpoint for chat completions.
    Supports both streaming and non-streaming modes.
    
    Workflow:
    - Validate message structure and extract conversation history
    - Validate or create session
    - Update session history from messages
    - Handle streaming or full response
    - Save new interaction to session history
    
    Returns:
    - JSON response in OpenAI format with session ID extension
    """
    # Validate messages structure
    if not req.messages:
        raise HTTPException(status_code=400, detail="Messages cannot be empty")
    
    # Extract conversation history and current input
    history = []
    current_input = ""
    
    # Process messages to extract conversation history
    try:
        # Last message should be from user and used as current input
        if req.messages[-1]["role"] != "user":
            raise ValueError("Last message must be from user")
        
        current_input = req.messages[-1]["content"]
        
        # Process message pairs (user + assistant)
        messages = req.messages[:-1]  # Exclude last message (current input)
        for i in range(0, len(messages), 2):
            if i+1 < len(messages):
                user_msg = messages[i]
                assistant_msg = messages[i+1]
                
                if user_msg["role"] != "user" or assistant_msg["role"] != "assistant":
                    # Skip invalid pairs but continue processing
                    continue
                
                history.append({
                    "input": user_msg["content"],
                    "response": assistant_msg["content"]
                })
    except (KeyError, ValueError) as e:
        raise HTTPException(status_code=400, detail=f"Invalid message format: {str(e)}")

    model = req.model or MODEL  # Use requested model or default
    session_id = get_or_create_session(req.session_id, model)  # Get or create session
    last_update, session_data = session_store[session_id]

    # Update session history from messages
    session_data["history"] = history
    session_store[session_id] = (time.time(), session_data)

    client = session_data["client"]
    if client is None:
        raise HTTPException(status_code=503, detail="AI client not available")

    if req.stream:
        # Streaming response
        return StreamingResponse(
            event_generator(current_input, model, session_id),
            media_type="text/event-stream"
        )

    # Non-streaming response
    try:
        jarvis_response = client.submit(multi={"text": current_input}, api_name="/api")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to submit to AI: {str(e)}")

    buffer = ""
    for partial in jarvis_response:
        text = partial[0][0][1]
        buffer = text

    # Update session history with new interaction
    session_data["history"].append({"input": current_input, "response": buffer})
    session_store[session_id] = (time.time(), session_data)

    # Format response in OpenAI style
    response = {
        "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
        "object": "chat.completion",
        "created": int(time.time()),
        "model": model,
        "choices": [
            {
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": buffer
                },
                "finish_reason": "stop"
            }
        ],
        "session_id": session_id  # Custom extension for session management
    }

    return JSONResponse(response)

@app.get("/v1/models")
async def list_models():
    """
    OpenAI-compatible endpoint to list available models.
    Returns a fixed list containing our default model.
    
    This endpoint is required by many OpenAI-compatible clients.
    """
    return JSONResponse({
        "object": "list",
        "data": [
            {
                "id": MODEL,
                "object": "model",
                "created": 0,  # Timestamp not available
                "owned_by": "J.A.R.V.I.S."
            }
        ]
    })

@app.get("/v1/history")
async def get_history(session_id: Optional[str] = None):
    """
    Endpoint to retrieve chat history for a given session.
    
    Parameters:
    - session_id: The unique session identifier.
    
    Returns:
    - JSON object containing session_id and list of past input-response pairs.
    
    Raises:
    - 404 error if session_id is missing or session does not exist.
    """
    if not session_id or session_id not in session_store:
        raise HTTPException(status_code=404, detail="Session not found or session_id missing.")

    _, session_data = session_store[session_id]
    return {"session_id": session_id, "history": session_data["history"]}

@app.get("/")
def root():
    """
    Simple health check endpoint.
    Returns basic status indicating if API is running.
    """
    return {"status": "API is running"}

# Run the app with Uvicorn ASGI server when executed directly
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)