Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from __future__ import annotations
|
2 |
import os, re, html
|
3 |
from datetime import datetime, timedelta
|
@@ -5,169 +7,126 @@ from typing import List
|
|
5 |
|
6 |
import torch
|
7 |
from transformers import (
|
8 |
-
AutoTokenizer,
|
9 |
AutoConfig,
|
|
|
10 |
AutoModelForSequenceClassification,
|
11 |
)
|
12 |
from huggingface_hub import hf_hub_download
|
13 |
-
from fastapi import FastAPI,
|
14 |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
15 |
from fastapi.middleware.cors import CORSMiddleware
|
16 |
from jose import jwt, JWTError
|
17 |
from pydantic import BaseModel, Field
|
18 |
|
19 |
-
#
|
20 |
-
|
21 |
-
#
|
22 |
-
if hasattr(torch, "compile"):
|
23 |
-
torch.compile = (lambda m=None, *_, **__: m if callable(m) else (lambda f: f)) # type: ignore
|
24 |
os.environ["TORCHINDUCTOR_DISABLED"] = "1"
|
25 |
|
26 |
-
#
|
27 |
-
#
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
"
|
43 |
-
"dolly", "dolly-v2-12b", "flan_t5_base", "flan_t5_large", "flan_t5_small",
|
44 |
-
"flan_t5_xl", "flan_t5_xxl", "gemma-7b-it", "gemma2-9b-it", "gpt-3.5-turbo",
|
45 |
-
"gpt-35", "gpt-4", "gpt-4o", "gpt-j", "gpt-neox", "human", "llama3-70b",
|
46 |
-
"llama3-8b", "mixtral-8x7b", "opt-1.3b", "opt-125m", "opt-13b", "opt-2.7b",
|
47 |
-
"opt-30b", "opt-350m", "opt-6.7b", "opt-iml-30b", "opt-iml-max-1.3b",
|
48 |
-
"t0-11b", "t0-3b", "text-davinci-002", "text-davinci-003",
|
49 |
])}
|
50 |
|
51 |
-
#
|
52 |
-
#
|
53 |
-
# ---------------------------------------------------------------------
|
54 |
-
SECRET_KEY = os.getenv("SECRET_KEY")
|
55 |
if not SECRET_KEY:
|
56 |
-
raise RuntimeError("SECRET_KEY env
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
|
61 |
-
def
|
62 |
-
|
|
|
63 |
|
64 |
-
def
|
65 |
try:
|
66 |
-
return jwt.decode(
|
67 |
except JWTError:
|
68 |
-
raise HTTPException(401,
|
69 |
-
|
70 |
-
#
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
m
|
82 |
-
m.load_state_dict(torch.load(p, map_location=DEVICE))
|
83 |
-
m.to(DEVICE).eval()
|
84 |
-
_models.append(m)
|
85 |
print("✅ Ensemble ready")
|
86 |
|
87 |
-
#
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
t
|
93 |
-
t
|
94 |
-
t = re.sub(r"[ \t]+", " ", t)
|
95 |
-
t = re.sub(r"(\w+)-\n(\w+)", r"\1\2", t)
|
96 |
-
t = re.sub(r"(?<!\n)\n(?!\n)", " ", t)
|
97 |
return t.strip()
|
98 |
|
99 |
-
def _infer(seg:
|
100 |
-
inp
|
101 |
with torch.no_grad():
|
102 |
-
probs
|
103 |
-
ai_probs
|
104 |
-
ai
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
class Token(BaseModel):
|
113 |
-
access_token: str
|
114 |
-
token_type: str = "bearer"
|
115 |
-
|
116 |
-
class AnalyseIn(BaseModel):
|
117 |
-
text: str = Field(..., min_length=1)
|
118 |
-
|
119 |
-
class Line(BaseModel):
|
120 |
-
text: str
|
121 |
-
ai: float
|
122 |
-
human: float
|
123 |
-
top3: List[str]
|
124 |
-
reason: str
|
125 |
-
|
126 |
class AnalyseOut(BaseModel):
|
127 |
-
verdict:
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
app = FastAPI(title="Orify Text Detector API", version="1.0.0")
|
138 |
-
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
139 |
-
|
140 |
-
@app.post("/token", response_model=Token)
|
141 |
-
async def login(form: OAuth2PasswordRequestForm = Depends()):
|
142 |
-
return Token(access_token=_jwt_create(form.username))
|
143 |
-
|
144 |
-
@app.post("/analyse", response_model=AnalyseOut)
|
145 |
-
async def analyse(data: AnalyseIn, _=Depends(_jwt_verify)):
|
146 |
-
lines = _tidy(data.text).split("\n")
|
147 |
-
per_line, html_parts = [], []
|
148 |
-
h_sum = ai_sum = n = 0.0
|
149 |
|
|
|
|
|
|
|
|
|
150 |
for ln in lines:
|
151 |
-
if not ln.strip():
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
verdict
|
167 |
-
|
168 |
-
badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict == "AI-generated" else
|
169 |
-
f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
|
170 |
-
highlight_html = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts)
|
171 |
-
|
172 |
-
return AnalyseOut(verdict=verdict, confidence=confidence, ai_avg=ai_avg, human_avg=human_avg,
|
173 |
-
per_line=per_line, highlight_html=highlight_html)
|
|
|
1 |
+
"""Orify Text Detector API – FastAPI / JWT / HF Zero-GPU"""
|
2 |
+
|
3 |
from __future__ import annotations
|
4 |
import os, re, html
|
5 |
from datetime import datetime, timedelta
|
|
|
7 |
|
8 |
import torch
|
9 |
from transformers import (
|
|
|
10 |
AutoConfig,
|
11 |
+
AutoTokenizer,
|
12 |
AutoModelForSequenceClassification,
|
13 |
)
|
14 |
from huggingface_hub import hf_hub_download
|
15 |
+
from fastapi import FastAPI, Depends, HTTPException
|
16 |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
17 |
from fastapi.middleware.cors import CORSMiddleware
|
18 |
from jose import jwt, JWTError
|
19 |
from pydantic import BaseModel, Field
|
20 |
|
21 |
+
# ───────────────────── Torch shim ─────────────────────
|
22 |
+
if hasattr(torch, "compile"): # HF CPU image has torch-compile
|
23 |
+
torch.compile = lambda m=None,*_,**__: m if callable(m) else (lambda f: f) # type: ignore
|
|
|
|
|
24 |
os.environ["TORCHINDUCTOR_DISABLED"] = "1"
|
25 |
|
26 |
+
# ───────────────────── Remote-code flag ───────────────
|
27 |
+
os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1") # let Transformers load ModernBERT code
|
28 |
+
TOKEN_KW = {"trust_remote_code": True}
|
29 |
+
|
30 |
+
# ───────────────────── Config ─────────────────────────
|
31 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
|
33 |
+
FILE_MAP = {"ensamble_1":"ensamble_1","ensamble_2.bin":"ensamble_2.bin","ensamble_3":"ensamble_3"}
|
34 |
+
BASE_MODEL = "answerdotai/ModernBERT-base"
|
35 |
+
NUM_LABELS = 41
|
36 |
+
|
37 |
+
LABELS = {i:n for i,n in enumerate([
|
38 |
+
"13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci","dolly","dolly-v2-12b",
|
39 |
+
"flan_t5_base","flan_t5_large","flan_t5_small","flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it",
|
40 |
+
"gpt-3.5-turbo","gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b","llama3-8b",
|
41 |
+
"mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b","opt-30b","opt-350m","opt-6.7b",
|
42 |
+
"opt-iml-30b","opt-iml-max-1.3b","t0-11b","t0-3b","text-davinci-002","text-davinci-003"
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
])}
|
44 |
|
45 |
+
# ───────────────────── JWT helpers ────────────────────
|
46 |
+
SECRET_KEY = os.getenv("SECRET_KEY") # set this in Space → Settings → Secrets
|
|
|
|
|
47 |
if not SECRET_KEY:
|
48 |
+
raise RuntimeError("Set SECRET_KEY env-var")
|
49 |
+
ALG = "HS256"
|
50 |
+
EXP_HOURS = 24
|
51 |
+
oauth_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
52 |
|
53 |
+
def _jwt_make(sub:str)->str:
|
54 |
+
payload={"sub":sub,"exp":datetime.utcnow()+timedelta(hours=EXP_HOURS)}
|
55 |
+
return jwt.encode(payload, SECRET_KEY, algorithm=ALG)
|
56 |
|
57 |
+
def _jwt_ok(token:str=Depends(oauth_scheme))->str:
|
58 |
try:
|
59 |
+
return jwt.decode(token, SECRET_KEY, algorithms=[ALG])["sub"]
|
60 |
except JWTError:
|
61 |
+
raise HTTPException(401,"Invalid or expired token")
|
62 |
+
|
63 |
+
# ───────────────────── Model bootstrap ────────────────
|
64 |
+
print("🔄 Fetching ensemble weights…", flush=True)
|
65 |
+
paths={k:hf_hub_download(WEIGHT_REPO,f,resume_download=True) for k,f in FILE_MAP.items()}
|
66 |
+
|
67 |
+
print("🧩 Loading ModernBERT code…", flush=True)
|
68 |
+
cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW)
|
69 |
+
tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
|
70 |
+
models=[]
|
71 |
+
for p in paths.values():
|
72 |
+
m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, config=cfg, **TOKEN_KW)
|
73 |
+
m.load_state_dict(torch.load(p,map_location=DEVICE))
|
74 |
+
m.to(DEVICE).eval(); models.append(m)
|
|
|
|
|
|
|
75 |
print("✅ Ensemble ready")
|
76 |
|
77 |
+
# ───────────────────── Utils ──────────────────────────
|
78 |
+
def _tidy(t:str)->str:
|
79 |
+
t=t.replace("\r\n","\n").replace("\r","\n")
|
80 |
+
t=re.sub(r"\n\s*\n+","\n\n",t)
|
81 |
+
t=re.sub(r"[ \t]+"," ",t)
|
82 |
+
t=re.sub(r"(\w+)-\n(\w+)",r"\\1\\2",t)
|
83 |
+
t=re.sub(r"(?<!\n)\n(?!\n)"," ",t)
|
|
|
|
|
|
|
84 |
return t.strip()
|
85 |
|
86 |
+
def _infer(seg:str):
|
87 |
+
inp=tok(seg,return_tensors="pt",trunc=True,padding=True).to(DEVICE)
|
88 |
with torch.no_grad():
|
89 |
+
probs=torch.stack([torch.softmax(m(**inp).logits,1) for m in models]).mean(0)[0]
|
90 |
+
ai_probs=probs.clone(); ai_probs[24]=0
|
91 |
+
ai=ai_probs.sum().item()*100; human=100-ai
|
92 |
+
top3=[LABELS[i] for i in torch.topk(ai_probs,3).indices.tolist()]
|
93 |
+
return human,ai,top3
|
94 |
+
|
95 |
+
# ───────────────────── Schemas ────────────────────────
|
96 |
+
class Token(BaseModel): access_token:str; token_type:str="bearer"
|
97 |
+
class AnalyseIn(BaseModel): text:str=Field(...,min_length=1)
|
98 |
+
class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
class AnalyseOut(BaseModel):
|
100 |
+
verdict:str; confidence:float; ai_avg:float; human_avg:float
|
101 |
+
per_line:List[Line]; highlight_html:str
|
102 |
+
|
103 |
+
# ───────────────────── FastAPI app ────────────────────
|
104 |
+
app=FastAPI(title="Orify Text Detector API",version="1.0.0")
|
105 |
+
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"])
|
106 |
+
|
107 |
+
@app.post("/token",response_model=Token,summary="Get JWT")
|
108 |
+
async def login(form:OAuth2PasswordRequestForm=Depends()):
|
109 |
+
return Token(access_token=_jwt_make(form.username))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
+
@app.post("/analyse",response_model=AnalyseOut,summary="Detect AI-generated text")
|
112 |
+
async def analyse(req:AnalyseIn,_user=Depends(_jwt_ok)):
|
113 |
+
lines=_tidy(req.text).split("\\n")
|
114 |
+
html_parts=[]; per_line=[]; h_sum=ai_sum=n=0.0
|
115 |
for ln in lines:
|
116 |
+
if not ln.strip(): html_parts.append("<br>"); continue
|
117 |
+
n+=1
|
118 |
+
h,ai,top3=_infer(ln); h_sum+=h; ai_sum+=ai
|
119 |
+
cls="ai-line" if ai>h else "human-line"
|
120 |
+
tip=f\"AI {ai:.2f}% – Top-3: {', '.join(top3)}\" if ai>h else f\"Human {h:.2f}%\"
|
121 |
+
html_parts.append(f\"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>\")
|
122 |
+
reason=(f\"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}\" if ai>h
|
123 |
+
else f\"Lexical variety suggests human ({h:.1f}%)\")
|
124 |
+
per_line.append(Line(text=ln,ai=ai,human=h,top3=top3,reason=reason))
|
125 |
+
human_avg=h_sum/n if n else 0; ai_avg=ai_sum/n if n else 0
|
126 |
+
verdict=\"AI-generated\" if ai_avg>human_avg else \"Human-written\"
|
127 |
+
conf=max(ai_avg,human_avg)
|
128 |
+
badge=(f\"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>\"
|
129 |
+
if verdict==\"AI-generated\" else
|
130 |
+
f\"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>\")
|
131 |
+
return AnalyseOut(verdict=verdict,confidence=conf,ai_avg=ai_avg,human_avg=human_avg,
|
132 |
+
per_line=per_line,highlight_html=f\"<h3>{badge}</h3><hr>\"+\"<br>\".join(html_parts))
|
|
|
|
|
|
|
|
|
|
|
|