Sleepyriizi commited on
Commit
5a21ae7
·
verified ·
1 Parent(s): 43f47c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -98
app.py CHANGED
@@ -1,132 +1,118 @@
1
- """Orify Text Detector API – FastAPI / JWT / HF Zero-GPU"""
2
-
3
  from __future__ import annotations
4
  import os, re, html
5
  from datetime import datetime, timedelta
6
  from typing import List
7
 
8
  import torch
9
- from transformers import (
10
- AutoConfig,
11
- AutoTokenizer,
12
- AutoModelForSequenceClassification,
13
- )
14
  from huggingface_hub import hf_hub_download
15
- from fastapi import FastAPI, Depends, HTTPException
16
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
17
  from fastapi.middleware.cors import CORSMiddleware
18
  from jose import jwt, JWTError
19
  from pydantic import BaseModel, Field
20
 
21
- # ───────────────────── Torch shim ─────────────────────
22
- if hasattr(torch, "compile"): # HF CPU image has torch-compile
23
- torch.compile = lambda m=None,*_,**__: m if callable(m) else (lambda f: f) # type: ignore
24
- os.environ["TORCHINDUCTOR_DISABLED"] = "1"
 
25
 
26
- # ───────────────────── Remote-code flag ───────────────
27
- os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1") # let Transformers load ModernBERT code
28
  TOKEN_KW = {"trust_remote_code": True}
29
 
30
- # ───────────────────── Config ─────────────────────────
31
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
- WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
33
- FILE_MAP = {"ensamble_1":"ensamble_1","ensamble_2.bin":"ensamble_2.bin","ensamble_3":"ensamble_3"}
34
- BASE_MODEL = "answerdotai/ModernBERT-base"
35
- NUM_LABELS = 41
36
-
37
  LABELS = {i:n for i,n in enumerate([
38
  "13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci","dolly","dolly-v2-12b",
39
  "flan_t5_base","flan_t5_large","flan_t5_small","flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it",
40
  "gpt-3.5-turbo","gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b","llama3-8b",
41
  "mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b","opt-30b","opt-350m","opt-6.7b",
42
- "opt-iml-30b","opt-iml-max-1.3b","t0-11b","t0-3b","text-davinci-002","text-davinci-003"
43
- ])}
44
 
45
- # ───────────────────── JWT helpers ────────────────────
46
- SECRET_KEY = os.getenv("SECRET_KEY") # set this in Space → Settings → Secrets
47
  if not SECRET_KEY:
48
- raise RuntimeError("Set SECRET_KEY env-var")
49
- ALG = "HS256"
50
- EXP_HOURS = 24
51
- oauth_scheme = OAuth2PasswordBearer(tokenUrl="token")
52
 
53
- def _jwt_make(sub:str)->str:
54
- payload={"sub":sub,"exp":datetime.utcnow()+timedelta(hours=EXP_HOURS)}
55
- return jwt.encode(payload, SECRET_KEY, algorithm=ALG)
56
 
57
- def _jwt_ok(token:str=Depends(oauth_scheme))->str:
58
  try:
59
- return jwt.decode(token, SECRET_KEY, algorithms=[ALG])["sub"]
60
  except JWTError:
61
- raise HTTPException(401,"Invalid or expired token")
62
-
63
- # ───────────────────── Model bootstrap ────────────────
64
- print("🔄 Fetching ensemble weights…", flush=True)
65
- paths={k:hf_hub_download(WEIGHT_REPO,f,resume_download=True) for k,f in FILE_MAP.items()}
66
-
67
- print("🧩 Loading ModernBERT code…", flush=True)
68
- cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW)
69
- tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
70
- models=[]
71
- for p in paths.values():
72
- m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, config=cfg, **TOKEN_KW)
73
- m.load_state_dict(torch.load(p,map_location=DEVICE))
74
- m.to(DEVICE).eval(); models.append(m)
75
- print("✅ Ensemble ready")
76
-
77
- # ───────────────────── Utils ──────────────────────────
 
78
  def _tidy(t:str)->str:
79
- t=t.replace("\r\n","\n").replace("\r","\n")
80
- t=re.sub(r"\n\s*\n+","\n\n",t)
81
- t=re.sub(r"[ \t]+"," ",t)
82
- t=re.sub(r"(\w+)-\n(\w+)",r"\\1\\2",t)
83
- t=re.sub(r"(?<!\n)\n(?!\n)"," ",t)
84
  return t.strip()
85
 
86
  def _infer(seg:str):
87
- inp=tok(seg,return_tensors="pt",trunc=True,padding=True).to(DEVICE)
88
  with torch.no_grad():
89
- probs=torch.stack([torch.softmax(m(**inp).logits,1) for m in models]).mean(0)[0]
90
  ai_probs=probs.clone(); ai_probs[24]=0
