Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,30 +1,36 @@
|
|
1 |
-
"""Orify Text Detector API – FastAPI (CPU‑only)"""
|
2 |
-
|
3 |
from __future__ import annotations
|
4 |
import os, re, html
|
5 |
from datetime import datetime, timedelta
|
6 |
from typing import List
|
7 |
|
8 |
import torch
|
9 |
-
from transformers import
|
|
|
|
|
|
|
|
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
-
|
12 |
from fastapi import FastAPI, HTTPException, Depends
|
13 |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
14 |
from fastapi.middleware.cors import CORSMiddleware
|
15 |
from jose import jwt, JWTError
|
16 |
from pydantic import BaseModel, Field
|
17 |
|
18 |
-
#
|
19 |
# Torch compile shim (CPU runtime)
|
20 |
-
#
|
21 |
if hasattr(torch, "compile"):
|
22 |
torch.compile = (lambda m=None, *_, **__: m if callable(m) else (lambda f: f)) # type: ignore
|
23 |
os.environ["TORCHINDUCTOR_DISABLED"] = "1"
|
24 |
|
25 |
-
#
|
26 |
-
#
|
27 |
-
#
|
|
|
|
|
|
|
|
|
|
|
28 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
|
30 |
FILE_MAP = {"ensamble_1": "ensamble_1", "ensamble_2.bin": "ensamble_2.bin", "ensamble_3": "ensamble_3"}
|
@@ -39,49 +45,48 @@ LABELS = {i: n for i, n in enumerate([
|
|
39 |
"gpt-35", "gpt-4", "gpt-4o", "gpt-j", "gpt-neox", "human", "llama3-70b",
|
40 |
"llama3-8b", "mixtral-8x7b", "opt-1.3b", "opt-125m", "opt-13b", "opt-2.7b",
|
41 |
"opt-30b", "opt-350m", "opt-6.7b", "opt-iml-30b", "opt-iml-max-1.3b",
|
42 |
-
"t0-11b", "t0-3b", "text-davinci-002", "text-davinci-003"
|
43 |
])}
|
44 |
|
45 |
-
#
|
46 |
# JWT helpers
|
47 |
-
#
|
48 |
SECRET_KEY = os.getenv("SECRET_KEY")
|
49 |
if not SECRET_KEY:
|
50 |
raise RuntimeError("SECRET_KEY env‑var not set")
|
51 |
-
|
52 |
-
|
53 |
-
EXP_HOURS = 24
|
54 |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
55 |
|
56 |
-
def _jwt_create(
|
57 |
-
|
58 |
-
return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
59 |
|
60 |
-
def _jwt_verify(
|
61 |
try:
|
62 |
-
return jwt.decode(
|
63 |
except JWTError:
|
64 |
raise HTTPException(401, "Invalid or expired token")
|
65 |
|
66 |
-
#
|
67 |
-
# Load ensemble
|
68 |
-
#
|
69 |
print("🔄 Downloading weights…", flush=True)
|
70 |
-
|
71 |
|
72 |
-
print("🧩
|
|
|
73 |
_tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
|
74 |
_models: List[AutoModelForSequenceClassification] = []
|
75 |
-
for p in
|
76 |
-
m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL,
|
77 |
m.load_state_dict(torch.load(p, map_location=DEVICE))
|
78 |
m.to(DEVICE).eval()
|
79 |
_models.append(m)
|
80 |
print("✅ Ensemble ready")
|
81 |
|
82 |
-
#
|
83 |
# Helpers
|
84 |
-
#
|
85 |
|
86 |
def _tidy(t: str) -> str:
|
87 |
t = t.replace("\r\n", "\n").replace("\r", "\n")
|
@@ -96,14 +101,14 @@ def _infer(seg: str):
|
|
96 |
with torch.no_grad():
|
97 |
probs = torch.stack([torch.softmax(m(**inp).logits, dim=1) for m in _models]).mean(0)[0]
|
98 |
ai_probs = probs.clone(); ai_probs[24] = 0
|
99 |
-
ai
|
100 |
human = 100 - ai
|
101 |
top3 = [LABELS[i] for i in torch.topk(ai_probs, 3).indices.tolist()]
|
102 |
return human, ai, top3
|
103 |
|
104 |
-
#
|
105 |
# Schemas
|
106 |
-
#
|
107 |
class Token(BaseModel):
|
108 |
access_token: str
|
109 |
token_type: str = "bearer"
|
@@ -126,44 +131,43 @@ class AnalyseOut(BaseModel):
|
|
126 |
per_line: List[Line]
|
127 |
highlight_html: str
|
128 |
|
129 |
-
#
|
130 |
# FastAPI
|
131 |
-
#
|
132 |
app = FastAPI(title="Orify Text Detector API", version="1.0.0")
|
133 |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
134 |
|
135 |
@app.post("/token", response_model=Token)
|
136 |
async def login(form: OAuth2PasswordRequestForm = Depends()):
|
137 |
-
return Token(access_token=_jwt_create(
|
138 |
|
139 |
@app.post("/analyse", response_model=AnalyseOut)
|
140 |
-
async def analyse(data: AnalyseIn,
|
141 |
lines = _tidy(data.text).split("\n")
|
142 |
per_line, html_parts = [], []
|
143 |
h_sum = ai_sum = n = 0.0
|
144 |
|
145 |
for ln in lines:
|
146 |
if not ln.strip():
|
147 |
-
html_parts.append("<br>")
|
|
|
148 |
n += 1
|
149 |
human, ai, top3 = _infer(ln)
|
150 |
h_sum += human; ai_sum += ai
|
151 |
cls = "ai-line" if ai > human else "human-line"
|
152 |
tip = f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai > human else f"Human {human:.2f}%"
|
153 |
html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
|
154 |
-
reason = (f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}" if ai > human
|
155 |
-
|
156 |
per_line.append(Line(text=ln, ai=ai, human=human, top3=top3, reason=reason))
|
157 |
|
158 |
human_avg = h_sum / n if n else 0
|
159 |
ai_avg = ai_sum / n if n else 0
|
160 |
verdict = "AI-generated" if ai_avg > human_avg else "Human-written"
|
161 |
confidence = max(ai_avg, human_avg)
|
162 |
-
badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>"
|
163 |
-
if verdict == "AI-generated" else
|
164 |
f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
|
165 |
highlight_html = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts)
|
166 |
|
167 |
-
return AnalyseOut(verdict=verdict, confidence=confidence, ai_avg=ai_avg,
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
1 |
from __future__ import annotations
|
2 |
import os, re, html
|
3 |
from datetime import datetime, timedelta
|
4 |
from typing import List
|
5 |
|
6 |
import torch
|
7 |
+
from transformers import (
|
8 |
+
AutoTokenizer,
|
9 |
+
AutoConfig,
|
10 |
+
AutoModelForSequenceClassification,
|
11 |
+
)
|
12 |
from huggingface_hub import hf_hub_download
|
|
|
13 |
from fastapi import FastAPI, HTTPException, Depends
|
14 |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
15 |
from fastapi.middleware.cors import CORSMiddleware
|
16 |
from jose import jwt, JWTError
|
17 |
from pydantic import BaseModel, Field
|
18 |
|
19 |
+
# ---------------------------------------------------------------------
|
20 |
# Torch compile shim (CPU runtime)
|
21 |
+
# ---------------------------------------------------------------------
|
22 |
if hasattr(torch, "compile"):
|
23 |
torch.compile = (lambda m=None, *_, **__: m if callable(m) else (lambda f: f)) # type: ignore
|
24 |
os.environ["TORCHINDUCTOR_DISABLED"] = "1"
|
25 |
|
26 |
+
# ---------------------------------------------------------------------
|
27 |
+
# Environment flags – enable remote code in Transformers
|
28 |
+
# ---------------------------------------------------------------------
|
29 |
+
os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1") # allow custom ModernBERT classes
|
30 |
+
|
31 |
+
# ---------------------------------------------------------------------
|
32 |
+
# Model / weight config
|
33 |
+
# ---------------------------------------------------------------------
|
34 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
|
36 |
FILE_MAP = {"ensamble_1": "ensamble_1", "ensamble_2.bin": "ensamble_2.bin", "ensamble_3": "ensamble_3"}
|
|
|
45 |
"gpt-35", "gpt-4", "gpt-4o", "gpt-j", "gpt-neox", "human", "llama3-70b",
|
46 |
"llama3-8b", "mixtral-8x7b", "opt-1.3b", "opt-125m", "opt-13b", "opt-2.7b",
|
47 |
"opt-30b", "opt-350m", "opt-6.7b", "opt-iml-30b", "opt-iml-max-1.3b",
|
48 |
+
"t0-11b", "t0-3b", "text-davinci-002", "text-davinci-003",
|
49 |
])}
|
50 |
|
51 |
+
# ---------------------------------------------------------------------
|
52 |
# JWT helpers
|
53 |
+
# ---------------------------------------------------------------------
|
54 |
SECRET_KEY = os.getenv("SECRET_KEY")
|
55 |
if not SECRET_KEY:
|
56 |
raise RuntimeError("SECRET_KEY env‑var not set")
|
57 |
+
ALGO = "HS256"
|
58 |
+
EXP_H = 24
|
|
|
59 |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
60 |
|
61 |
+
def _jwt_create(sub: str) -> str:
|
62 |
+
return jwt.encode({"sub": sub, "exp": datetime.utcnow() + timedelta(hours=EXP_H)}, SECRET_KEY, algorithm=ALGO)
|
|
|
63 |
|
64 |
+
def _jwt_verify(tok: str = Depends(oauth2_scheme)) -> str:
|
65 |
try:
|
66 |
+
return jwt.decode(tok, SECRET_KEY, algorithms=[ALGO])["sub"]
|
67 |
except JWTError:
|
68 |
raise HTTPException(401, "Invalid or expired token")
|
69 |
|
70 |
+
# ---------------------------------------------------------------------
|
71 |
+
# Load tokenizer + config + ensemble
|
72 |
+
# ---------------------------------------------------------------------
|
73 |
print("🔄 Downloading weights…", flush=True)
|
74 |
+
local_paths = {k: hf_hub_download(WEIGHT_REPO, v, resume_download=True) for k, v in FILE_MAP.items()}
|
75 |
|
76 |
+
print("🧩 Loading ModernBERT remote code…", flush=True)
|
77 |
+
_cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW)
|
78 |
_tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
|
79 |
_models: List[AutoModelForSequenceClassification] = []
|
80 |
+
for p in local_paths.values():
|
81 |
+
m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, config=_cfg, **TOKEN_KW)
|
82 |
m.load_state_dict(torch.load(p, map_location=DEVICE))
|
83 |
m.to(DEVICE).eval()
|
84 |
_models.append(m)
|
85 |
print("✅ Ensemble ready")
|
86 |
|
87 |
+
# ---------------------------------------------------------------------
|
88 |
# Helpers
|
89 |
+
# ---------------------------------------------------------------------
|
90 |
|
91 |
def _tidy(t: str) -> str:
|
92 |
t = t.replace("\r\n", "\n").replace("\r", "\n")
|
|
|
101 |
with torch.no_grad():
|
102 |
probs = torch.stack([torch.softmax(m(**inp).logits, dim=1) for m in _models]).mean(0)[0]
|
103 |
ai_probs = probs.clone(); ai_probs[24] = 0
|
104 |
+
ai = ai_probs.sum().item() * 100
|
105 |
human = 100 - ai
|
106 |
top3 = [LABELS[i] for i in torch.topk(ai_probs, 3).indices.tolist()]
|
107 |
return human, ai, top3
|
108 |
|
109 |
+
# ---------------------------------------------------------------------
|
110 |
# Schemas
|
111 |
+
# ---------------------------------------------------------------------
|
112 |
class Token(BaseModel):
|
113 |
access_token: str
|
114 |
token_type: str = "bearer"
|
|
|
131 |
per_line: List[Line]
|
132 |
highlight_html: str
|
133 |
|
134 |
+
# ---------------------------------------------------------------------
|
135 |
# FastAPI
|
136 |
+
# ---------------------------------------------------------------------
|
137 |
app = FastAPI(title="Orify Text Detector API", version="1.0.0")
|
138 |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
139 |
|
140 |
@app.post("/token", response_model=Token)
|
141 |
async def login(form: OAuth2PasswordRequestForm = Depends()):
|
142 |
+
return Token(access_token=_jwt_create(form.username))
|
143 |
|
144 |
@app.post("/analyse", response_model=AnalyseOut)
|
145 |
+
async def analyse(data: AnalyseIn, _=Depends(_jwt_verify)):
|
146 |
lines = _tidy(data.text).split("\n")
|
147 |
per_line, html_parts = [], []
|
148 |
h_sum = ai_sum = n = 0.0
|
149 |
|
150 |
for ln in lines:
|
151 |
if not ln.strip():
|
152 |
+
html_parts.append("<br>")
|
153 |
+
continue
|
154 |
n += 1
|
155 |
human, ai, top3 = _infer(ln)
|
156 |
h_sum += human; ai_sum += ai
|
157 |
cls = "ai-line" if ai > human else "human-line"
|
158 |
tip = f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai > human else f"Human {human:.2f}%"
|
159 |
html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
|
160 |
+
reason = (f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}" if ai > human else
|
161 |
+
f"Lexical variety suggests human ({human:.1f}%)")
|
162 |
per_line.append(Line(text=ln, ai=ai, human=human, top3=top3, reason=reason))
|
163 |
|
164 |
human_avg = h_sum / n if n else 0
|
165 |
ai_avg = ai_sum / n if n else 0
|
166 |
verdict = "AI-generated" if ai_avg > human_avg else "Human-written"
|
167 |
confidence = max(ai_avg, human_avg)
|
168 |
+
badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict == "AI-generated" else
|
|
|
169 |
f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
|
170 |
highlight_html = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts)
|
171 |
|
172 |
+
return AnalyseOut(verdict=verdict, confidence=confidence, ai_avg=ai_avg, human_avg=human_avg,
|
173 |
+
per_line=per_line, highlight_html=highlight_html)
|
|