Sleepyriizi commited on
Commit
a72abd5
·
verified ·
1 Parent(s): 5a21ae7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -52
app.py CHANGED
@@ -1,10 +1,15 @@
 
1
  from __future__ import annotations
2
  import os, re, html
3
  from datetime import datetime, timedelta
4
  from typing import List
5
 
6
  import torch
7
- from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
 
 
 
 
8
  from huggingface_hub import hf_hub_download
9
  from fastapi import FastAPI, HTTPException, Depends
10
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
@@ -12,23 +17,22 @@ from fastapi.middleware.cors import CORSMiddleware
12
  from jose import jwt, JWTError
13
  from pydantic import BaseModel, Field
14
 
15
- # ── torch.compile shim (optional) ──────────────────────────────────────
16
  if hasattr(torch, "compile"):
17
  torch.compile = (lambda m=None,*_,**__: m if callable(m) else (lambda f: f)) # type: ignore
18
- # keep Inductor disabled on CPU Spaces; harmless on GPU
19
  os.environ.setdefault("TORCHINDUCTOR_DISABLED", "1")
20
 
21
- # ── allow ModernBERT remote code ──────────────────────────────────────
22
- os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1") # required for custom architecture
23
  TOKEN_KW = {"trust_remote_code": True}
24
 
25
- # ── config ────────────────────────────────────────────────────────────
26
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
28
- FILE_MAP = {"ensamble_1":"ensamble_1", "ensamble_2.bin":"ensamble_2.bin", "ensamble_3":"ensamble_3"}
29
  BASE_MODEL = "answerdotai/ModernBERT-base"
30
  NUM_LABELS = 41