91
  ai=ai_probs.sum().item()*100; human=100-ai
92
- top3=[LABELS[i] for i in torch.topk(ai_probs,3).indices.tolist()]
93
- return human,ai,top3
94
-
95
- # ───────────────────── Schemas ────────────────────────
96
- class Token(BaseModel): access_token:str; token_type:str="bearer"
97
- class AnalyseIn(BaseModel): text:str=Field(...,min_length=1)
98
- class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
99
- class AnalyseOut(BaseModel):
100
- verdict:str; confidence:float; ai_avg:float; human_avg:float
101
- per_line:List[Line]; highlight_html:str
102
-
103
- # ───────────────────── FastAPI app ────────────────────
104
- app=FastAPI(title="Orify Text Detector API",version="1.0.0")
105
- app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"])
106
-
107
- @app.post("/token",response_model=Token,summary="Get JWT")
108
- async def login(form:OAuth2PasswordRequestForm=Depends()):
109
- return Token(access_token=_jwt_make(form.username))
110
-
111
- @app.post("/analyse",response_model=AnalyseOut,summary="Detect AI-generated text")
112
- async def analyse(req:AnalyseIn,_user=Depends(_jwt_ok)):
113
- lines=_tidy(req.text).split("\\n")
114
- html_parts=[]; per_line=[]; h_sum=ai_sum=n=0.0
115
  for ln in lines:
