Sleepyriizi commited on
Commit
7ce0483
·
verified ·
1 Parent(s): 12aa198

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -140
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from __future__ import annotations
2
  import os, re, html
3
  from datetime import datetime, timedelta
@@ -5,169 +7,126 @@ from typing import List
5
 
6
  import torch
7
  from transformers import (
8
- AutoTokenizer,
9
  AutoConfig,
 
10
  AutoModelForSequenceClassification,
11
  )
12
  from huggingface_hub import hf_hub_download
13
- from fastapi import FastAPI, HTTPException, Depends
14
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
15
  from fastapi.middleware.cors import CORSMiddleware
16
  from jose import jwt, JWTError
17
  from pydantic import BaseModel, Field
18
 
19
- # ---------------------------------------------------------------------
20
- # Torch compile shim (CPU runtime)
21
- # ---------------------------------------------------------------------
22
- if hasattr(torch, "compile"):
23
- torch.compile = (lambda m=None, *_, **__: m if callable(m) else (lambda f: f)) # type: ignore
24
  os.environ["TORCHINDUCTOR_DISABLED"] = "1"
25
 
26
- # ---------------------------------------------------------------------
27
- # Environment flags enable remote code in Transformers
28
- # ---------------------------------------------------------------------
29
- os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1") # allow custom ModernBERT classes
30
-
31
- # ---------------------------------------------------------------------
32
- # Model / weight config
33
- # ---------------------------------------------------------------------
34
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
- WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
36
- FILE_MAP = {"ensamble_1": "ensamble_1", "ensamble_2.bin": "ensamble_2.bin", "ensamble_3": "ensamble_3"}
37
- BASE_MODEL = "answerdotai/ModernBERT-base"
38
- NUM_LABELS = 41
39
- TOKEN_KW = {"trust_remote_code": True}
40
-
41
- LABELS = {i: n for i, n in enumerate([
42
- "13B", "30B", "65B", "7B", "GLM130B", "bloom_7b", "bloomz", "cohere", "davinci",
43
- "dolly", "dolly-v2-12b", "flan_t5_base", "flan_t5_large", "flan_t5_small",
44
- "flan_t5_xl", "flan_t5_xxl", "gemma-7b-it", "gemma2-9b-it", "gpt-3.5-turbo",
45
- "gpt-35", "gpt-4", "gpt-4o", "gpt-j", "gpt-neox", "human", "llama3-70b",
46
- "llama3-8b", "mixtral-8x7b", "opt-1.3b", "opt-125m", "opt-13b", "opt-2.7b",
47
- "opt-30b", "opt-350m", "opt-6.7b", "opt-iml-30b", "opt-iml-max-1.3b",
48
- "t0-11b", "t0-3b", "text-davinci-002", "text-davinci-003",
49
  ])}
50
 
51
- # ---------------------------------------------------------------------
52
- # JWT helpers
53
- # ---------------------------------------------------------------------
54
- SECRET_KEY = os.getenv("SECRET_KEY")
55
  if not SECRET_KEY:
56
- raise RuntimeError("SECRET_KEY envvar not set")
57
- ALGO = "HS256"
58
- EXP_H = 24
59
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
60
 
61
- def _jwt_create(sub: str) -> str:
62
- return jwt.encode({"sub": sub, "exp": datetime.utcnow() + timedelta(hours=EXP_H)}, SECRET_KEY, algorithm=ALGO)
 
63
 
64
- def _jwt_verify(tok: str = Depends(oauth2_scheme)) -> str:
65
  try:
66
- return jwt.decode(tok, SECRET_KEY, algorithms=[ALGO])["sub"]
67
  except JWTError:
