Sleepyriizi commited on
Commit
3c6385c
·
verified ·
1 Parent(s): 4c4fd1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -78
app.py CHANGED
@@ -1,10 +1,4 @@
1
- """Orify Text Detector API – FastAPI + JWT (CPU-only, HF Zero-GPU)
2
-
3
- Endpoints
4
- ---------
5
- POST /token → returns JWT (OAuth2 password-grant, demo-auth)
6
- POST /analyse → protected; returns verdict JSON + HTML highlights
7
- """
8
 
9
  from __future__ import annotations
10
  import os, re, html
@@ -18,91 +12,98 @@ from huggingface_hub import hf_hub_download
18
  from fastapi import FastAPI, HTTPException, Depends
19
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
20
  from fastapi.middleware.cors import CORSMiddleware
21
- from jose import JWTError, jwt
22
  from pydantic import BaseModel, Field
23
 
24
- # ─── torch.compile shim (HF CPU runtime) ───────────────────────────────
 
 
25
  if hasattr(torch, "compile"):
26
- torch.compile = (lambda m=None,*a,**kw: m if callable(m) else (lambda f: f)) # type: ignore
27
  os.environ["TORCHINDUCTOR_DISABLED"] = "1"
28
 
29
- # ─── model / weight config ─────────────────────────────────────────────
 
 
30
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
32
- FILE_MAP = {"ensamble_1":"ensamble_1",
33
- "ensamble_2.bin":"ensamble_2.bin",
34
- "ensamble_3":"ensamble_3"}
35
  BASE_MODEL = "answerdotai/ModernBERT-base"
36
  NUM_LABELS = 41
37
-
38
- LABELS = {i:n for i,n in enumerate([
39
- "13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci",
40
- "dolly","dolly-v2-12b","flan_t5_base","flan_t5_large","flan_t5_small",
41
- "flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it","gpt-3.5-turbo",
42
- "gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b",
43
- "llama3-8b","mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b",
44
- "opt-30b","opt-350m","opt-6.7b","opt-iml-30b","opt-iml-max-1.3b",
45
- "t0-11b","t0-3b","text-davinci-002","text-davinci-003"
 
46
  ])}
47
 
48
- # ─── JWT helpers ───────────────────────────────────────────────────────
 
 
49
  SECRET_KEY = os.getenv("SECRET_KEY")
50
  if not SECRET_KEY:
51
- raise RuntimeError("Set the SECRET_KEY env-var in Space ➜ Settings ➜ Secrets")
52
 
53
- ALGORITHM = "HS256"
54
- ACCESS_TOKEN_EXPIRE_HOURS = 24
55
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
56
 
57
- def _create_token(data: dict, exp_hours: int = ACCESS_TOKEN_EXPIRE_HOURS) -> str:
58
- to_encode = data.copy()
59
- to_encode["exp"] = datetime.utcnow() + timedelta(hours=exp_hours)
60
- return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
61
 
62
- async def _current_user(token: str = Depends(oauth2_scheme)):
63
  try:
64
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
65
- return payload.get("sub") or "anonymous"
66
  except JWTError:
67
  raise HTTPException(401, "Invalid or expired token")
68
 
69
- # ─── load ensemble once ────────────────────────────────────────────────
70
- print("🔄 Downloading weights …", flush=True)
71
- local_paths = {k:hf_hub_download(WEIGHT_REPO,f,resume_download=True)
72
- for k,f in FILE_MAP.items()}
 
73
 
74
- print("🧩 Initialising models …", flush=True)
75
- _tok = AutoTokenizer.from_pretrained(BASE_MODEL)
76
  _models: List[AutoModelForSequenceClassification] = []
77
- for p in local_paths.values():
78
- m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL,
79
- num_labels=NUM_LABELS)
80
  m.load_state_dict(torch.load(p, map_location=DEVICE))
81
  m.to(DEVICE).eval()
82
  _models.append(m)
83
- print("✅ Ensemble ready")
 
 
 
 
84
 
85
- # ─── tiny helpers ──────────────────────────────────────────────────────
86
- def _tidy(text: str) -> str:
87
- text = text.replace("\r\n", "\n").replace("\r", "\n")
88
- text = re.sub(r"\n\s*\n+", "\n\n", text)
89
- text = re.sub(r"[ \t]+", " ", text)
90
- text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text)
91
- text = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
92
- return text.strip()
93
 
94
  def _infer(seg: str):
