Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -16,16 +16,16 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
16 |
from jose import jwt, JWTError
|
17 |
from pydantic import BaseModel, Field
|
18 |
|
19 |
-
#
|
20 |
if hasattr(torch, "compile"):
|
21 |
torch.compile = (lambda m=None,*_,**__: m if callable(m) else (lambda f: f)) # type: ignore
|
22 |
os.environ.setdefault("TORCHINDUCTOR_DISABLED", "1")
|
23 |
|
24 |
-
#
|
25 |
os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1")
|
26 |
TOKEN_KW = {"trust_remote_code": True}
|
27 |
|
28 |
-
#
|
29 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
|
31 |
FILE_MAP = {"ensamble_1":"ensamble_1","ensamble_2.bin":"ensamble_2.bin","ensamble_3":"ensamble_3"}
|
@@ -38,7 +38,7 @@ LABELS = {i:n for i,n in enumerate([
|
|
38 |
"mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b","opt-30b","opt-350m","opt-6.7b",
|
39 |
"opt-iml-30b","opt-iml-max-1.3b","t0-11b","t0-3b","text-davinci-002","text-davinci-003"])}
|
40 |
|
41 |
-
#
|
42 |
SECRET_KEY = os.getenv("SECRET_KEY")
|
43 |
if not SECRET_KEY:
|
44 |
raise RuntimeError("SECRET_KEY env‑var not set – add it in Space settings → Secrets")
|
@@ -55,32 +55,30 @@ def _verify_jwt(tok:str=Depends(oauth2)):
|
|
55 |
except JWTError:
|
56 |
raise HTTPException(401,"Invalid or expired token")
|
57 |
|
58 |
-
#
|
59 |
print("🔄 Fetching ensemble weights…", flush=True)
|
60 |
paths={k:hf_hub_download(WEIGHT_REPO,f,resume_download=True) for k,f in FILE_MAP.items()}
|
61 |
|
62 |
print("🧩 Building ModernBERT backbone…", flush=True)
|
63 |
-
_cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW)
|
64 |
-
_cfg.num_labels = NUM_LABELS # ➜ classification head = 41
|
65 |
_tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
|
66 |
_models: List[AutoModelForSequenceClassification] = []
|
67 |
for p in paths.values():
|
68 |
m = AutoModelForSequenceClassification.from_pretrained(
|
69 |
BASE_MODEL,
|
70 |
config=_cfg,
|
71 |
-
ignore_mismatched_sizes=True,
|
72 |
**TOKEN_KW,
|
73 |
)
|
74 |
-
|
75 |
-
m.load_state_dict(state) # loads 41‑class ensemble head
|
76 |
m.to(DEVICE).eval()
|
77 |
_models.append(m)
|
78 |
print(f"✅ Ensemble ready on {DEVICE}")
|
79 |
|
80 |
-
#
|
81 |
|
82 |
def _tidy(t:str)->str:
|
83 |
-
t=t.replace("\r\n","\n").replace("\r","\n")
|
84 |
t=re.sub(r"\n\s*\n+","\n\n",t)
|
85 |
t=re.sub(r"[ \t]+"," ",t)
|
86 |
t=re.sub(r"(\w+)-\n(\w+)",r"\1\2",t)
|
@@ -96,14 +94,14 @@ def _infer(seg:str):
|
|
96 |
top3=[LABELS[i] for i in torch.topk(ai_probs,3).indices.tolist()]
|
97 |
return human, ai, top3
|
98 |
|
99 |
-
#
|
100 |
class TokenOut(BaseModel): access_token:str; token_type:str="bearer"
|
101 |
class AnalyseIn(BaseModel): text:str=Field(...,min_length=1)
|
102 |
class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
|
103 |
class AnalyseOut(BaseModel): verdict:str; confidence:float; ai_avg:float; human_avg:float; per_line:List[Line]; highlight_html:str
|
104 |
|
105 |
-
#
|
106 |
-
app=FastAPI(title="Orify Text Detector API",version="1.
|
107 |
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"])
|
108 |
|
109 |
@app.post("/token",response_model=TokenOut)
|
@@ -127,3 +125,9 @@ async def analyse(body:AnalyseIn,_=Depends(_verify_jwt)):
|
|
127 |
badge=(f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict=="AI-generated" else f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
|
128 |
html_out=f"<h3>{badge}</h3><hr>"+"<br>".join(html_parts)
|
129 |
return AnalyseOut(verdict=verdict,confidence=conf,ai_avg=ai_avg,human_avg=human_avg,per_line=per,highlight_html=html_out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from jose import jwt, JWTError
|
17 |
from pydantic import BaseModel, Field
|
18 |
|
19 |
+
# ───────────────────────── torch shim ───────────────────────────────
|
20 |
if hasattr(torch, "compile"):
|
21 |
torch.compile = (lambda m=None,*_,**__: m if callable(m) else (lambda f: f)) # type: ignore
|
22 |
os.environ.setdefault("TORCHINDUCTOR_DISABLED", "1")
|
23 |
|
24 |
+
# ─────────────────────── remote‑code flag ───────────────────────────
|
25 |
os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1")
|
26 |
TOKEN_KW = {"trust_remote_code": True}
|
27 |
|
28 |
+
# ─────────────────────────── config ─────────────────────────────────
|
29 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
|
31 |
FILE_MAP = {"ensamble_1":"ensamble_1","ensamble_2.bin":"ensamble_2.bin","ensamble_3":"ensamble_3"}
|
|
|
38 |
"mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b","opt-30b","opt-350m","opt-6.7b",
|
39 |
"opt-iml-30b","opt-iml-max-1.3b","t0-11b","t0-3b","text-davinci-002","text-davinci-003"])}
|
40 |
|
41 |
+
# ──────────────────────── JWT helpers ──────────────────────────────
|
42 |
SECRET_KEY = os.getenv("SECRET_KEY")
|
43 |
if not SECRET_KEY:
|
44 |
raise RuntimeError("SECRET_KEY env‑var not set – add it in Space settings → Secrets")
|
|
|
55 |
except JWTError:
|
56 |
raise HTTPException(401,"Invalid or expired token")
|
57 |
|
58 |
+
# ─────────────────────── model bootstrap ───────────────────────────
|
59 |
print("🔄 Fetching ensemble weights…", flush=True)
|
60 |
paths={k:hf_hub_download(WEIGHT_REPO,f,resume_download=True) for k,f in FILE_MAP.items()}
|
61 |
|
62 |
print("🧩 Building ModernBERT backbone…", flush=True)
|
63 |
+
_cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW); _cfg.num_labels = NUM_LABELS
|
|
|
64 |
_tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
|
65 |
_models: List[AutoModelForSequenceClassification] = []
|
66 |
for p in paths.values():
|
67 |
m = AutoModelForSequenceClassification.from_pretrained(
|
68 |
BASE_MODEL,
|
69 |
config=_cfg,
|
70 |
+
ignore_mismatched_sizes=True,
|
71 |
**TOKEN_KW,
|
72 |
)
|
73 |
+
m.load_state_dict(torch.load(p, map_location=DEVICE))
|
|
|
74 |
m.to(DEVICE).eval()
|
75 |
_models.append(m)
|
76 |
print(f"✅ Ensemble ready on {DEVICE}")
|
77 |
|
78 |
+
# ───────────────────────── helpers ─────────────────────────────────
|
79 |
|
80 |
def _tidy(t:str)->str:
|
81 |
+
t=t.replace("\r\n","\n").replace("\r", "\n")
|
82 |
t=re.sub(r"\n\s*\n+","\n\n",t)
|
83 |
t=re.sub(r"[ \t]+"," ",t)
|
84 |
t=re.sub(r"(\w+)-\n(\w+)",r"\1\2",t)
|
|
|
94 |
top3=[LABELS[i] for i in torch.topk(ai_probs,3).indices.tolist()]
|
95 |
return human, ai, top3
|
96 |
|
97 |
+
# ───────────────────────── schemas ─────────────────────────────────
|
98 |
class TokenOut(BaseModel): access_token:str; token_type:str="bearer"
|
99 |
class AnalyseIn(BaseModel): text:str=Field(...,min_length=1)
|
100 |
class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
|
101 |
class AnalyseOut(BaseModel): verdict:str; confidence:float; ai_avg:float; human_avg:float; per_line:List[Line]; highlight_html:str
|
102 |
|
103 |
+
# ───────────────────────── FastAPI app ─────────────────────────────
|
104 |
+
app=FastAPI(title="Orify Text Detector API",version="1.2.0")
|
105 |
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"])
|
106 |
|
107 |
@app.post("/token",response_model=TokenOut)
|
|
|
125 |
badge=(f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict=="AI-generated" else f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
|
126 |
html_out=f"<h3>{badge}</h3><hr>"+"<br>".join(html_parts)
|
127 |
return AnalyseOut(verdict=verdict,confidence=conf,ai_avg=ai_avg,human_avg=human_avg,per_line=per,highlight_html=html_out)
|
128 |
+
|
129 |
+
# ─────────────────────────── entrypoint ────────────────────────────
|
130 |
+
if __name__ == "__main__":
|
131 |
+
import uvicorn, sys
|
132 |
+
port=int(os.environ.get("PORT", "7860"))
|
133 |
+
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info", reload=False)
|