File size: 7,881 Bytes
db897db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""Orify Text Detector API – FastAPI + JWT (CPU-only, HF Zero-GPU)

Endpoints
---------
POST /token    → returns JWT  (OAuth2 password-grant, demo-auth)
POST /analyse  → protected; returns verdict JSON + HTML highlights
"""

from __future__ import annotations
import os, re, html
from datetime import datetime, timedelta
from typing import List

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download

from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.middleware.cors import CORSMiddleware
from jose import JWTError, jwt
from pydantic import BaseModel, Field

# ─── torch.compile shim (HF CPU runtime) ───────────────────────────────
if hasattr(torch, "compile"):
    torch.compile = (lambda m=None,*a,**kw: m if callable(m) else (lambda f: f))  # type: ignore
    os.environ["TORCHINDUCTOR_DISABLED"] = "1"

# ─── model / weight config ─────────────────────────────────────────────
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
FILE_MAP    = {"ensamble_1":"ensamble_1",
               "ensamble_2.bin":"ensamble_2.bin",
               "ensamble_3":"ensamble_3"}
BASE_MODEL  = "answerdotai/ModernBERT-base"
NUM_LABELS  = 41

LABELS = {i:n for i,n in enumerate([
    "13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci",
    "dolly","dolly-v2-12b","flan_t5_base","flan_t5_large","flan_t5_small",
    "flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it","gpt-3.5-turbo",
    "gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b",
    "llama3-8b","mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b",
    "opt-30b","opt-350m","opt-6.7b","opt-iml-30b","opt-iml-max-1.3b",
    "t0-11b","t0-3b","text-davinci-002","text-davinci-003"
])}

# ─── JWT helpers ───────────────────────────────────────────────────────
SECRET_KEY = os.getenv("SECRET_KEY")
if not SECRET_KEY:
    raise RuntimeError("Set the SECRET_KEY env-var in Space ➜ Settings ➜ Secrets")

ALGORITHM               = "HS256"
ACCESS_TOKEN_EXPIRE_HOURS = 24
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

def _create_token(data: dict, exp_hours: int = ACCESS_TOKEN_EXPIRE_HOURS) -> str:
    to_encode = data.copy()
    to_encode["exp"] = datetime.utcnow() + timedelta(hours=exp_hours)
    return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)

async def _current_user(token: str = Depends(oauth2_scheme)):
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        return payload.get("sub") or "anonymous"
    except JWTError:
        raise HTTPException(401, "Invalid or expired token")

# ─── load ensemble once ────────────────────────────────────────────────
print("🔄 Downloading weights …", flush=True)
local_paths = {k:hf_hub_download(WEIGHT_REPO,f,resume_download=True)
               for k,f in FILE_MAP.items()}

print("🧩 Initialising models …", flush=True)
_tok = AutoTokenizer.from_pretrained(BASE_MODEL)
_models: List[AutoModelForSequenceClassification] = []
for p in local_paths.values():
    m = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL,
                                                           num_labels=NUM_LABELS)
    m.load_state_dict(torch.load(p, map_location=DEVICE))
    m.to(DEVICE).eval()
    _models.append(m)
print("✅  Ensemble ready")

# ─── tiny helpers ──────────────────────────────────────────────────────
def _tidy(text: str) -> str:
    text = text.replace("\r\n", "\n").replace("\r", "\n")
    text = re.sub(r"\n\s*\n+", "\n\n", text)
    text = re.sub(r"[ \t]+", " ", text)
    text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text)
    text = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
    return text.strip()

def _infer(seg: str):
    inp = _tok(seg, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
    with torch.no_grad():
        probs = torch.stack([torch.softmax(m(**inp).logits, dim=1) for m in _models]).mean(0)[0]
    ai_probs = probs.clone(); ai_probs[24] = 0    # drop explicit “human” label
    ai    = ai_probs.sum().item()*100
    human = 100 - ai
    top3  = [LABELS[i] for i in torch.topk(ai_probs, 3).indices.tolist()]
    return human, ai, top3

# ─── Pydantic schemas ─────────────────────────────────────────────────
from pydantic import BaseModel, Field
class Token(BaseModel):
    access_token: str
    token_type: str = "bearer"

class AnalyseIn(BaseModel):
    text: str = Field(..., min_length=1)

class Line(BaseModel):
    text: str; ai: float; human: float; top3: List[str]; reason: str

class AnalyseOut(BaseModel):
    verdict: str; confidence: float; ai_avg: float; human_avg: float
    per_line: List[Line]; highlight_html: str

# ─── FastAPI app ───────────────────────────────────────────────────────
app = FastAPI(title="Orify Text Detector API", version="1.0.0")
app.add_middleware(CORSMiddleware,
                   allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

@app.post("/token", response_model=Token, summary="Obtain JWT (demo accepts any creds)")
async def login(form: OAuth2PasswordRequestForm = Depends()):
    return Token(access_token=_create_token({"sub": form.username}))

@app.post("/analyse", response_model=AnalyseOut, summary="Detect AI-generated text")
async def analyse(data: AnalyseIn, _user=Depends(_current_user)):
    lines = _tidy(data.text).split("\n")
    html_parts, per_line = [], []
    h_sum = ai_sum = n = 0.0

    for ln in lines:
        if not ln.strip():
            html_parts.append("<br>")
            continue
        n += 1
        human, ai, top3 = _infer(ln)
        h_sum += human; ai_sum += ai
        cls = "ai-line" if ai > human else "human-line"
        tip = f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai > human else f"Human {human:.2f}%"
        html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
        reason = (f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}"
                  if ai > human else f"Lexical variety suggests human ({human:.1f}%)")
        per_line.append(Line(text=ln, ai=ai, human=human, top3=top3, reason=reason))

    human_avg = h_sum / n if n else 0
    ai_avg    = ai_sum / n if n else 0
    verdict   = "AI-generated" if ai_avg > human_avg else "Human-written"
    confidence = max(human_avg, ai_avg)
    badge = (f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>"
             if verdict == "AI-generated" else
             f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
    highlight_html = f"<h3>{badge}</h3><hr>" + "<br>".join(html_parts)

    return AnalyseOut(verdict=verdict, confidence=confidence,
                      ai_avg=ai_avg, human_avg=human_avg,
                      per_line=per_line, highlight_html=highlight_html)

# ────── local dev:  uvicorn app:app --reload ───────────────────────────