95
  inp = _tok(seg, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
96
  with torch.no_grad():
97
  probs = torch.stack([torch.softmax(m(**inp).logits, dim=1) for m in _models]).mean(0)[0]
98
- ai_probs = probs.clone(); ai_probs[24] = 0 # drop explicit “human” label
99
- ai = ai_probs.sum().item()*100
100
  human = 100 - ai
101
- top3 = [LABELS[i] for i in torch.topk(ai_probs, 3).indices.tolist()]
102
  return human, ai, top3
103
 
104
- # ─── Pydantic schemas ─────────────────────────────────────────────────
105
- from pydantic import BaseModel, Field
 
106
  class Token(BaseModel):
107
  access_token: str
108
  token_type: str = "bearer"
@@ -111,52 +112,58 @@ class AnalyseIn(BaseModel):
111
  text: str = Field(..., min_length=1)
112
 
113
  class Line(BaseModel):
114
- text: str; ai: float; human: float; top3: List[str]; reason: str
 
 
 
 
115
 
116
  class AnalyseOut(BaseModel):
117
- verdict: str; confidence: float; ai_avg: float; human_avg: float
118
- per_line: List[Line]; highlight_html: str
119
-
120
- # ─── FastAPI app ───────────────────────────────────────────────────────
 
 
 
 
 
 
121
  app = FastAPI(title="Orify Text Detector API", version="1.0.0")
122
- app.add_middleware(CORSMiddleware,
123
- allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
124
 
125
- @app.post("/token", response_model=Token, summary="Obtain JWT (demo accepts any creds)")
126
  async def login(form: OAuth2PasswordRequestForm = Depends()):
127
- return Token(access_token=_create_token({"sub": form.username}))
128
 
129
- @app.post("/analyse", response_model=AnalyseOut, summary="Detect AI-generated text")
130
- async def analyse(data: AnalyseIn, _user=Depends(_current_user)):
131
  lines = _tidy(data.text).split("\n")
132
- html_parts, per_line = [], []
133
  h_sum = ai_sum = n = 0.0
134
 
135
  for ln in lines:
136
  if not ln.strip():
137
- html_parts.append("<br>")
138
- continue
139
  n += 1
140
  human, ai, top3 = _infer(ln)
141
  h_sum += human; ai_sum += ai
142
  cls = "ai-line" if ai > human else "human-line"
143
  tip = f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai > human else f"Human {human:.2f}%"
144
  html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
145
- reason = (f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}"
146
- if ai > human else f"Lexical variety suggests human ({human:.1f}%)")
147
  per_line.append(Line(text=ln, ai=ai, human=human, top3=top3, reason=reason))
148
 
149
  human_avg = h_sum / n if n else 0
150
  ai_avg = ai_sum / n if n else 0
151
  verdict = "AI-generated" if ai_avg > human_avg else "Human-written"
152
- confidence = max(human_avg, ai_avg)
153
  badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>"
154
  if verdict == "AI-generated" else
155
  f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
156
  highlight_html = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts)
157
 
158
- return AnalyseOut(verdict=verdict, confidence=confidence,
159
- ai_avg=ai_avg, human_avg=human_avg,
160
- per_line=per_line, highlight_html=highlight_html)
161
 
162
- # ────── local dev: uvicorn app:app --reload ───────────────────────────
 
1
+ """Orify Text Detector API – FastAPI (CPUonly)"""
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
  import os, re, html
 
12
  from fastapi import FastAPI, HTTPException, Depends
13
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
14
  from fastapi.middleware.cors import CORSMiddleware
15
+ from jose import jwt, JWTError
16
  from pydantic import BaseModel, Field
17
 
18
+ # ----------------------------------------------------------------------------
19
+ # Torch compile shim (CPU runtime)
20
+ # ----------------------------------------------------------------------------
21
  if hasattr(torch, "compile"):
22
+ torch.compile = (lambda m=None, *_, **__: m if callable(m) else (lambda f: f)) # type: ignore
23
  os.environ["TORCHINDUCTOR_DISABLED"] = "1"
24
 
25
+ # ----------------------------------------------------------------------------
26
+ # Model config
27
+ # ----------------------------------------------------------------------------
28
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
30
+ FILE_MAP = {"ensamble_1": "ensamble_1", "ensamble_2.bin": "ensamble_2.bin", "ensamble_3": "ensamble_3"}
 
 
31
  BASE_MODEL = "answerdotai/ModernBERT-base"
32
  NUM_LABELS = 41
33
+ TOKEN_KW = {"trust_remote_code": True}
34
+
35
+ LABELS = {i: n for i, n in enumerate([
36
+ "13B", "30B", "65B", "7B", "GLM130B", "bloom_7b", "bloomz", "cohere", "davinci",
37
+ "dolly", "dolly-v2-12b", "flan_t5_base", "flan_t5_large", "flan_t5_small",
38
+ "flan_t5_xl", "flan_t5_xxl", "gemma-7b-it", "gemma2-9b-it", "gpt-3.5-turbo",
39
+ "gpt-35", "gpt-4", "gpt-4o", "gpt-j", "gpt-neox", "human", "llama3-70b",
40
+ "llama3-8b", "mixtral-8x7b", "opt-1.3b", "opt-125m", "opt-13b", "opt-2.7b",
41
+ "opt-30b", "opt-350m", "opt-6.7b", "opt-iml-30b", "opt-iml-max-1.3b",
42
+ "t0-11b", "t0-3b", "text-davinci-002", "text-davinci-003"
43
  ])}
44
 
45
+ # ----------------------------------------------------------------------------
46
+ # JWT helpers
47
+ # ----------------------------------------------------------------------------
48
  SECRET_KEY = os.getenv("SECRET_KEY")
49
  if not SECRET_KEY:
50
+ raise RuntimeError("SECRET_KEY envvar not set")
51
 