31
- LABELS = {i:n for i,n in enumerate([
32
  "13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci","dolly","dolly-v2-12b",
33
  "flan_t5_base","flan_t5_large","flan_t5_small","flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it",
34
  "gpt-3.5-turbo","gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b","llama3-8b",
@@ -38,81 +42,89 @@ LABELS = {i:n for i,n in enumerate([
38
  # ── JWT helpers ───────────────────────────────────────────────────────
39
  SECRET_KEY = os.getenv("SECRET_KEY")
40
  if not SECRET_KEY:
41
- raise RuntimeError("SECRET_KEY env‑var not set (add it in the Space Settings Secrets)")
42
- ALG = "HS256"; EXP_HOURS = 24
43
  oauth2 = OAuth2PasswordBearer(tokenUrl="token")
44
 
45
- def _make_token(sub:str)->str:
46
- return jwt.encode({"sub":sub,"exp":datetime.utcnow()+timedelta(hours=EXP_HOURS)}, SECRET_KEY, algorithm=ALG)
 
47
 
48
- def _verify(tok:str=Depends(oauth2)):
49
  try:
50
- return jwt.decode(tok, SECRET_KEY, algorithms=[ALG])["sub"]
51
  except JWTError:
52
- raise HTTPException(401, "Invalid or expired token")
 
 
 
 
53
 
54
- # ── load ensemble ─────────────────────────────────────────────────────
55
- print("🔄 Downloading weights…", flush=True)
56
- local_paths = {k: hf_hub_download(WEIGHT_REPO, f, resume_download=True) for k, f in FILE_MAP.items()}
57
- print("🧩 Loading ModernBERT checkpoints…", flush=True)
58
  _cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW)
 
59
  _tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
60
  _models: List[AutoModelForSequenceClassification] = []
61
- for p in local_paths.values():
62
- m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, config=_cfg, **TOKEN_KW)
63
- m.load_state_dict(torch.load(p, map_location=DEVICE))
 
 
 
 
 
 
64
  m.to(DEVICE).eval()
65
  _models.append(m)
66
  print(f"✅ Ensemble ready on {DEVICE}")
67
 
68
- # ── helpers ───────────────────────────────────────────────────────────
69
 
70
  def _tidy(t:str)->str:
71
- t=t.replace("\r\n", "\n").replace("\r", "\n")
72
- t=re.sub(r"\n\s*\n+", "\n\n", t)
73
- t=re.sub(r"[ \t]+", " ", t)
74
- t=re.sub(r"(\w+)-\n(\w+)", r"\1\2", t)
75
- t=re.sub(r"(?<!\n)\n(?!\n)", " ", t)
76
  return t.strip()
77
 
78
  def _infer(seg:str):
79
- inp=_tok(seg, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
80
  with torch.no_grad():
81
- probs=torch.stack([torch.softmax(m(**inp).logits, dim=1) for m in _models]).mean(0)[0]
82
  ai_probs=probs.clone(); ai_probs[24]=0
83
  ai=ai_probs.sum().item()*100; human=100-ai
84
- top3=[LABELS[i] for i in torch.topk(ai_probs, 3).indices.tolist()]
85
  return human, ai, top3
86
 
87
- # ── Pydantic schemas ─────────────────────────────────────────────────
88
  class TokenOut(BaseModel): access_token:str; token_type:str="bearer"
89
- class AnalyseIn(BaseModel): text:str=Field(..., min_length=1)
90
  class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
91
  class AnalyseOut(BaseModel): verdict:str; confidence:float; ai_avg:float; human_avg:float; per_line:List[Line]; highlight_html:str
92
 
93
  # ── FastAPI app ───────────────────────────────────────────────────────
94
- app = FastAPI(title="Orify Text Detector API", version="1.1.0")
95
- app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
96
 
97
- @app.post("/token", response_model=TokenOut)
98
- async def token_route(form: OAuth2PasswordRequestForm = Depends()):
99
- return TokenOut(access_token=_make_token(form.username))
100
 
101
- @app.post("/analyse", response_model=AnalyseOut)
102
- async def analyse(req: AnalyseIn, _user=Depends(_verify)):
103
- lines=_tidy(req.text).split("\n"); html_parts=[]; records=[]; h_sum=ai_sum=n=0.0
104
  for ln in lines:
105
  if not ln.strip():
106
  html_parts.append("<br>"); continue
107
- n+=1; human, ai, top3 = _infer(ln); h_sum+=human; ai_sum+=ai
108
- cls = "ai-line" if ai>human else "human-line"
109
- tip = f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai>human else f"Human {human:.2f}%"
110
  html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
111
- reason = (f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}" if ai>human else f"Lexical variety suggests human ({human:.1f}%)")
112
- records.append(Line(text=ln, ai=ai, human=human, top3=top3, reason=reason))
113
- human_avg = h_sum/n if n else 0; ai_avg = ai_sum/n if n else 0
114
- verdict = "AI-generated" if ai_avg>human_avg else "Human-written"; conf=max(human_avg, ai_avg)
115
- badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict=="AI-generated" else f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
116
- html_out = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts)
117
- return AnalyseOut(verdict=verdict, confidence=conf, ai_avg=ai_avg, human_avg=human_avg, per_line=records, highlight_html=html_out)
118
-
 
1
+
2
  from __future__ import annotations
3
  import os, re, html
4
  from datetime import datetime, timedelta
5
  from typing import List
6
 
7
  import torch
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoTokenizer,
11
+ AutoModelForSequenceClassification,
12
+ )
13
  from huggingface_hub import hf_hub_download
14
  from fastapi import FastAPI, HTTPException, Depends
15
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
 
17
  from jose import jwt, JWTError
18
  from pydantic import BaseModel, Field
19
 
20
+ # ── torch shim ─────────────────────────────────────────────────────────
21
  if hasattr(torch, "compile"):
22
  torch.compile = (lambda m=None,*_,**__: m if callable(m) else (lambda f: f)) # type: ignore
 
23
  os.environ.setdefault("TORCHINDUCTOR_DISABLED", "1")
24
 
25
+ # ── remotecode flag ───────────────────────────────────────────────────
26
+ os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1")
27
  TOKEN_KW = {"trust_remote_code": True}
28
 
29
+ # ── config ─────────────────────────────────────────────────────────────
30
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
32
+ FILE_MAP = {"ensamble_1":"ensamble_1","ensamble_2.bin":"ensamble_2.bin","ensamble_3":"ensamble_3"}
33
  BASE_MODEL = "answerdotai/ModernBERT-base"
34
  NUM_LABELS = 41
35
+ LABELS = {i:n for i,n in enumerate([
36
  "13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci","dolly","dolly-v2-12b",
37
  "flan_t5_base","flan_t5_large","flan_t5_small","flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it",
38
  "gpt-3.5-turbo","gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b","llama3-8b",
 
42
  # ── JWT helpers ───────────────────────────────────────────────────────
43
  SECRET_KEY = os.getenv("SECRET_KEY")
44
  if not SECRET_KEY:
45
+ raise RuntimeError("SECRET_KEY env‑var not set add it in Space settings Secrets")
46
+ ALG="HS256"; EXP=24
47
  oauth2 = OAuth2PasswordBearer(tokenUrl="token")
48
 
49
+ def _make_jwt(sub:str)->str:
50
+ payload={"sub":sub,"exp":datetime.utcnow()+timedelta(hours=EXP)}
51
+ return jwt.encode(payload,SECRET_KEY,algorithm=ALG)
52
 
53
+ def _verify_jwt(tok:str=Depends(oauth2)):
54
  try:
55
+ return jwt.decode(tok,SECRET_KEY,algorithms=[ALG])["sub"]
56
  except JWTError:
57
+ raise HTTPException(401,"Invalid or expired token")
58
+
59
+ # ── model bootstrap ───────────────────────────────────────────────────
60
+ print("🔄 Fetching ensemble weights…", flush=True)
61
+ paths={k:hf_hub_download(WEIGHT_REPO,f,resume_download=True) for k,f in FILE_MAP.items()}
62
 
63
+ print("🧩 Building ModernBERT backbone…", flush=True)
 
 
 
64
  _cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW)
65
+ _cfg.num_labels = NUM_LABELS # ➜ classification head = 41
66
  _tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
67
  _models: List[AutoModelForSequenceClassification] = []
68
+ for p in paths.values():
69
+ m = AutoModelForSequenceClassification.from_pretrained(
70
+ BASE_MODEL,
71
+ config=_cfg,
72
+ ignore_mismatched_sizes=True, # skip the 2‑class head in checkpoint
73
+ **TOKEN_KW,
74
+ )
75
+ state=torch.load(p, map_location=DEVICE)
76
+ m.load_state_dict(state) # loads 41‑class ensemble head
77
  m.to(DEVICE).eval()
78
  _models.append(m)
79
  print(f"✅ Ensemble ready on {DEVICE}")
80
 
81
+ # ── helper fns ─────────────────────────────────────────────────────────
82
 
83
  def _tidy(t:str)->str:
84
+ t=t.replace("\r\n","\n").replace("\r","\n")
85
+ t=re.sub(r"\n\s*\n+","\n\n",t)
86
+ t=re.sub(r"[ \t]+"," ",t)
87
+ t=re.sub(r"(\w+)-\n(\w+)",r"\1\2",t)
88
+ t=re.sub(r"(?<!\n)\n(?!\n)"," ",t)
89
  return t.strip()
90
 
91
  def _infer(seg:str):
92
+ inp=_tok(seg,return_tensors="pt",truncation=True,padding=True).to(DEVICE)
93
  with torch.no_grad():
94
+ probs=torch.stack([torch.softmax(m(**inp).logits,1) for m in _models]).mean(0)[0]
95
  ai_probs=probs.clone(); ai_probs[24]=0
96
  ai=ai_probs.sum().item()*100; human=100-ai
97
+ top3=[LABELS[i] for i in torch.topk(ai_probs,3).indices.tolist()]
98
  return human, ai, top3
99
 
100
+ # ── schemas ───────────────────────────────────────────────────────────
101
  class TokenOut(BaseModel): access_token:str; token_type:str="bearer"
102
+ class AnalyseIn(BaseModel): text:str=Field(...,min_length=1)
103
  class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
104
  class AnalyseOut(BaseModel): verdict:str; confidence:float; ai_avg:float; human_avg:float; per_line:List[Line]; highlight_html:str
105
 
106
  # ── FastAPI app ───────────────────────────────────────────────────────
107
+ app=FastAPI(title="Orify Text Detector API",version="1.1.1")
108
+ app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"])
109
 
110
+ @app.post("/token",response_model=TokenOut)
111
+ async def token(form:OAuth2PasswordRequestForm=Depends()):
112
+ return TokenOut(access_token=_make_jwt(form.username))
113
 
114
+ @app.post("/analyse",response_model=AnalyseOut)
115
+ async def analyse(body:AnalyseIn,_=Depends(_verify_jwt)):
116
+ lines=_tidy(body.text).split("\n"); html_parts=[]; per=[]; h_sum=ai_sum=n=0.0
117
  for ln in lines:
118
  if not ln.strip():
119
  html_parts.append("<br>"); continue
120
+ n+=1; human,ai,top3=_infer(ln); h_sum+=human; ai_sum+=ai
121
+ cls="ai-line" if ai>human else "human-line"
122
+ tip=f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai>human else f"Human {human:.2f}%"
123
  html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
124
+ reason=(f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}" if ai>human else f"Lexical variety suggests human ({human:.1f}%)")
125
+ per.append(Line(text=ln,ai=ai,human=human,top3=top3,reason=reason))
126
+ human_avg=h_sum/n if n else 0; ai_avg=ai_sum/n if n else 0
127
+ verdict="AI-generated" if ai_avg>human_avg else "Human-written"; conf=max(human_avg,ai_avg)
128
+ badge=(f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict=="AI-generated" else f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
129
+ html_out=f"<h3>{badge}</h3><hr>"+"<br>".join(html_parts)
130
+ return AnalyseOut(verdict=verdict,confidence=conf,ai_avg=ai_avg,human_avg=human_avg,per_line=per,highlight_html=html_out)