Spaces:
Sleeping
Sleeping
Commit
·
db897db
1
Parent(s):
0c4f83b
Add application file
Browse files- app.py +162 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Orify Text Detector API – FastAPI + JWT (CPU-only, HF Zero-GPU)
|
2 |
+
|
3 |
+
Endpoints
|
4 |
+
---------
|
5 |
+
POST /token → returns JWT (OAuth2 password-grant, demo-auth)
|
6 |
+
POST /analyse → protected; returns verdict JSON + HTML highlights
|
7 |
+
"""
|
8 |
+
|
9 |
+
from __future__ import annotations
|
10 |
+
import os, re, html
|
11 |
+
from datetime import datetime, timedelta
|
12 |
+
from typing import List
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
16 |
+
from huggingface_hub import hf_hub_download
|
17 |
+
|
18 |
+
from fastapi import FastAPI, HTTPException, Depends
|
19 |
+
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
20 |
+
from fastapi.middleware.cors import CORSMiddleware
|
21 |
+
from jose import JWTError, jwt
|
22 |
+
from pydantic import BaseModel, Field
|
23 |
+
|
24 |
+
# ─── torch.compile shim (HF CPU runtime) ───────────────────────────────
|
25 |
+
if hasattr(torch, "compile"):
|
26 |
+
torch.compile = (lambda m=None,*a,**kw: m if callable(m) else (lambda f: f)) # type: ignore
|
27 |
+
os.environ["TORCHINDUCTOR_DISABLED"] = "1"
|
28 |
+
|
29 |
+
# ─── model / weight config ─────────────────────────────────────────────
|
30 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
|
32 |
+
FILE_MAP = {"ensamble_1":"ensamble_1",
|
33 |
+
"ensamble_2.bin":"ensamble_2.bin",
|
34 |
+
"ensamble_3":"ensamble_3"}
|
35 |
+
BASE_MODEL = "answerdotai/ModernBERT-base"
|
36 |
+
NUM_LABELS = 41
|
37 |
+
|
38 |
+
LABELS = {i:n for i,n in enumerate([
|
39 |
+
"13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci",
|
40 |
+
"dolly","dolly-v2-12b","flan_t5_base","flan_t5_large","flan_t5_small",
|
41 |
+
"flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it","gpt-3.5-turbo",
|
42 |
+
"gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b",
|
43 |
+
"llama3-8b","mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b",
|
44 |
+
"opt-30b","opt-350m","opt-6.7b","opt-iml-30b","opt-iml-max-1.3b",
|
45 |
+
"t0-11b","t0-3b","text-davinci-002","text-davinci-003"
|
46 |
+
])}
|
47 |
+
|
48 |
+
# ─── JWT helpers ───────────────────────────────────────────────────────
|
49 |
+
SECRET_KEY = os.getenv("SECRET_KEY")
|
50 |
+
if not SECRET_KEY:
|
51 |
+
raise RuntimeError("Set the SECRET_KEY env-var in Space ➜ Settings ➜ Secrets")
|
52 |
+
|
53 |
+
ALGORITHM = "HS256"
|
54 |
+
ACCESS_TOKEN_EXPIRE_HOURS = 24
|
55 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
56 |
+
|
57 |
+
def _create_token(data: dict, exp_hours: int = ACCESS_TOKEN_EXPIRE_HOURS) -> str:
|
58 |
+
to_encode = data.copy()
|
59 |
+
to_encode["exp"] = datetime.utcnow() + timedelta(hours=exp_hours)
|
60 |
+
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
61 |
+
|
62 |
+
async def _current_user(token: str = Depends(oauth2_scheme)):
|
63 |
+
try:
|
64 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
65 |
+
return payload.get("sub") or "anonymous"
|
66 |
+
except JWTError:
|
67 |
+
raise HTTPException(401, "Invalid or expired token")
|
68 |
+
|
69 |
+
# ─── load ensemble once ────────────────────────────────────────────────
|
70 |
+
print("🔄 Downloading weights …", flush=True)
|
71 |
+
local_paths = {k:hf_hub_download(WEIGHT_REPO,f,resume_download=True)
|
72 |
+
for k,f in FILE_MAP.items()}
|
73 |
+
|
74 |
+
print("🧩 Initialising models …", flush=True)
|
75 |
+
_tok = AutoTokenizer.from_pretrained(BASE_MODEL)
|
76 |
+
_models: List[AutoModelForSequenceClassification] = []
|
77 |
+
for p in local_paths.values():
|
78 |
+
m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL,
|
79 |
+
num_labels=NUM_LABELS)
|
80 |
+
m.load_state_dict(torch.load(p, map_location=DEVICE))
|
81 |
+
m.to(DEVICE).eval()
|
82 |
+
_models.append(m)
|
83 |
+
print("✅ Ensemble ready")
|
84 |
+
|
85 |
+
# ─── tiny helpers ──────────────────────────────────────────────────────
|
86 |
+
def _tidy(text: str) -> str:
|
87 |
+
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
88 |
+
text = re.sub(r"\n\s*\n+", "\n\n", text)
|
89 |
+
text = re.sub(r"[ \t]+", " ", text)
|
90 |
+
text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text)
|
91 |
+
text = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
|
92 |
+
return text.strip()
|
93 |
+
|
94 |
+
def _infer(seg: str):
|
95 |
+
inp = _tok(seg, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
|
96 |
+
with torch.no_grad():
|
97 |
+
probs = torch.stack([torch.softmax(m(**inp).logits, dim=1) for m in _models]).mean(0)[0]
|
98 |
+
ai_probs = probs.clone(); ai_probs[24] = 0 # drop explicit “human” label
|
99 |
+
ai = ai_probs.sum().item()*100
|
100 |
+
human = 100 - ai
|
101 |
+
top3 = [LABELS[i] for i in torch.topk(ai_probs, 3).indices.tolist()]
|
102 |
+
return human, ai, top3
|
103 |
+
|
104 |
+
# ─── Pydantic schemas ─────────────────────────────────────────────────
|
105 |
+
from pydantic import BaseModel, Field
|
106 |
+
class Token(BaseModel):
|
107 |
+
access_token: str
|
108 |
+
token_type: str = "bearer"
|
109 |
+
|
110 |
+
class AnalyseIn(BaseModel):
|
111 |
+
text: str = Field(..., min_length=1)
|
112 |
+
|
113 |
+
class Line(BaseModel):
|
114 |
+
text: str; ai: float; human: float; top3: List[str]; reason: str
|
115 |
+
|
116 |
+
class AnalyseOut(BaseModel):
|
117 |
+
verdict: str; confidence: float; ai_avg: float; human_avg: float
|
118 |
+
per_line: List[Line]; highlight_html: str
|
119 |
+
|
120 |
+
# ─── FastAPI app ───────────────────────────────────────────────────────
|
121 |
+
app = FastAPI(title="Orify Text Detector API", version="1.0.0")
|
122 |
+
app.add_middleware(CORSMiddleware,
|
123 |
+
allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
124 |
+
|
125 |
+
@app.post("/token", response_model=Token, summary="Obtain JWT (demo accepts any creds)")
|
126 |
+
async def login(form: OAuth2PasswordRequestForm = Depends()):
|
127 |
+
return Token(access_token=_create_token({"sub": form.username}))
|
128 |
+
|
129 |
+
@app.post("/analyse", response_model=AnalyseOut, summary="Detect AI-generated text")
|
130 |
+
async def analyse(data: AnalyseIn, _user=Depends(_current_user)):
|
131 |
+
lines = _tidy(data.text).split("\n")
|
132 |
+
html_parts, per_line = [], []
|
133 |
+
h_sum = ai_sum = n = 0.0
|
134 |
+
|
135 |
+
for ln in lines:
|
136 |
+
if not ln.strip():
|
137 |
+
html_parts.append("<br>")
|
138 |
+
continue
|
139 |
+
n += 1
|
140 |
+
human, ai, top3 = _infer(ln)
|
141 |
+
h_sum += human; ai_sum += ai
|
142 |
+
cls = "ai-line" if ai > human else "human-line"
|
143 |
+
tip = f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai > human else f"Human {human:.2f}%"
|
144 |
+
html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
|
145 |
+
reason = (f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}"
|
146 |
+
if ai > human else f"Lexical variety suggests human ({human:.1f}%)")
|
147 |
+
per_line.append(Line(text=ln, ai=ai, human=human, top3=top3, reason=reason))
|
148 |
+
|
149 |
+
human_avg = h_sum / n if n else 0
|
150 |
+
ai_avg = ai_sum / n if n else 0
|
151 |
+
verdict = "AI-generated" if ai_avg > human_avg else "Human-written"
|
152 |
+
confidence = max(human_avg, ai_avg)
|
153 |
+
badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>"
|
154 |
+
if verdict == "AI-generated" else
|
155 |
+
f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
|
156 |
+
highlight_html = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts)
|
157 |
+
|
158 |
+
return AnalyseOut(verdict=verdict, confidence=confidence,
|
159 |
+
ai_avg=ai_avg, human_avg=human_avg,
|
160 |
+
per_line=per_line, highlight_html=highlight_html)
|
161 |
+
|
162 |
+
# ────── local dev: uvicorn app:app --reload ───────────────────────────
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
uvicorn[standard]
|
3 |
+
torch==2.2.2 # pin to the CPU build you tested with
|
4 |
+
transformers>=4.40
|
5 |
+
huggingface_hub>=0.23
|
6 |
+
python-jose[cryptography]
|
7 |
+
pydantic
|