52
+ ALGORITHM = "HS256"
53
+ EXP_HOURS = 24
54
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
55
 
56
+ def _jwt_create(data: dict) -> str:
57
+ payload = data | {"exp": datetime.utcnow() + timedelta(hours=EXP_HOURS)}
58
+ return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
 
59
 
60
+ def _jwt_verify(token: str = Depends(oauth2_scheme)) -> str:
61
  try:
62
+ return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])["sub"]
 
63
  except JWTError:
64
  raise HTTPException(401, "Invalid or expired token")
65
 
66
+ # ----------------------------------------------------------------------------
67
+ # Load ensemble once
68
+ # ----------------------------------------------------------------------------
69
+ print("🔄 Downloading weights…", flush=True)
70
+ local = {k: hf_hub_download(WEIGHT_REPO, v, resume_download=True) for k, v in FILE_MAP.items()}
71
 
72
+ print("🧩 Initialising models…", flush=True)
73
+ _tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
74
  _models: List[AutoModelForSequenceClassification] = []
75
+ for p in local.values():
76
+ m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, num_labels=NUM_LABELS, **TOKEN_KW)
 
77
  m.load_state_dict(torch.load(p, map_location=DEVICE))
78
  m.to(DEVICE).eval()
79
  _models.append(m)
80
+ print("✅ Ensemble ready")
81
+
82
+ # ----------------------------------------------------------------------------
83
+ # Helpers
84
+ # ----------------------------------------------------------------------------
85
 
86
+ def _tidy(t: str) -> str:
87
+ t = t.replace("\r\n", "\n").replace("\r", "\n")
88
+ t = re.sub(r"\n\s*\n+", "\n\n", t)
89
+ t = re.sub(r"[ \t]+", " ", t)
90
+ t = re.sub(r"(\w+)-\n(\w+)", r"\1\2", t)
91
+ t = re.sub(r"(?<!\n)\n(?!\n)", " ", t)
92
+ return t.strip()
 
93
 
94
  def _infer(seg: str):
95
  inp = _tok(seg, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
96
  with torch.no_grad():
97
  probs = torch.stack([torch.softmax(m(**inp).logits, dim=1) for m in _models]).mean(0)[0]
98
+ ai_probs = probs.clone(); ai_probs[24] = 0
99
+ ai, human = ai_probs.sum().item() * 100, 0.0
100
  human = 100 - ai
101
+ top3 = [LABELS[i] for i in torch.topk(ai_probs, 3).indices.tolist()]
102
  return human, ai, top3
103
 
104
+ # ----------------------------------------------------------------------------
105
+ # Schemas
106
+ # ----------------------------------------------------------------------------
107
  class Token(BaseModel):
108
  access_token: str
109
  token_type: str = "bearer"
 
112
  text: str = Field(..., min_length=1)
113
 
114
  class Line(BaseModel):
115
+ text: str
116
+ ai: float
117
+ human: float
118
+ top3: List[str]
119
+ reason: str
120
 
121
  class AnalyseOut(BaseModel):
122
+ verdict: str
123
+ confidence: float
124
+ ai_avg: float
125
+ human_avg: float
126
+ per_line: List[Line]
127
+ highlight_html: str
128
+
129
+ # ----------------------------------------------------------------------------
130
+ # FastAPI
131
+ # ----------------------------------------------------------------------------
132
  app = FastAPI(title="Orify Text Detector API", version="1.0.0")
133
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
 
134
 
135
+ @app.post("/token", response_model=Token)
136
  async def login(form: OAuth2PasswordRequestForm = Depends()):
137
+ return Token(access_token=_jwt_create({"sub": form.username}))
138
 
139
+ @app.post("/analyse", response_model=AnalyseOut)
140
+ async def analyse(data: AnalyseIn, _user=Depends(_jwt_verify)):
141
  lines = _tidy(data.text).split("\n")
142
+ per_line, html_parts = [], []
143
  h_sum = ai_sum = n = 0.0
144
 
145
  for ln in lines:
146
  if not ln.strip():
147
+ html_parts.append("<br>"); continue
 
148
  n += 1
149
  human, ai, top3 = _infer(ln)
150
  h_sum += human; ai_sum += ai
151
  cls = "ai-line" if ai > human else "human-line"
152
  tip = f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai > human else f"Human {human:.2f}%"
153
  html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
154
+ reason = (f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}" if ai > human
155
+ else f"Lexical variety suggests human ({human:.1f}%)")
156
  per_line.append(Line(text=ln, ai=ai, human=human, top3=top3, reason=reason))
157
 
158
  human_avg = h_sum / n if n else 0
159
  ai_avg = ai_sum / n if n else 0
160
  verdict = "AI-generated" if ai_avg > human_avg else "Human-written"
161
+ confidence = max(ai_avg, human_avg)
162
  badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>"
163
  if verdict == "AI-generated" else
164
  f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
165
  highlight_html = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts)
166
 
167
+ return AnalyseOut(verdict=verdict, confidence=confidence, ai_avg=ai_avg,
168
+ human_avg=human_avg, per_line=per_line, highlight_html=highlight_html)
 
169