Spaces:
Sleeping
Sleeping
"""Orify Text Detector API – FastAPI + JWT (CPU-only, HF Zero-GPU) | |
Endpoints | |
--------- | |
POST /token → returns JWT (OAuth2 password-grant, demo-auth) | |
POST /analyse → protected; returns verdict JSON + HTML highlights | |
""" | |
from __future__ import annotations | |
import os, re, html | |
from datetime import datetime, timedelta | |
from typing import List | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from huggingface_hub import hf_hub_download | |
from fastapi import FastAPI, HTTPException, Depends | |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
from fastapi.middleware.cors import CORSMiddleware | |
from jose import JWTError, jwt | |
from pydantic import BaseModel, Field | |
# ─── torch.compile shim (HF CPU runtime) ─────────────────────────────── | |
if hasattr(torch, "compile"): | |
torch.compile = (lambda m=None,*a,**kw: m if callable(m) else (lambda f: f)) # type: ignore | |
os.environ["TORCHINDUCTOR_DISABLED"] = "1" | |
# ─── model / weight config ───────────────────────────────────────────── | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights" | |
FILE_MAP = {"ensamble_1":"ensamble_1", | |
"ensamble_2.bin":"ensamble_2.bin", | |
"ensamble_3":"ensamble_3"} | |
BASE_MODEL = "answerdotai/ModernBERT-base" | |
NUM_LABELS = 41 | |
LABELS = {i:n for i,n in enumerate([ | |
"13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci", | |
"dolly","dolly-v2-12b","flan_t5_base","flan_t5_large","flan_t5_small", | |
"flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it","gpt-3.5-turbo", | |
"gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b", | |
"llama3-8b","mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b", | |
"opt-30b","opt-350m","opt-6.7b","opt-iml-30b","opt-iml-max-1.3b", | |
"t0-11b","t0-3b","text-davinci-002","text-davinci-003" | |
])} | |
# ─── JWT helpers ─────────────────────────────────────────────────────── | |
SECRET_KEY = os.getenv("SECRET_KEY") | |
if not SECRET_KEY: | |
raise RuntimeError("Set the SECRET_KEY env-var in Space ➜ Settings ➜ Secrets") | |
ALGORITHM = "HS256" | |
ACCESS_TOKEN_EXPIRE_HOURS = 24 | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
def _create_token(data: dict, exp_hours: int = ACCESS_TOKEN_EXPIRE_HOURS) -> str: | |
to_encode = data.copy() | |
to_encode["exp"] = datetime.utcnow() + timedelta(hours=exp_hours) | |
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
async def _current_user(token: str = Depends(oauth2_scheme)): | |
try: | |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
return payload.get("sub") or "anonymous" | |
except JWTError: | |
raise HTTPException(401, "Invalid or expired token") | |
# ─── load ensemble once ──────────────────────────────────────────────── | |
print("🔄 Downloading weights …", flush=True) | |
local_paths = {k:hf_hub_download(WEIGHT_REPO,f,resume_download=True) | |
for k,f in FILE_MAP.items()} | |
print("🧩 Initialising models …", flush=True) | |
_tok = AutoTokenizer.from_pretrained(BASE_MODEL) | |
_models: List[AutoModelForSequenceClassification] = [] | |
for p in local_paths.values(): | |
m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, | |
num_labels=NUM_LABELS) | |
m.load_state_dict(torch.load(p, map_location=DEVICE)) | |
m.to(DEVICE).eval() | |
_models.append(m) | |
print("✅ Ensemble ready") | |
# ─── tiny helpers ────────────────────────────────────────────────────── | |
def _tidy(text: str) -> str: | |
text = text.replace("\r\n", "\n").replace("\r", "\n") | |
text = re.sub(r"\n\s*\n+", "\n\n", text) | |
text = re.sub(r"[ \t]+", " ", text) | |
text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text) | |
text = re.sub(r"(?<!\n)\n(?!\n)", " ", text) | |
return text.strip() | |
def _infer(seg: str): | |
inp = _tok(seg, return_tensors="pt", truncation=True, padding=True).to(DEVICE) | |
with torch.no_grad(): | |
probs = torch.stack([torch.softmax(m(**inp).logits, dim=1) for m in _models]).mean(0)[0] | |
ai_probs = probs.clone(); ai_probs[24] = 0 # drop explicit “human” label | |
ai = ai_probs.sum().item()*100 | |
human = 100 - ai | |
top3 = [LABELS[i] for i in torch.topk(ai_probs, 3).indices.tolist()] | |
return human, ai, top3 | |
# ─── Pydantic schemas ───────────────────────────────────────────────── | |
from pydantic import BaseModel, Field | |
class Token(BaseModel): | |
access_token: str | |
token_type: str = "bearer" | |
class AnalyseIn(BaseModel): | |
text: str = Field(..., min_length=1) | |
class Line(BaseModel): | |
text: str; ai: float; human: float; top3: List[str]; reason: str | |
class AnalyseOut(BaseModel): | |
verdict: str; confidence: float; ai_avg: float; human_avg: float | |
per_line: List[Line]; highlight_html: str | |
# ─── FastAPI app ─────────────────────────────────────────────────────── | |
app = FastAPI(title="Orify Text Detector API", version="1.0.0") | |
app.add_middleware(CORSMiddleware, | |
allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
async def login(form: OAuth2PasswordRequestForm = Depends()): | |
return Token(access_token=_create_token({"sub": form.username})) | |
async def analyse(data: AnalyseIn, _user=Depends(_current_user)): | |
lines = _tidy(data.text).split("\n") | |
html_parts, per_line = [], [] | |
h_sum = ai_sum = n = 0.0 | |
for ln in lines: | |
if not ln.strip(): | |
html_parts.append("<br>") | |
continue | |
n += 1 | |
human, ai, top3 = _infer(ln) | |
h_sum += human; ai_sum += ai | |
cls = "ai-line" if ai > human else "human-line" | |
tip = f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai > human else f"Human {human:.2f}%" | |
html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>") | |
reason = (f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}" | |
if ai > human else f"Lexical variety suggests human ({human:.1f}%)") | |
per_line.append(Line(text=ln, ai=ai, human=human, top3=top3, reason=reason)) | |
human_avg = h_sum / n if n else 0 | |
ai_avg = ai_sum / n if n else 0 | |
verdict = "AI-generated" if ai_avg > human_avg else "Human-written" | |
confidence = max(human_avg, ai_avg) | |
badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" | |
if verdict == "AI-generated" else | |
f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>") | |
highlight_html = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts) | |
return AnalyseOut(verdict=verdict, confidence=confidence, | |
ai_avg=ai_avg, human_avg=human_avg, | |
per_line=per_line, highlight_html=highlight_html) | |
# ────── local dev: uvicorn app:app --reload ─────────────────────────── | |