Sleepyriizi commited on
Commit
12aa198
·
verified ·
1 Parent(s): 6c3647a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -44
app.py CHANGED
@@ -1,30 +1,36 @@
1
- """Orify Text Detector API – FastAPI (CPU‑only)"""
2
-
3
  from __future__ import annotations
4
  import os, re, html
5
  from datetime import datetime, timedelta
6
  from typing import List
7
 
8
  import torch
9
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
 
 
10
  from huggingface_hub import hf_hub_download
11
-
12
  from fastapi import FastAPI, HTTPException, Depends
13
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from jose import jwt, JWTError
16
  from pydantic import BaseModel, Field
17
 
18
- # ----------------------------------------------------------------------------
19
  # Torch compile shim (CPU runtime)
20
- # ----------------------------------------------------------------------------
21
  if hasattr(torch, "compile"):
22
  torch.compile = (lambda m=None, *_, **__: m if callable(m) else (lambda f: f)) # type: ignore
23
  os.environ["TORCHINDUCTOR_DISABLED"] = "1"
24
 
25
- # ----------------------------------------------------------------------------
26
- # Model config
27
- # ----------------------------------------------------------------------------
 
 
 
 
 
28
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
30
  FILE_MAP = {"ensamble_1": "ensamble_1", "ensamble_2.bin": "ensamble_2.bin", "ensamble_3": "ensamble_3"}
@@ -39,49 +45,48 @@ LABELS = {i: n for i, n in enumerate([
39
  "gpt-35", "gpt-4", "gpt-4o", "gpt-j", "gpt-neox", "human", "llama3-70b",
40
  "llama3-8b", "mixtral-8x7b", "opt-1.3b", "opt-125m", "opt-13b", "opt-2.7b",
41
  "opt-30b", "opt-350m", "opt-6.7b", "opt-iml-30b", "opt-iml-max-1.3b",
42
- "t0-11b", "t0-3b", "text-davinci-002", "text-davinci-003"
43
  ])}
44
 
45
- # ----------------------------------------------------------------------------
46
  # JWT helpers
47
- # ----------------------------------------------------------------------------
48
  SECRET_KEY = os.getenv("SECRET_KEY")
49
  if not SECRET_KEY:
50
  raise RuntimeError("SECRET_KEY env‑var not set")
51
-
52
- ALGORITHM = "HS256"
53
- EXP_HOURS = 24
54
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
55
 
56
- def _jwt_create(data: dict) -> str:
57
- payload = data | {"exp": datetime.utcnow() + timedelta(hours=EXP_HOURS)}
58
- return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
59
 
60
- def _jwt_verify(token: str = Depends(oauth2_scheme)) -> str:
61
  try:
62
- return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])["sub"]
63
  except JWTError:
64
  raise HTTPException(401, "Invalid or expired token")
65
 
66
- # ----------------------------------------------------------------------------
67
- # Load ensemble once
68
- # ----------------------------------------------------------------------------
69
  print("🔄 Downloading weights…", flush=True)
70
- local = {k: hf_hub_download(WEIGHT_REPO, v, resume_download=True) for k, v in FILE_MAP.items()}
71
 
72
- print("🧩 Initialising models…", flush=True)
 
73
  _tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
74
  _models: List[AutoModelForSequenceClassification] = []
75
- for p in local.values():
76
- m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, num_labels=NUM_LABELS, **TOKEN_KW)
77
  m.load_state_dict(torch.load(p, map_location=DEVICE))
78
  m.to(DEVICE).eval()
79
  _models.append(m)
80
  print("✅ Ensemble ready")
81
 
82
- # ----------------------------------------------------------------------------
83
  # Helpers
84
- # ----------------------------------------------------------------------------
85
 
86
  def _tidy(t: str) -> str:
87
  t = t.replace("\r\n", "\n").replace("\r", "\n")
@@ -96,14 +101,14 @@ def _infer(seg: str):
96
  with torch.no_grad():
97
  probs = torch.stack([torch.softmax(m(**inp).logits, dim=1) for m in _models]).mean(0)[0]
98
  ai_probs = probs.clone(); ai_probs[24] = 0
99
- ai, human = ai_probs.sum().item() * 100, 0.0
100
  human = 100 - ai
101
  top3 = [LABELS[i] for i in torch.topk(ai_probs, 3).indices.tolist()]
102
  return human, ai, top3
103
 
104
- # ----------------------------------------------------------------------------
105
  # Schemas