68
- raise HTTPException(401, "Invalid or expired token")
69
-
70
- # ---------------------------------------------------------------------
71
- # Load tokenizer + config + ensemble
72
- # ---------------------------------------------------------------------
73
- print("🔄 Downloading weights…", flush=True)
74
- local_paths = {k: hf_hub_download(WEIGHT_REPO, v, resume_download=True) for k, v in FILE_MAP.items()}
75
-
76
- print("🧩 Loading ModernBERT remote code…", flush=True)
77
- _cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW)
78
- _tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
79
- _models: List[AutoModelForSequenceClassification] = []
80
- for p in local_paths.values():
81
- m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, config=_cfg, **TOKEN_KW)
82
- m.load_state_dict(torch.load(p, map_location=DEVICE))
83
- m.to(DEVICE).eval()
84
- _models.append(m)
85
  print("✅ Ensemble ready")
86
 
87
- # ---------------------------------------------------------------------
88
- # Helpers
89
- # ---------------------------------------------------------------------
90
-
91
- def _tidy(t: str) -> str:
92
- t = t.replace("\r\n", "\n").replace("\r", "\n")
93
- t = re.sub(r"\n\s*\n+", "\n\n", t)
94
- t = re.sub(r"[ \t]+", " ", t)
95
- t = re.sub(r"(\w+)-\n(\w+)", r"\1\2", t)
96
- t = re.sub(r"(?<!\n)\n(?!\n)", " ", t)
97
  return t.strip()
98
 
99
- def _infer(seg: str):
100
- inp = _tok(seg, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
101
  with torch.no_grad():
102
- probs = torch.stack([torch.softmax(m(**inp).logits, dim=1) for m in _models]).mean(0)[0]
103
- ai_probs = probs.clone(); ai_probs[24] = 0
104
- ai = ai_probs.sum().item() * 100
105
- human = 100 - ai
106
- top3 = [LABELS[i] for i in torch.topk(ai_probs, 3).indices.tolist()]
107
- return human, ai, top3
108
-
109
- # ---------------------------------------------------------------------
110
- # Schemas
111
- # ---------------------------------------------------------------------
112
- class Token(BaseModel):
113
- access_token: str
114
- token_type: str = "bearer"
115
-
116
- class AnalyseIn(BaseModel):
117
- text: str = Field(..., min_length=1)
118
-
119
- class Line(BaseModel):
120
- text: str
121
- ai: float
122
- human: float
123
- top3: List[str]
124
- reason: str
125
-
126
  class AnalyseOut(BaseModel):
127
- verdict: str
128
- confidence: float
129
- ai_avg: float
130
- human_avg: float
131
- per_line: List[Line]
132
- highlight_html: str
133
-
134
- # ---------------------------------------------------------------------
135
- # FastAPI
136
- # ---------------------------------------------------------------------
137
- app = FastAPI(title="Orify Text Detector API", version="1.0.0")
138
- app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
139
-
140
- @app.post("/token", response_model=Token)
141
- async def login(form: OAuth2PasswordRequestForm = Depends()):
142
- return Token(access_token=_jwt_create(form.username))
143
-
144
- @app.post("/analyse", response_model=AnalyseOut)
145
- async def analyse(data: AnalyseIn, _=Depends(_jwt_verify)):
146
- lines = _tidy(data.text).split("\n")
147
- per_line, html_parts = [], []
148
- h_sum = ai_sum = n = 0.0
149
 
 
 
 
 
150
  for ln in lines:
151
- if not ln.strip():
152
- html_parts.append("<br>")
153
- continue
154
- n += 1
155
- human, ai, top3 = _infer(ln)
156
- h_sum += human; ai_sum += ai
157
- cls = "ai-line" if ai > human else "human-line"
158
- tip = f"AI {ai:.2f}% Top-3: {', '.join(top3)}" if ai > human else f"Human {human:.2f}%"
159
- html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
160
- reason = (f"High AI likelihood ({ai:.1f}%) fingerprint ≈ {top3[0]}" if ai > human else
161
- f"Lexical variety suggests human ({human:.1f}%)")
162
- per_line.append(Line(text=ln, ai=ai, human=human, top3=top3, reason=reason))
163
-
164
- human_avg = h_sum / n if n else 0
165
- ai_avg = ai_sum / n if n else 0
166
- verdict = "AI-generated" if ai_avg > human_avg else "Human-written"
167
- confidence = max(ai_avg, human_avg)
168
- badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict == "AI-generated" else
169
- f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
170
- highlight_html = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts)
171
-
172
- return AnalyseOut(verdict=verdict, confidence=confidence, ai_avg=ai_avg, human_avg=human_avg,
173
- per_line=per_line, highlight_html=highlight_html)
 
