Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,15 @@
|
|
|
|
1 |
from __future__ import annotations
|
2 |
import os, re, html
|
3 |
from datetime import datetime, timedelta
|
4 |
from typing import List
|
5 |
|
6 |
import torch
|
7 |
-
from transformers import
|
|
|
|
|
|
|
|
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
from fastapi import FastAPI, HTTPException, Depends
|
10 |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
@@ -12,23 +17,22 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
12 |
from jose import jwt, JWTError
|
13 |
from pydantic import BaseModel, Field
|
14 |
|
15 |
-
# ── torch
|
16 |
if hasattr(torch, "compile"):
|
17 |
torch.compile = (lambda m=None,*_,**__: m if callable(m) else (lambda f: f)) # type: ignore
|
18 |
-
# keep Inductor disabled on CPU Spaces; harmless on GPU
|
19 |
os.environ.setdefault("TORCHINDUCTOR_DISABLED", "1")
|
20 |
|
21 |
-
# ──
|
22 |
-
os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1")
|
23 |
TOKEN_KW = {"trust_remote_code": True}
|
24 |
|
25 |
-
# ── config
|
26 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
27 |
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
|
28 |
-
FILE_MAP = {"ensamble_1":"ensamble_1",
|
29 |
BASE_MODEL = "answerdotai/ModernBERT-base"
|
30 |
NUM_LABELS = 41
|
31 |
-
LABELS
|
32 |
"13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci","dolly","dolly-v2-12b",
|
33 |
"flan_t5_base","flan_t5_large","flan_t5_small","flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it",
|
34 |
"gpt-3.5-turbo","gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b","llama3-8b",
|
@@ -38,81 +42,89 @@ LABELS = {i:n for i,n in enumerate([
|
|
38 |
# ── JWT helpers ───────────────────────────────────────────────────────
|
39 |
SECRET_KEY = os.getenv("SECRET_KEY")
|
40 |
if not SECRET_KEY:
|
41 |
-
raise RuntimeError("SECRET_KEY env‑var not set
|
42 |
-
ALG
|
43 |
oauth2 = OAuth2PasswordBearer(tokenUrl="token")
|
44 |
|
45 |
-
def
|
46 |
-
|
|
|
47 |
|
48 |
-
def
|
49 |
try:
|
50 |
-
return jwt.decode(tok,
|
51 |
except JWTError:
|
52 |
-
raise HTTPException(401,
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
|
55 |
-
print("🔄 Downloading weights…", flush=True)
|
56 |
-
local_paths = {k: hf_hub_download(WEIGHT_REPO, f, resume_download=True) for k, f in FILE_MAP.items()}
|
57 |
-
print("🧩 Loading ModernBERT checkpoints…", flush=True)
|
58 |
_cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW)
|
|
|
59 |
_tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
|
60 |
_models: List[AutoModelForSequenceClassification] = []
|
61 |
-
for p in
|
62 |
-
m = AutoModelForSequenceClassification.from_pretrained(
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
m.to(DEVICE).eval()
|
65 |
_models.append(m)
|
66 |
print(f"✅ Ensemble ready on {DEVICE}")
|
67 |
|
68 |
-
# ──
|
69 |
|
70 |
def _tidy(t:str)->str:
|
71 |
-
t=t.replace("\r\n",
|
72 |
-
t=re.sub(r"\n\s*\n+",
|
73 |
-
t=re.sub(r"[ \t]+",
|
74 |
-
t=re.sub(r"(\w+)-\n(\w+)",
|
75 |
-
t=re.sub(r"(?<!\n)\n(?!\n)",
|
76 |
return t.strip()
|
77 |
|
78 |
def _infer(seg:str):
|
79 |
-
inp=_tok(seg,
|
80 |
with torch.no_grad():
|
81 |
-
probs=torch.stack([torch.softmax(m(**inp).logits,
|
82 |
ai_probs=probs.clone(); ai_probs[24]=0
|
83 |
ai=ai_probs.sum().item()*100; human=100-ai
|
84 |
-
top3=[LABELS[i] for i in torch.topk(ai_probs,
|
85 |
return human, ai, top3
|
86 |
|
87 |
-
# ──
|
88 |
class TokenOut(BaseModel): access_token:str; token_type:str="bearer"
|
89 |
-
class AnalyseIn(BaseModel): text:str=Field(...,
|
90 |
class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
|
91 |
class AnalyseOut(BaseModel): verdict:str; confidence:float; ai_avg:float; human_avg:float; per_line:List[Line]; highlight_html:str
|
92 |
|
93 |
# ── FastAPI app ───────────────────────────────────────────────────────
|
94 |
-
app
|
95 |
-
app.add_middleware(CORSMiddleware,
|
96 |
|
97 |
-
@app.post("/token",
|
98 |
-
async def
|
99 |
-
return TokenOut(access_token=
|
100 |
|
101 |
-
@app.post("/analyse",
|
102 |
-
async def analyse(
|
103 |
-
lines=_tidy(
|
104 |
for ln in lines:
|
105 |
if not ln.strip():
|
106 |
html_parts.append("<br>"); continue
|
107 |
-
n+=1; human,
|
108 |
-
cls
|
109 |
-
tip
|
110 |
html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
|
111 |
-
reason
|
112 |
-
|
113 |
-
human_avg
|
114 |
-
verdict
|
115 |
-
badge
|
116 |
-
html_out
|
117 |
-
return AnalyseOut(verdict=verdict,
|
118 |
-
|
|
|
1 |
+
|
2 |
from __future__ import annotations
|
3 |
import os, re, html
|
4 |
from datetime import datetime, timedelta
|
5 |
from typing import List
|
6 |
|
7 |
import torch
|
8 |
+
from transformers import (
|
9 |
+
AutoConfig,
|
10 |
+
AutoTokenizer,
|
11 |
+
AutoModelForSequenceClassification,
|
12 |
+
)
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
from fastapi import FastAPI, HTTPException, Depends
|
15 |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
|
17 |
from jose import jwt, JWTError
|
18 |
from pydantic import BaseModel, Field
|
19 |
|
20 |
+
# ── torch shim ─────────────────────────────────────────────────────────
|
21 |
if hasattr(torch, "compile"):
|
22 |
torch.compile = (lambda m=None,*_,**__: m if callable(m) else (lambda f: f)) # type: ignore
|
|
|
23 |
os.environ.setdefault("TORCHINDUCTOR_DISABLED", "1")
|
24 |
|
25 |
+
# ── remote‑code flag ───────────────────────────────────────────────────
|
26 |
+
os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1")
|
27 |
TOKEN_KW = {"trust_remote_code": True}
|
28 |
|
29 |
+
# ── config ─────────────────────────────────────────────────────────────
|
30 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
|
32 |
+
FILE_MAP = {"ensamble_1":"ensamble_1","ensamble_2.bin":"ensamble_2.bin","ensamble_3":"ensamble_3"}
|
33 |
BASE_MODEL = "answerdotai/ModernBERT-base"
|
34 |
NUM_LABELS = 41
|
35 |
+
LABELS = {i:n for i,n in enumerate([
|
36 |
"13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci","dolly","dolly-v2-12b",
|
37 |
"flan_t5_base","flan_t5_large","flan_t5_small","flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it",
|
38 |
"gpt-3.5-turbo","gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b","llama3-8b",
|
|
|
42 |
# ── JWT helpers ───────────────────────────────────────────────────────
|
43 |
SECRET_KEY = os.getenv("SECRET_KEY")
|
44 |
if not SECRET_KEY:
|
45 |
+
raise RuntimeError("SECRET_KEY env‑var not set – add it in Space settings → Secrets")
|
46 |
+
ALG="HS256"; EXP=24
|
47 |
oauth2 = OAuth2PasswordBearer(tokenUrl="token")
|
48 |
|
49 |
+
def _make_jwt(sub:str)->str:
|
50 |
+
payload={"sub":sub,"exp":datetime.utcnow()+timedelta(hours=EXP)}
|
51 |
+
return jwt.encode(payload,SECRET_KEY,algorithm=ALG)
|
52 |
|
53 |
+
def _verify_jwt(tok:str=Depends(oauth2)):
|
54 |
try:
|
55 |
+
return jwt.decode(tok,SECRET_KEY,algorithms=[ALG])["sub"]
|
56 |
except JWTError:
|
57 |
+
raise HTTPException(401,"Invalid or expired token")
|
58 |
+
|
59 |
+
# ── model bootstrap ───────────────────────────────────────────────────
|
60 |
+
print("🔄 Fetching ensemble weights…", flush=True)
|
61 |
+
paths={k:hf_hub_download(WEIGHT_REPO,f,resume_download=True) for k,f in FILE_MAP.items()}
|
62 |
|
63 |
+
print("🧩 Building ModernBERT backbone…", flush=True)
|
|
|
|
|
|
|
64 |
_cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW)
|
65 |
+
_cfg.num_labels = NUM_LABELS # ➜ classification head = 41
|
66 |
_tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
|
67 |
_models: List[AutoModelForSequenceClassification] = []
|
68 |
+
for p in paths.values():
|
69 |
+
m = AutoModelForSequenceClassification.from_pretrained(
|
70 |
+
BASE_MODEL,
|
71 |
+
config=_cfg,
|
72 |
+
ignore_mismatched_sizes=True, # skip the 2‑class head in checkpoint
|
73 |
+
**TOKEN_KW,
|
74 |
+
)
|
75 |
+
state=torch.load(p, map_location=DEVICE)
|
76 |
+
m.load_state_dict(state) # loads 41‑class ensemble head
|
77 |
m.to(DEVICE).eval()
|
78 |
_models.append(m)
|
79 |
print(f"✅ Ensemble ready on {DEVICE}")
|
80 |
|
81 |
+
# ── helper fns ─────────────────────────────────────────────────────────
|
82 |
|
83 |
def _tidy(t:str)->str:
|
84 |
+
t=t.replace("\r\n","\n").replace("\r","\n")
|
85 |
+
t=re.sub(r"\n\s*\n+","\n\n",t)
|
86 |
+
t=re.sub(r"[ \t]+"," ",t)
|
87 |
+
t=re.sub(r"(\w+)-\n(\w+)",r"\1\2",t)
|
88 |
+
t=re.sub(r"(?<!\n)\n(?!\n)"," ",t)
|
89 |
return t.strip()
|
90 |
|
91 |
def _infer(seg:str):
|
92 |
+
inp=_tok(seg,return_tensors="pt",truncation=True,padding=True).to(DEVICE)
|
93 |
with torch.no_grad():
|
94 |
+
probs=torch.stack([torch.softmax(m(**inp).logits,1) for m in _models]).mean(0)[0]
|
95 |
ai_probs=probs.clone(); ai_probs[24]=0
|
96 |
ai=ai_probs.sum().item()*100; human=100-ai
|
97 |
+
top3=[LABELS[i] for i in torch.topk(ai_probs,3).indices.tolist()]
|
98 |
return human, ai, top3
|
99 |
|
100 |
+
# ── schemas ───────────────────────────────────────────────────────────
|
101 |
class TokenOut(BaseModel): access_token:str; token_type:str="bearer"
|
102 |
+
class AnalyseIn(BaseModel): text:str=Field(...,min_length=1)
|
103 |
class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
|
104 |
class AnalyseOut(BaseModel): verdict:str; confidence:float; ai_avg:float; human_avg:float; per_line:List[Line]; highlight_html:str
|
105 |
|
106 |
# ── FastAPI app ───────────────────────────────────────────────────────
|
107 |
+
app=FastAPI(title="Orify Text Detector API",version="1.1.1")
|
108 |
+
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"])
|
109 |
|
110 |
+
@app.post("/token",response_model=TokenOut)
|
111 |
+
async def token(form:OAuth2PasswordRequestForm=Depends()):
|
112 |
+
return TokenOut(access_token=_make_jwt(form.username))
|
113 |
|
114 |
+
@app.post("/analyse",response_model=AnalyseOut)
|
115 |
+
async def analyse(body:AnalyseIn,_=Depends(_verify_jwt)):
|
116 |
+
lines=_tidy(body.text).split("\n"); html_parts=[]; per=[]; h_sum=ai_sum=n=0.0
|
117 |
for ln in lines:
|
118 |
if not ln.strip():
|
119 |
html_parts.append("<br>"); continue
|
120 |
+
n+=1; human,ai,top3=_infer(ln); h_sum+=human; ai_sum+=ai
|
121 |
+
cls="ai-line" if ai>human else "human-line"
|
122 |
+
tip=f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai>human else f"Human {human:.2f}%"
|
123 |
html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
|
124 |
+
reason=(f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}" if ai>human else f"Lexical variety suggests human ({human:.1f}%)")
|
125 |
+
per.append(Line(text=ln,ai=ai,human=human,top3=top3,reason=reason))
|
126 |
+
human_avg=h_sum/n if n else 0; ai_avg=ai_sum/n if n else 0
|
127 |
+
verdict="AI-generated" if ai_avg>human_avg else "Human-written"; conf=max(human_avg,ai_avg)
|
128 |
+
badge=(f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict=="AI-generated" else f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
|
129 |
+
html_out=f"<h3>{badge}</h3><hr>"+"<br>".join(html_parts)
|
130 |
+
return AnalyseOut(verdict=verdict,confidence=conf,ai_avg=ai_avg,human_avg=human_avg,per_line=per,highlight_html=html_out)
|
|