106
- # ----------------------------------------------------------------------------
107
  class Token(BaseModel):
108
  access_token: str
109
  token_type: str = "bearer"
@@ -126,44 +131,43 @@ class AnalyseOut(BaseModel):
126
  per_line: List[Line]
127
  highlight_html: str
128
 
129
- # ----------------------------------------------------------------------------
130
  # FastAPI
131
- # ----------------------------------------------------------------------------
132
  app = FastAPI(title="Orify Text Detector API", version="1.0.0")
133
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
134
 
135
  @app.post("/token", response_model=Token)
136
  async def login(form: OAuth2PasswordRequestForm = Depends()):
137
- return Token(access_token=_jwt_create({"sub": form.username}))
138
 
139
  @app.post("/analyse", response_model=AnalyseOut)
140
- async def analyse(data: AnalyseIn, _user=Depends(_jwt_verify)):
141
  lines = _tidy(data.text).split("\n")
142
  per_line, html_parts = [], []
143
  h_sum = ai_sum = n = 0.0
144
 
145
  for ln in lines:
146
  if not ln.strip():
147
- html_parts.append("<br>"); continue
 
148
  n += 1
149
  human, ai, top3 = _infer(ln)
150
  h_sum += human; ai_sum += ai
151
  cls = "ai-line" if ai > human else "human-line"
152
  tip = f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai > human else f"Human {human:.2f}%"
153
  html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
154
- reason = (f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}" if ai > human
155
- else f"Lexical variety suggests human ({human:.1f}%)")
156
  per_line.append(Line(text=ln, ai=ai, human=human, top3=top3, reason=reason))
157
 
158
  human_avg = h_sum / n if n else 0
159
  ai_avg = ai_sum / n if n else 0
160
  verdict = "AI-generated" if ai_avg > human_avg else "Human-written"
161
  confidence = max(ai_avg, human_avg)
162
- badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>"
163
- if verdict == "AI-generated" else
164
  f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
165
  highlight_html = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts)
166
 
167
- return AnalyseOut(verdict=verdict, confidence=confidence, ai_avg=ai_avg,
168
- human_avg=human_avg, per_line=per_line, highlight_html=highlight_html)
169
-
 
 
 
1
  from __future__ import annotations
2
  import os, re, html
3
  from datetime import datetime, timedelta
4
  from typing import List
5
 
6
  import torch
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoConfig,
10
+ AutoModelForSequenceClassification,
11
+ )
12
  from huggingface_hub import hf_hub_download
 
13
  from fastapi import FastAPI, HTTPException, Depends
14
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
15
  from fastapi.middleware.cors import CORSMiddleware
16
  from jose import jwt, JWTError
17
  from pydantic import BaseModel, Field
18
 
19
+ # ---------------------------------------------------------------------
20
  # Torch compile shim (CPU runtime)
21
+ # ---------------------------------------------------------------------
22
  if hasattr(torch, "compile"):
23
  torch.compile = (lambda m=None, *_, **__: m if callable(m) else (lambda f: f)) # type: ignore
24
  os.environ["TORCHINDUCTOR_DISABLED"] = "1"
25
 
26
+ # ---------------------------------------------------------------------
27
+ # Environment flags – enable remote code in Transformers
28
+ # ---------------------------------------------------------------------
29
+ os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1") # allow custom ModernBERT classes
30
+
31
+ # ---------------------------------------------------------------------
32
+ # Model / weight config
33
+ # ---------------------------------------------------------------------
34
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
  WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
36
  FILE_MAP = {"ensamble_1": "ensamble_1", "ensamble_2.bin": "ensamble_2.bin", "ensamble_3": "ensamble_3"}
 
45
  "gpt-35", "gpt-4", "gpt-4o", "gpt-j", "gpt-neox", "human", "llama3-70b",
46
  "llama3-8b", "mixtral-8x7b", "opt-1.3b", "opt-125m", "opt-13b", "opt-2.7b",
47
  "opt-30b", "opt-350m", "opt-6.7b", "opt-iml-30b", "opt-iml-max-1.3b",
48
+ "t0-11b", "t0-3b", "text-davinci-002", "text-davinci-003",
49
  ])}
50
 
51
+ # ---------------------------------------------------------------------
52
  # JWT helpers
53
+ # ---------------------------------------------------------------------
54
  SECRET_KEY = os.getenv("SECRET_KEY")
55
  if not SECRET_KEY:
56
  raise RuntimeError("SECRET_KEY env‑var not set")
57
+ ALGO = "HS256"
58
+ EXP_H = 24
 
59
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
60
 
61
+ def _jwt_create(sub: str) -> str:
62
+ return jwt.encode({"sub": sub, "exp": datetime.utcnow() + timedelta(hours=EXP_H)}, SECRET_KEY, algorithm=ALGO)
 
63
 
64
+ def _jwt_verify(tok: str = Depends(oauth2_scheme)) -> str:
65
  try:
66
+ return jwt.decode(tok, SECRET_KEY, algorithms=[ALGO])["sub"]
67
  except JWTError:
68
  raise HTTPException(401, "Invalid or expired token")
69
 
70
+ # ---------------------------------------------------------------------
71
+ # Load tokenizer + config + ensemble
72
+ # ---------------------------------------------------------------------
73
  print("🔄 Downloading weights…", flush=True)
74
+ local_paths = {k: hf_hub_download(WEIGHT_REPO, v, resume_download=True) for k, v in FILE_MAP.items()}
75
 
76
+ print("🧩 Loading ModernBERT remote code…", flush=True)
77
+ _cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW)
78
  _tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
79
  _models: List[AutoModelForSequenceClassification] = []
80
+ for p in local_paths.values():
81
+ m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, config=_cfg, **TOKEN_KW)
82
  m.load_state_dict(torch.load(p, map_location=DEVICE))
83
  m.to(DEVICE).eval()
84
  _models.append(m)
85
  print("✅ Ensemble ready")
86
 
87
+ # ---------------------------------------------------------------------
88
  # Helpers
89
+ # ---------------------------------------------------------------------
90
 
91
  def _tidy(t: str) -> str:
92
  t = t.replace("\r\n", "\n").replace("\r", "\n")
 
101
  with torch.no_grad():
102
  probs = torch.stack([torch.softmax(m(**inp).logits, dim=1) for m in _models]).mean(0)[0]
103
  ai_probs = probs.clone(); ai_probs[24] = 0
104
+ ai = ai_probs.sum().item() * 100
105
  human = 100 - ai
106
  top3 = [LABELS[i] for i in torch.topk(ai_probs, 3).indices.tolist()]
107
  return human, ai, top3
108
 
109
+ # ---------------------------------------------------------------------
110
  # Schemas
111
+ # ---------------------------------------------------------------------
112
  class Token(BaseModel):
113
  access_token: str
114
  token_type: str = "bearer"
 
131
  per_line: List[Line]
132
  highlight_html: str
133
 
134
+ # ---------------------------------------------------------------------
135
  # FastAPI
136
+ # ---------------------------------------------------------------------
137
  app = FastAPI(title="Orify Text Detector API", version="1.0.0")
138
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
139
 
140
  @app.post("/token", response_model=Token)
141
  async def login(form: OAuth2PasswordRequestForm = Depends()):
142
+ return Token(access_token=_jwt_create(form.username))
143
 
144
  @app.post("/analyse", response_model=AnalyseOut)
145
+ async def analyse(data: AnalyseIn, _=Depends(_jwt_verify)):
146
  lines = _tidy(data.text).split("\n")
147
  per_line, html_parts = [], []
148
  h_sum = ai_sum = n = 0.0
149
 
150
  for ln in lines:
151
  if not ln.strip():
152
+ html_parts.append("<br>")
153
+ continue
154
  n += 1
155
  human, ai, top3 = _infer(ln)
156
  h_sum += human; ai_sum += ai
157
  cls = "ai-line" if ai > human else "human-line"
158
  tip = f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai > human else f"Human {human:.2f}%"
159
  html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
160
+ reason = (f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}" if ai > human else
161
+ f"Lexical variety suggests human ({human:.1f}%)")
162
  per_line.append(Line(text=ln, ai=ai, human=human, top3=top3, reason=reason))
163
 
164
  human_avg = h_sum / n if n else 0
165
  ai_avg = ai_sum / n if n else 0
166
  verdict = "AI-generated" if ai_avg > human_avg else "Human-written"
167
  confidence = max(ai_avg, human_avg)
168
+ badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict == "AI-generated" else
 
169
  f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
170
  highlight_html = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts)
171
 
172
+ return AnalyseOut(verdict=verdict, confidence=confidence, ai_avg=ai_avg, human_avg=human_avg,
173
+ per_line=per_line, highlight_html=highlight_html)