116
- if not ln.strip(): html_parts.append("<br>"); continue
117
- n+=1
118
- h,ai,top3=_infer(ln); h_sum+=h; ai_sum+=ai
119
- cls="ai-line" if ai>h else "human-line"
120
- tip=f\"AI {ai:.2f}% – Top-3: {', '.join(top3)}\" if ai>h else f\"Human {h:.2f}%\"
121
- html_parts.append(f\"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>\")
122
- reason=(f\"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}\" if ai>h
123
- else f\"Lexical variety suggests human ({h:.1f}%)\")
124
- per_line.append(Line(text=ln,ai=ai,human=h,top3=top3,reason=reason))
125
- human_avg=h_sum/n if n else 0; ai_avg=ai_sum/n if n else 0
126
- verdict=\"AI-generated\" if ai_avg>human_avg else \"Human-written\"
127
- conf=max(ai_avg,human_avg)
128
- badge=(f\"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>\"
129
- if verdict==\"AI-generated\" else
130
- f\"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>\")
131
- return AnalyseOut(verdict=verdict,confidence=conf,ai_avg=ai_avg,human_avg=human_avg,
132
- per_line=per_line,highlight_html=f\"<h3>{badge}</h3><hr>\"+\"<br>\".join(html_parts))
 
 
 
1
  from __future__ import annotations
2
  import os, re, html
3
  from datetime import datetime, timedelta
4
  from typing import List
5
 
6
  import torch
7
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
 
 
 
 
8
  from huggingface_hub import hf_hub_download
9
+ from fastapi import FastAPI, HTTPException, Depends
10
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from jose import jwt, JWTError
13
  from pydantic import BaseModel, Field
14
 
15
+ # ── torch.compile shim (optional) ──────────────────────────────────────
16
+ if hasattr(torch, "compile"):
17
+ torch.compile = (lambda m=None,*_,**__: m if callable(m) else (lambda f: f)) # type: ignore
18
+ # keep Inductor disabled on CPU Spaces; harmless on GPU
19
+ os.environ.setdefault("TORCHINDUCTOR_DISABLED", "1")
20
 
21
+ # ── allow ModernBERT remote code ──────────────────────────────────────
22
+ os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1") # required for custom architecture
23
  TOKEN_KW = {"trust_remote_code": True}
24
 
25
+ # ── config ────────────────────────────────────────────────────────────
26
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
28
+ FILE_MAP = {"ensamble_1":"ensamble_1", "ensamble_2.bin":"ensamble_2.bin", "ensamble_3":"ensamble_3"}
29
+ BASE_MODEL = "answerdotai/ModernBERT-base"
30
+ NUM_LABELS = 41
 
31
  LABELS = {i:n for i,n in enumerate([
32
  "13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci","dolly","dolly-v2-12b",
33
  "flan_t5_base","flan_t5_large","flan_t5_small","flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it",
34
  "gpt-3.5-turbo","gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b","llama3-8b",
35
  "mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b","opt-30b","opt-350m","opt-6.7b",
36
+ "opt-iml-30b","opt-iml-max-1.3b","t0-11b","t0-3b","text-davinci-002","text-davinci-003"])}
 
37
 
38
+ # ── JWT helpers ───────────────────────────────────────────────────────
39
+ SECRET_KEY = os.getenv("SECRET_KEY")
40
  if not SECRET_KEY:
41
+ raise RuntimeError("SECRET_KEY envvar not set (add it in the Space ➜ Settings ➜ Secrets)")
42
+ ALG = "HS256"; EXP_HOURS = 24
43
+ oauth2 = OAuth2PasswordBearer(tokenUrl="token")
 
44
 
45
+ def _make_token(sub:str)->str:
46
+ return jwt.encode({"sub":sub,"exp":datetime.utcnow()+timedelta(hours=EXP_HOURS)}, SECRET_KEY, algorithm=ALG)
 
47
 
48
+ def _verify(tok:str=Depends(oauth2)):
49
  try:
50
+ return jwt.decode(tok, SECRET_KEY, algorithms=[ALG])["sub"]
51
  except JWTError:
52
+ raise HTTPException(401, "Invalid or expired token")
53
+
54
+ # ── load ensemble ─────────────────────────────────────────────────────
55
+ print("🔄 Downloading weights…", flush=True)
56
+ local_paths = {k: hf_hub_download(WEIGHT_REPO, f, resume_download=True) for k, f in FILE_MAP.items()}
57
+ print("🧩 Loading ModernBERT checkpoints…", flush=True)
58
+ _cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW)
59
+ _tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
60
+ _models: List[AutoModelForSequenceClassification] = []
61
+ for p in local_paths.values():
62
+ m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, config=_cfg, **TOKEN_KW)
63
+ m.load_state_dict(torch.load(p, map_location=DEVICE))
64
+ m.to(DEVICE).eval()
65
+ _models.append(m)
66
+ print(f"✅ Ensemble ready on {DEVICE}")
67
+
68
+ # ── helpers ───────────────────────────────────────────────────────────
69
+
70
  def _tidy(t:str)->str:
71
+ t=t.replace("\r\n", "\n").replace("\r", "\n")
72
+ t=re.sub(r"\n\s*\n+", "\n\n", t)
73
+ t=re.sub(r"[ \t]+", " ", t)
74
+ t=re.sub(r"(\w+)-\n(\w+)", r"\1\2", t)
75
+ t=re.sub(r"(?<!\n)\n(?!\n)", " ", t)
76
  return t.strip()
77
 
78
  def _infer(seg:str):
79
+ inp=_tok(seg, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
80
  with torch.no_grad():
81
+ probs=torch.stack([torch.softmax(m(**inp).logits, dim=1) for m in _models]).mean(0)[0]
82
  ai_probs=probs.clone(); ai_probs[24]=0
83
  ai=ai_probs.sum().item()*100; human=100-ai
84
+ top3=[LABELS[i] for i in torch.topk(ai_probs, 3).indices.tolist()]
85
+ return human, ai, top3
86
+
87
+ # ── Pydantic schemas ─────────────────────────────────────────────────
88
+ class TokenOut(BaseModel): access_token:str; token_type:str="bearer"
89
+ class AnalyseIn(BaseModel): text:str=Field(..., min_length=1)
90
+ class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
91
+ class AnalyseOut(BaseModel): verdict:str; confidence:float; ai_avg:float; human_avg:float; per_line:List[Line]; highlight_html:str
92
+
93
+ # ── FastAPI app ───────────────────────────────────────────────────────
94
+ app = FastAPI(title="Orify Text Detector API", version="1.1.0")
95
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
96
+
97
+ @app.post("/token", response_model=TokenOut)
98
+ async def token_route(form: OAuth2PasswordRequestForm = Depends()):
99
+ return TokenOut(access_token=_make_token(form.username))
100
+
101
+ @app.post("/analyse", response_model=AnalyseOut)
102
+ async def analyse(req: AnalyseIn, _user=Depends(_verify)):
103
+ lines=_tidy(req.text).split("\n"); html_parts=[]; records=[]; h_sum=ai_sum=n=0.0
 
 
 
104
  for ln in lines:
105
+ if not ln.strip():
106
+ html_parts.append("<br>"); continue
107
+ n+=1; human, ai, top3 = _infer(ln); h_sum+=human; ai_sum+=ai
108
+ cls = "ai-line" if ai>human else "human-line"
109
+ tip = f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai>human else f"Human {human:.2f}%"
110
+ html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
111
+ reason = (f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}" if ai>human else f"Lexical variety suggests human ({human:.1f}%)")
112
+ records.append(Line(text=ln, ai=ai, human=human, top3=top3, reason=reason))
113
+ human_avg = h_sum/n if n else 0; ai_avg = ai_sum/n if n else 0
114
+ verdict = "AI-generated" if ai_avg>human_avg else "Human-written"; conf=max(human_avg, ai_avg)
115
+ badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict=="AI-generated" else f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
116
+ html_out = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts)
117
+ return AnalyseOut(verdict=verdict, confidence=conf, ai_avg=ai_avg, human_avg=human_avg, per_line=records, highlight_html=html_out)
118
+