Sleepyriizi commited on
Commit
db897db
·
1 Parent(s): 0c4f83b

Add application file

Browse files
Files changed (2) hide show
  1. app.py +162 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Orify Text Detector API – FastAPI + JWT (CPU-only, HF Zero-GPU)
2
+
3
+ Endpoints
4
+ ---------
5
+ POST /token → returns JWT (OAuth2 password-grant, demo-auth)
6
+ POST /analyse → protected; returns verdict JSON + HTML highlights
7
+ """
8
+
9
+ from __future__ import annotations
10
+ import os, re, html
11
+ from datetime import datetime, timedelta
12
+ from typing import List
13
+
14
+ import torch
15
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ from fastapi import FastAPI, HTTPException, Depends
19
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ from jose import JWTError, jwt
22
+ from pydantic import BaseModel, Field
23
+
24
+ # ─── torch.compile shim (HF CPU runtime) ───────────────────────────────
25
+ if hasattr(torch, "compile"):
26
+ torch.compile = (lambda m=None,*a,**kw: m if callable(m) else (lambda f: f)) # type: ignore
27
+ os.environ["TORCHINDUCTOR_DISABLED"] = "1"
28
+
29
+ # ─── model / weight config ─────────────────────────────────────────────
30
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
32
+ FILE_MAP = {"ensamble_1":"ensamble_1",
33
+ "ensamble_2.bin":"ensamble_2.bin",
34
+ "ensamble_3":"ensamble_3"}
35
+ BASE_MODEL = "answerdotai/ModernBERT-base"
36
+ NUM_LABELS = 41
37
+
38
+ LABELS = {i:n for i,n in enumerate([
39
+ "13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci",
40
+ "dolly","dolly-v2-12b","flan_t5_base","flan_t5_large","flan_t5_small",
41
+ "flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it","gpt-3.5-turbo",
42
+ "gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b",
43
+ "llama3-8b","mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b",
44
+ "opt-30b","opt-350m","opt-6.7b","opt-iml-30b","opt-iml-max-1.3b",
45
+ "t0-11b","t0-3b","text-davinci-002","text-davinci-003"
46
+ ])}
47
+
48
+ # ─── JWT helpers ───────────────────────────────────────────────────────
49
+ SECRET_KEY = os.getenv("SECRET_KEY")
50
+ if not SECRET_KEY:
51
+ raise RuntimeError("Set the SECRET_KEY env-var in Space ➜ Settings ➜ Secrets")
52
+
53
+ ALGORITHM = "HS256"
54
+ ACCESS_TOKEN_EXPIRE_HOURS = 24
55
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
56
+
57
+ def _create_token(data: dict, exp_hours: int = ACCESS_TOKEN_EXPIRE_HOURS) -> str:
58
+ to_encode = data.copy()
59
+ to_encode["exp"] = datetime.utcnow() + timedelta(hours=exp_hours)
60
+ return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
61
+
62
+ async def _current_user(token: str = Depends(oauth2_scheme)):
63
+ try:
64
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
65
+ return payload.get("sub") or "anonymous"
66
+ except JWTError:
67
+ raise HTTPException(401, "Invalid or expired token")
68
+
69
+ # ─── load ensemble once ────────────────────────────────────────────────
70
+ print("🔄 Downloading weights …", flush=True)
71
+ local_paths = {k:hf_hub_download(WEIGHT_REPO,f,resume_download=True)
72
+ for k,f in FILE_MAP.items()}
73
+
74
+ print("🧩 Initialising models …", flush=True)
75
+ _tok = AutoTokenizer.from_pretrained(BASE_MODEL)
76
+ _models: List[AutoModelForSequenceClassification] = []
77
+ for p in local_paths.values():
78
+ m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL,
79
+ num_labels=NUM_LABELS)
80
+ m.load_state_dict(torch.load(p, map_location=DEVICE))
81
+ m.to(DEVICE).eval()
82
+ _models.append(m)
83
+ print("✅ Ensemble ready")
84
+
85
+ # ─── tiny helpers ──────────────────────────────────────────────────────
86
+ def _tidy(text: str) -> str:
87
+ text = text.replace("\r\n", "\n").replace("\r", "\n")
88
+ text = re.sub(r"\n\s*\n+", "\n\n", text)
89
+ text = re.sub(r"[ \t]+", " ", text)
90
+ text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text)
91
+ text = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
92
+ return text.strip()
93
+
94
+ def _infer(seg: str):
95
+ inp = _tok(seg, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
96
+ with torch.no_grad():
97
+ probs = torch.stack([torch.softmax(m(**inp).logits, dim=1) for m in _models]).mean(0)[0]
98
+ ai_probs = probs.clone(); ai_probs[24] = 0 # drop explicit “human” label
99
+ ai = ai_probs.sum().item()*100
100
+ human = 100 - ai
101
+ top3 = [LABELS[i] for i in torch.topk(ai_probs, 3).indices.tolist()]
102
+ return human, ai, top3
103
+
104
+ # ─── Pydantic schemas ─────────────────────────────────────────────────
105
+ from pydantic import BaseModel, Field
106
+ class Token(BaseModel):
107
+ access_token: str
108
+ token_type: str = "bearer"
109
+
110
+ class AnalyseIn(BaseModel):
111
+ text: str = Field(..., min_length=1)
112
+
113
+ class Line(BaseModel):
114
+ text: str; ai: float; human: float; top3: List[str]; reason: str
115
+
116
+ class AnalyseOut(BaseModel):
117
+ verdict: str; confidence: float; ai_avg: float; human_avg: float
118
+ per_line: List[Line]; highlight_html: str
119
+
120
+ # ─── FastAPI app ───────────────────────────────────────────────────────
121
+ app = FastAPI(title="Orify Text Detector API", version="1.0.0")
122
+ app.add_middleware(CORSMiddleware,
123
+ allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
124
+
125
+ @app.post("/token", response_model=Token, summary="Obtain JWT (demo accepts any creds)")
126
+ async def login(form: OAuth2PasswordRequestForm = Depends()):
127
+ return Token(access_token=_create_token({"sub": form.username}))
128
+
129
+ @app.post("/analyse", response_model=AnalyseOut, summary="Detect AI-generated text")
130
+ async def analyse(data: AnalyseIn, _user=Depends(_current_user)):
131
+ lines = _tidy(data.text).split("\n")
132
+ html_parts, per_line = [], []
133
+ h_sum = ai_sum = n = 0.0
134
+
135
+ for ln in lines:
136
+ if not ln.strip():
137
+ html_parts.append("<br>")
138
+ continue
139
+ n += 1
140
+ human, ai, top3 = _infer(ln)
141
+ h_sum += human; ai_sum += ai
142
+ cls = "ai-line" if ai > human else "human-line"
143
+ tip = f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai > human else f"Human {human:.2f}%"
144
+ html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
145
+ reason = (f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}"
146
+ if ai > human else f"Lexical variety suggests human ({human:.1f}%)")
147
+ per_line.append(Line(text=ln, ai=ai, human=human, top3=top3, reason=reason))
148
+
149
+ human_avg = h_sum / n if n else 0
150
+ ai_avg = ai_sum / n if n else 0
151
+ verdict = "AI-generated" if ai_avg > human_avg else "Human-written"
152
+ confidence = max(human_avg, ai_avg)
153
+ badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>"
154
+ if verdict == "AI-generated" else
155
+ f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
156
+ highlight_html = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts)
157
+
158
+ return AnalyseOut(verdict=verdict, confidence=confidence,
159
+ ai_avg=ai_avg, human_avg=human_avg,
160
+ per_line=per_line, highlight_html=highlight_html)
161
+
162
+ # ────── local dev: uvicorn app:app --reload ───────────────────────────
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ torch==2.2.2 # pin to the CPU build you tested with
4
+ transformers>=4.40
5
+ huggingface_hub>=0.23
6
+ python-jose[cryptography]
7
+ pydantic