Mono / main.py
AIMaster7's picture
Update main.py
c18e9c8 verified
raw
history blame
18.2 kB
import base64
import json
import os
import secrets
import string
import time
import tempfile
import ast # <-- NEW IMPORT for safe literal evaluation
from typing import List, Optional, Union, Any
import httpx
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field, model_validator
# New import for OCR
from gradio_client import Client, handle_file
# --- Configuration ---
load_dotenv()
IMAGE_API_URL = os.environ.get("IMAGE_API_URL", "https://image.api.example.com")
SNAPZION_UPLOAD_URL = "https://upload.snapzion.com/api/public-upload"
SNAPZION_API_KEY = os.environ.get("SNAP", "")
CHAT_API_URL = "https://www.chatwithmono.xyz/api/chat"
IMAGE_GEN_API_URL = "https://www.chatwithmono.xyz/api/image"
MODERATION_API_URL = "https://www.chatwithmono.xyz/api/moderation"
# --- Model Definitions ---
AVAILABLE_MODELS = [
{"id": "gpt-4-turbo", "object": "model", "created": int(time.time()), "owned_by": "system"},
{"id": "gpt-4o", "object": "model", "created": int(time.time()), "owned_by": "system"},
{"id": "gpt-3.5-turbo", "object": "model", "created": int(time.time()), "owned_by": "system"},
{"id": "dall-e-3", "object": "model", "created": int(time.time()), "owned_by": "system"},
{"id": "text-moderation-stable", "object": "model", "created": int(time.time()), "owned_by": "system"},
{"id": "florence-2-ocr", "object": "model", "created": int(time.time()), "owned_by": "system"},
]
MODEL_ALIASES = {}
# --- FastAPI Application & Global Clients ---
app = FastAPI(
title="OpenAI Compatible API",
description="An adapter for various services to be compatible with the OpenAI API specification.",
version="1.1.2" # Incremented version for the new fix
)
try:
ocr_client = Client("multimodalart/Florence-2-l4")
except Exception as e:
print(f"Warning: Could not initialize Gradio client for OCR: {e}")
ocr_client = None
# --- Pydantic Models ---
# (Pydantic models are unchanged and remain the same as before)
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
model: str
stream: Optional[bool] = False
tools: Optional[Any] = None
class ImageGenerationRequest(BaseModel):
prompt: str
aspect_ratio: Optional[str] = "1:1"
n: Optional[int] = 1
user: Optional[str] = None
model: Optional[str] = "default"
class ModerationRequest(BaseModel):
input: Union[str, List[str]]
model: Optional[str] = "text-moderation-stable"
class OcrRequest(BaseModel):
image_url: Optional[str] = Field(None, description="URL of the image to process.")
image_b64: Optional[str] = Field(None, description="Base64 encoded string of the image to process.")
@model_validator(mode='before')
@classmethod
def check_sources(cls, data: Any) -> Any:
if isinstance(data, dict):
if not (data.get('image_url') or data.get('image_b64')):
raise ValueError('Either image_url or image_b64 must be provided.')
if data.get('image_url') and data.get('image_b64'):
raise ValueError('Provide either image_url or image_b64, not both.')
return data
class OcrResponse(BaseModel):
ocr_text: str
raw_response: dict
# --- Helper Function ---
def generate_random_id(prefix: str, length: int = 29) -> str:
population = string.ascii_letters + string.digits
random_part = "".join(secrets.choice(population) for _ in range(length))
return f"{prefix}{random_part}"
# === API Endpoints ===
@app.get("/v1/models", tags=["Models"])
async def list_models():
return {"object": "list", "data": AVAILABLE_MODELS}
# (Chat, Image Generation, and Moderation endpoints are unchanged and remain correct)
@app.post("/v1/chat/completions", tags=["Chat"])
async def chat_completion(request: ChatRequest):
model_id=MODEL_ALIASES.get(request.model,request.model);chat_id=generate_random_id("chatcmpl-");headers={'accept':'text/event-stream','content-type':'application/json','origin':'https://www.chatwithmono.xyz','referer':'https://www.chatwithmono.xyz/','user-agent':'Mozilla/5.0'}
if request.tools:
tool_prompt=f"""You have access to the following tools. To call a tool, please respond with JSON for a tool call within <tool_call></tool_call> XML tags. Respond in the format {{"name": tool name, "parameters": dictionary of argument name and its value}}. Do not use variables.
Tools: {";".join(f"<tool>{tool}</tool>" for tool in request.tools)}
Response Format for tool call:
<tool_call>
{{"name": <function-name>, "arguments": <args-json-object>}}
</tool_call>"""
if request.messages[0].role=="system":request.messages[0].content+="\n\n"+tool_prompt
else:request.messages.insert(0,Message(role="system",content=tool_prompt))
payload={"messages":[msg.model_dump()for msg in request.messages],"model":model_id}
if request.stream:
async def event_stream():
created=int(time.time());usage_info=None;is_first_chunk=True;tool_call_buffer="";in_tool_call=False
try:
async with httpx.AsyncClient(timeout=120)as client:
async with client.stream("POST",CHAT_API_URL,headers=headers,json=payload)as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line:continue
if line.startswith("0:"):
try:content_piece=json.loads(line[2:])
except json.JSONDecodeError:continue
current_buffer=content_piece
if in_tool_call:current_buffer=tool_call_buffer+content_piece
if"</tool_call>"in current_buffer:
tool_str=current_buffer.split("<tool_call>")[1].split("</tool_call>")[0];tool_json=json.loads(tool_str.strip());delta={"content":None,"tool_calls":[{"index":0,"id":generate_random_id("call_"),"type":"function","function":{"name":tool_json["name"],"arguments":json.dumps(tool_json["parameters"])}}]}
chunk={"id":chat_id,"object":"chat.completion.chunk","created":created,"model":model_id,"choices":[{"index":0,"delta":delta,"finish_reason":None}],"usage":None};yield f"data: {json.dumps(chunk)}\n\n"
in_tool_call=False;tool_call_buffer="";remaining_text=current_buffer.split("</tool_call>",1)[1]
if remaining_text:content_piece=remaining_text
else:continue
if"<tool_call>"in content_piece:
in_tool_call=True;tool_call_buffer+=content_piece.split("<tool_call>",1)[1];text_before=content_piece.split("<tool_call>",1)[0]
if text_before:
delta={"content":text_before,"tool_calls":None};chunk={"id":chat_id,"object":"chat.completion.chunk","created":created,"model":model_id,"choices":[{"index":0,"delta":delta,"finish_reason":None}],"usage":None};yield f"data: {json.dumps(chunk)}\n\n"
if"</tool_call>"not in tool_call_buffer:continue
if not in_tool_call:
delta={"content":content_piece}
if is_first_chunk:delta["role"]="assistant";is_first_chunk=False
chunk={"id":chat_id,"object":"chat.completion.chunk","created":created,"model":model_id,"choices":[{"index":0,"delta":delta,"finish_reason":None}],"usage":None};yield f"data: {json.dumps(chunk)}\n\n"
elif line.startswith(("e:","d:")):
try:usage_info=json.loads(line[2:]).get("usage")
except(json.JSONDecodeError,AttributeError):pass
break
final_usage=None
if usage_info:final_usage={"prompt_tokens":usage_info.get("promptTokens",0),"completion_tokens":usage_info.get("completionTokens",0),"total_tokens":usage_info.get("promptTokens",0)+usage_info.get("completionTokens",0)}
done_chunk={"id":chat_id,"object":"chat.completion.chunk","created":created,"model":model_id,"choices":[{"index":0,"delta":{},"finish_reason":"stop"if not in_tool_call else"tool_calls"}],"usage":final_usage};yield f"data: {json.dumps(done_chunk)}\n\n"
except httpx.HTTPStatusError as e:error_content={"error":{"message":f"Upstream API error: {e.response.status_code}. Details: {e.response.text}","type":"upstream_error","code":str(e.response.status_code)}};yield f"data: {json.dumps(error_content)}\n\n"
finally:yield"data: [DONE]\n\n"
return StreamingResponse(event_stream(),media_type="text/event-stream")
else:
full_response,usage_info="",{}
try:
async with httpx.AsyncClient(timeout=120)as client:
async with client.stream("POST",CHAT_API_URL,headers=headers,json=payload)as response:
response.raise_for_status()
async for chunk in response.aiter_lines():
if chunk.startswith("0:"):
try:full_response+=json.loads(chunk[2:])
except:continue
elif chunk.startswith(("e:","d:")):
try:usage_info=json.loads(chunk[2:]).get("usage",{})
except:continue
tool_calls=None;content_response=full_response
if"<tool_call>"in full_response and"</tool_call>"in full_response:
tool_call_str=full_response.split("<tool_call>")[1].split("</tool_call>")[0];tool_call=json.loads(tool_call_str.strip());tool_calls=[{"id":generate_random_id("call_"),"type":"function","function":{"name":tool_call["name"],"arguments":json.dumps(tool_call["parameters"])}}];content_response=None
return JSONResponse(content={"id":chat_id,"object":"chat.completion","created":int(time.time()),"model":model_id,"choices":[{"index":0,"message":{"role":"assistant","content":content_response,"tool_calls":tool_calls},"finish_reason":"stop"if not tool_calls else"tool_calls"}],"usage":{"prompt_tokens":usage_info.get("promptTokens",0),"completion_tokens":usage_info.get("completionTokens",0),"total_tokens":usage_info.get("promptTokens",0)+usage_info.get("completionTokens",0)}})
except httpx.HTTPStatusError as e:return JSONResponse(status_code=e.response.status_code,content={"error":{"message":f"Upstream API error. Details: {e.response.text}","type":"upstream_error"}})
@app.post("/v1/images/generations", tags=["Images"])
async def generate_images(request: ImageGenerationRequest):
results=[]
try:
async with httpx.AsyncClient(timeout=120)as client:
for _ in range(request.n):
model=request.model or"default"
if model in["gpt-image-1","dall-e-3","dall-e-2","nextlm-image-1"]:
headers={'Content-Type':'application/json','User-Agent':'Mozilla/5.0','Referer':'https://www.chatwithmono.xyz/'};payload={"prompt":request.prompt,"model":model};resp=await client.post(IMAGE_GEN_API_URL,headers=headers,json=payload);resp.raise_for_status();data=resp.json();b64_image=data.get("image")
if not b64_image:return JSONResponse(status_code=502,content={"error":"Missing base64 image in response"})
image_url=f"data:image/png;base64,{b64_image}"
if SNAPZION_API_KEY:
upload_headers={"Authorization":SNAPZION_API_KEY};upload_files={'file':('image.png',base64.b64decode(b64_image),'image/png')};upload_resp=await client.post(SNAPZION_UPLOAD_URL,headers=upload_headers,files=upload_files)
if upload_resp.status_code==200:image_url=upload_resp.json().get("url",image_url)
results.append({"url":image_url,"b64_json":b64_image,"revised_prompt":data.get("revised_prompt")})
else:params={"prompt":request.prompt,"aspect_ratio":request.aspect_ratio,"link":"typegpt.net"};resp=await client.get(IMAGE_API_URL,params=params);resp.raise_for_status();data=resp.json();results.append({"url":data.get("image_link"),"b64_json":data.get("base64_output")})
except httpx.HTTPStatusError as e:return JSONResponse(status_code=502,content={"error":f"Image generation failed. Upstream error: {e.response.status_code}","details":e.response.text})
except Exception as e:return JSONResponse(status_code=500,content={"error":"An internal error occurred.","details":str(e)})
return{"created":int(time.time()),"data":results}
# === REVISED AND FIXED OCR Endpoint ===
@app.post("/v1/ocr", response_model=OcrResponse, tags=["OCR"])
async def perform_ocr(request: OcrRequest):
"""
Performs Optical Character Recognition (OCR) on an image using the Florence-2 model.
Provide an image via a URL or a base64 encoded string.
"""
if not ocr_client:
raise HTTPException(status_code=503, detail="OCR service is not available. Gradio client failed to initialize.")
image_path, temp_file_path = None, None
try:
if request.image_url:
image_path = request.image_url
elif request.image_b64:
image_bytes = base64.b64decode(request.image_b64)
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
temp_file.write(image_bytes)
temp_file_path = temp_file.name
image_path = temp_file_path
prediction = ocr_client.predict(image=handle_file(image_path), task_prompt="OCR", api_name="/process_image")
if not prediction or not isinstance(prediction, tuple) or len(prediction) == 0:
raise HTTPException(status_code=502, detail="Invalid or empty response from OCR service.")
raw_output = prediction[0]
raw_result_dict = {}
# --- START: ROBUST PARSING LOGIC ---
if isinstance(raw_output, str):
try:
# First, try to parse as standard JSON
raw_result_dict = json.loads(raw_output)
except json.JSONDecodeError:
try:
# If JSON fails, try to evaluate as a Python literal (handles single quotes)
parsed_output = ast.literal_eval(raw_output)
if isinstance(parsed_output, dict):
raw_result_dict = parsed_output
else:
# The literal is something else (e.g., a list), wrap it.
raw_result_dict = {"result": str(parsed_output)}
except (ValueError, SyntaxError):
# If all parsing fails, assume the string is the direct OCR text.
raw_result_dict = {"ocr_text": raw_output}
elif isinstance(raw_output, dict):
# It's already a dictionary, use it directly
raw_result_dict = raw_output
else:
# Handle other unexpected data types
raise HTTPException(status_code=502, detail=f"Unexpected data type from OCR service: {type(raw_output)}")
# --- END: ROBUST PARSING LOGIC ---
# Extract text from the dictionary, with fallbacks
ocr_text = raw_result_dict.get("OCR", raw_result_dict.get("ocr_text", str(raw_result_dict)))
return OcrResponse(ocr_text=ocr_text, raw_response=raw_result_dict)
except Exception as e:
if isinstance(e, HTTPException):
raise e
raise HTTPException(status_code=500, detail=f"An error occurred during OCR processing: {str(e)}")
finally:
if temp_file_path and os.path.exists(temp_file_path):
os.unlink(temp_file_path)
@app.post("/v1/moderations", tags=["Moderation"])
async def create_moderation(request: ModerationRequest):
input_texts=[request.input]if isinstance(request.input,str)else request.input
if not input_texts:return JSONResponse(status_code=400,content={"error":{"message":"Request must have at least one input string."}})
headers={'Content-Type':'application/json','User-Agent':'Mozilla/5.0','Referer':'https://www.chatwithmono.xyz/'};results=[]
try:
async with httpx.AsyncClient(timeout=30)as client:
for text_input in input_texts:
resp=await client.post(MODERATION_API_URL,headers=headers,json={"text":text_input});resp.raise_for_status();upstream_data=resp.json();upstream_categories=upstream_data.get("categories",{})
openai_categories={"hate":upstream_categories.get("hate",False),"hate/threatening":False,"harassment":False,"harassment/threatening":False,"self-harm":upstream_categories.get("self-harm",False),"self-harm/intent":False,"self-harm/instructions":False,"sexual":upstream_categories.get("sexual",False),"sexual/minors":False,"violence":upstream_categories.get("violence",False),"violence/graphic":False}
result_item={"flagged":upstream_data.get("overall_sentiment")=="flagged","categories":openai_categories,"category_scores":{k:1.0 if v else 0.0 for k,v in openai_categories.items()}}
if reason:=upstream_data.get("reason"):result_item["reason"]=reason
results.append(result_item)
except httpx.HTTPStatusError as e:return JSONResponse(status_code=502,content={"error":{"message":f"Moderation failed. Upstream error: {e.response.status_code}","details":e.response.text}})
except Exception as e:return JSONResponse(status_code=500,content={"error":{"message":"An internal error occurred during moderation.","details":str(e)}})
return JSONResponse(content={"id":generate_random_id("modr-"),"model":request.model,"results":results})
# --- Main Execution ---
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)