1
+ """Orify Text Detector API – FastAPI / JWT / HF Zero-GPU"""
2
+
3
  from __future__ import annotations
4
  import os, re, html
5
  from datetime import datetime, timedelta
 
7
 
8
  import torch
9
  from transformers import (
 
10
  AutoConfig,
11
+ AutoTokenizer,
12
  AutoModelForSequenceClassification,
13
  )
14
  from huggingface_hub import hf_hub_download
15
+ from fastapi import FastAPI, Depends, HTTPException
16
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
17
  from fastapi.middleware.cors import CORSMiddleware
18
  from jose import jwt, JWTError
19
  from pydantic import BaseModel, Field
20
 
21
+ # ───────────────────── Torch shim ─────────────────────
22
+ if hasattr(torch, "compile"): # HF CPU image has torch-compile
23
+ torch.compile = lambda m=None,*_,**__: m if callable(m) else (lambda f: f) # type: ignore
 
 
24
  os.environ["TORCHINDUCTOR_DISABLED"] = "1"
25
 
26
+ # ───────────────────── Remote-code flag ───────────────
27
+ os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1") # let Transformers load ModernBERT code
28
+ TOKEN_KW = {"trust_remote_code": True}
29
+
30
+ # ───────────────────── Config ─────────────────────────
31
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
33
+ FILE_MAP = {"ensamble_1":"ensamble_1","ensamble_2.bin":"ensamble_2.bin","ensamble_3":"ensamble_3"}
34
+ BASE_MODEL = "answerdotai/ModernBERT-base"
35
+ NUM_LABELS = 41
36
+
37
+ LABELS = {i:n for i,n in enumerate([
38
+ "13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci","dolly","dolly-v2-12b",
39
+ "flan_t5_base","flan_t5_large","flan_t5_small","flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it",
40
+ "gpt-3.5-turbo","gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b","llama3-8b",
41
+ "mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b","opt-30b","opt-350m","opt-6.7b",
42
+ "opt-iml-30b","opt-iml-max-1.3b","t0-11b","t0-3b","text-davinci-002","text-davinci-003"
 
 
 
 
 
 
43
  ])}
44
 
45
+ # ───────────────────── JWT helpers ────────────────────
46
+ SECRET_KEY = os.getenv("SECRET_KEY") # set this in Space → Settings → Secrets
 
 
47
  if not SECRET_KEY:
48
+ raise RuntimeError("Set SECRET_KEY env-var")
49
+ ALG = "HS256"
50
+ EXP_HOURS = 24
51
+ oauth_scheme = OAuth2PasswordBearer(tokenUrl="token")
52
 
53
+ def _jwt_make(sub:str)->str:
54
+ payload={"sub":sub,"exp":datetime.utcnow()+timedelta(hours=EXP_HOURS)}
55
+ return jwt.encode(payload, SECRET_KEY, algorithm=ALG)
56
 
57
+ def _jwt_ok(token:str=Depends(oauth_scheme))->str:
58
  try:
59
+ return jwt.decode(token, SECRET_KEY, algorithms=[ALG])["sub"]
60
  except JWTError:
61
+ raise HTTPException(401,"Invalid or expired token")
62
+
63
+ # ───────────────────── Model bootstrap ────────────────
64
+ print("🔄 Fetching ensemble weights…", flush=True)
65
+ paths={k:hf_hub_download(WEIGHT_REPO,f,resume_download=True) for k,f in FILE_MAP.items()}
66
+
67
+ print("🧩 Loading ModernBERT code…", flush=True)
68
+ cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW)
69
+ tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
70
+ models=[]
71
+ for p in paths.values():
72
+ m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, config=cfg, **TOKEN_KW)
73
+ m.load_state_dict(torch.load(p,map_location=DEVICE))
74
+ m.to(DEVICE).eval(); models.append(m)
 
 
 
75
  print("✅ Ensemble ready")
76
 
77
+ # ───────────────────── Utils ──────────────────────────
78
+ def _tidy(t:str)->str:
79
+ t=t.replace("\r\n","\n").replace("\r","\n")
80
+ t=re.sub(r"\n\s*\n+","\n\n",t)
81
+ t=re.sub(r"[ \t]+"," ",t)
82
+ t=re.sub(r"(\w+)-\n(\w+)",r"\\1\\2",t)
83
+ t=re.sub(r"(?<!\n)\n(?!\n)"," ",t)
 
 
 
84
  return t.strip()
85
 
86
+ def _infer(seg:str):
87
+ inp=tok(seg,return_tensors="pt",trunc=True,padding=True).to(DEVICE)
88
  with torch.no_grad():
89
+ probs=torch.stack([torch.softmax(m(**inp).logits,1) for m in models]).mean(0)[0]
90
+ ai_probs=probs.clone(); ai_probs[24]=0
91
+ ai=ai_probs.sum().item()*100; human=100-ai
92
+ top3=[LABELS[i] for i in torch.topk(ai_probs,3).indices.tolist()]
93
+ return human,ai,top3
94
+
95
+ # ───────────────────── Schemas ────────────────────────
96
+ class Token(BaseModel): access_token:str; token_type:str="bearer"
97
+ class AnalyseIn(BaseModel): text:str=Field(...,min_length=1)
98
+ class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  class AnalyseOut(BaseModel):
100
+ verdict:str; confidence:float; ai_avg:float; human_avg:float
101
+ per_line:List[Line]; highlight_html:str
102
+
103
+ # ───────────────────── FastAPI app ────────────────────
104
+ app=FastAPI(title="Orify Text Detector API",version="1.0.0")
105
+ app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"])
106
+
107
+ @app.post("/token",response_model=Token,summary="Get JWT")
108
+ async def login(form:OAuth2PasswordRequestForm=Depends()):
109
+ return Token(access_token=_jwt_make(form.username))
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ @app.post("/analyse",response_model=AnalyseOut,summary="Detect AI-generated text")
112
+ async def analyse(req:AnalyseIn,_user=Depends(_jwt_ok)):
113
+ lines=_tidy(req.text).split("\\n")
114
+ html_parts=[]; per_line=[]; h_sum=ai_sum=n=0.0
115
  for ln in lines:
116
+ if not ln.strip(): html_parts.append("<br>"); continue
117
+ n+=1
118
+ h,ai,top3=_infer(ln); h_sum+=h; ai_sum+=ai
119
+ cls="ai-line" if ai>h else "human-line"
120
+ tip=f\"AI {ai:.2f}% – Top-3: {', '.join(top3)}\" if ai>h else f\"Human {h:.2f}%\"
121
+ html_parts.append(f\"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>\")
122
+ reason=(f\"High AI likelihood ({ai:.1f}%) fingerprint {top3[0]}\" if ai>h
123
+ else f\"Lexical variety suggests human ({h:.1f}%)\")
124
+ per_line.append(Line(text=ln,ai=ai,human=h,top3=top3,reason=reason))
125
+ human_avg=h_sum/n if n else 0; ai_avg=ai_sum/n if n else 0
126
+ verdict=\"AI-generated\" if ai_avg>human_avg else \"Human-written\"
127
+ conf=max(ai_avg,human_avg)
128
+ badge=(f\"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>\"
129
+ if verdict==\"AI-generated\" else
130
+ f\"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>\")
131
+ return AnalyseOut(verdict=verdict,confidence=conf,ai_avg=ai_avg,human_avg=human_avg,
132
+ per_line=per_line,highlight_html=f\"<h3>{badge}</h3><hr>\"+\"<br>\".join(html_parts))