Sleepyriizi commited on
Commit
d6c4fbd
·
verified ·
1 Parent(s): 30f0fad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -16,16 +16,16 @@ from fastapi.middleware.cors import CORSMiddleware
16
  from jose import jwt, JWTError
17
  from pydantic import BaseModel, Field
18
 
19
- # ── torch shim ─────────────────────────────────────────────────────────
20
  if hasattr(torch, "compile"):
21
  torch.compile = (lambda m=None,*_,**__: m if callable(m) else (lambda f: f)) # type: ignore
22
  os.environ.setdefault("TORCHINDUCTOR_DISABLED", "1")
23
 
24
- # ── remote‑code flag ───────────────────────────────────────────────────
25
  os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1")
26
  TOKEN_KW = {"trust_remote_code": True}
27
 
28
- # ── config ─────────────────────────────────────────────────────────────
29
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
31
  FILE_MAP = {"ensamble_1":"ensamble_1","ensamble_2.bin":"ensamble_2.bin","ensamble_3":"ensamble_3"}
@@ -38,7 +38,7 @@ LABELS = {i:n for i,n in enumerate([
38
  "mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b","opt-30b","opt-350m","opt-6.7b",
39
  "opt-iml-30b","opt-iml-max-1.3b","t0-11b","t0-3b","text-davinci-002","text-davinci-003"])}
40
 
41
- # ── JWT helpers ───────────────────────────────────────────────────────
42
  SECRET_KEY = os.getenv("SECRET_KEY")
43
  if not SECRET_KEY:
44
  raise RuntimeError("SECRET_KEY env‑var not set – add it in Space settings → Secrets")
@@ -55,32 +55,30 @@ def _verify_jwt(tok:str=Depends(oauth2)):
55
  except JWTError:
56
  raise HTTPException(401,"Invalid or expired token")
57
 
58
- # ── model bootstrap ───────────────────────────────────────────────────
59
  print("🔄 Fetching ensemble weights…", flush=True)
60
  paths={k:hf_hub_download(WEIGHT_REPO,f,resume_download=True) for k,f in FILE_MAP.items()}
61
 
62
  print("🧩 Building ModernBERT backbone…", flush=True)
63
- _cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW)
64
- _cfg.num_labels = NUM_LABELS # ➜ classification head = 41
65
  _tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
66
  _models: List[AutoModelForSequenceClassification] = []
67
  for p in paths.values():
68
  m = AutoModelForSequenceClassification.from_pretrained(
69
  BASE_MODEL,
70
  config=_cfg,
71
- ignore_mismatched_sizes=True, # skip the 2‑class head in checkpoint
72
  **TOKEN_KW,
73
  )
74
- state=torch.load(p, map_location=DEVICE)
75
- m.load_state_dict(state) # loads 41‑class ensemble head
76
  m.to(DEVICE).eval()
77
  _models.append(m)
78
  print(f"✅ Ensemble ready on {DEVICE}")
79
 
80
- # ── helper fns ─────────────────────────────────────────────────────────
81
 
82
  def _tidy(t:str)->str:
83
- t=t.replace("\r\n","\n").replace("\r","\n")
84
  t=re.sub(r"\n\s*\n+","\n\n",t)
85
  t=re.sub(r"[ \t]+"," ",t)
86
  t=re.sub(r"(\w+)-\n(\w+)",r"\1\2",t)
@@ -96,14 +94,14 @@ def _infer(seg:str):
96
  top3=[LABELS[i] for i in torch.topk(ai_probs,3).indices.tolist()]
97
  return human, ai, top3
98
 
99
- # ── schemas ────────────────────────────────────────���──────────────────
100
  class TokenOut(BaseModel): access_token:str; token_type:str="bearer"
101
  class AnalyseIn(BaseModel): text:str=Field(...,min_length=1)
102
  class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
103
  class AnalyseOut(BaseModel): verdict:str; confidence:float; ai_avg:float; human_avg:float; per_line:List[Line]; highlight_html:str
104
 
105
- # ── FastAPI app ───────────────────────────────────────────────────────
106
- app=FastAPI(title="Orify Text Detector API",version="1.1.1")
107
  app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"])
108
 
109
  @app.post("/token",response_model=TokenOut)
@@ -127,3 +125,9 @@ async def analyse(body:AnalyseIn,_=Depends(_verify_jwt)):
127
  badge=(f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict=="AI-generated" else f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
128
  html_out=f"<h3>{badge}</h3><hr>"+"<br>".join(html_parts)
129
  return AnalyseOut(verdict=verdict,confidence=conf,ai_avg=ai_avg,human_avg=human_avg,per_line=per,highlight_html=html_out)
 
 
 
 
 
 
 
16
  from jose import jwt, JWTError
17
  from pydantic import BaseModel, Field
18
 
19
+ # ───────────────────────── torch shim ───────────────────────────────
20
  if hasattr(torch, "compile"):
21
  torch.compile = (lambda m=None,*_,**__: m if callable(m) else (lambda f: f)) # type: ignore
22
  os.environ.setdefault("TORCHINDUCTOR_DISABLED", "1")
23
 
24
+ # ─────────────────────── remote‑code flag ───────────────────────────
25
  os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1")
26
  TOKEN_KW = {"trust_remote_code": True}
27
 
28
+ # ─────────────────────────── config ─────────────────────────────────
29
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
31
  FILE_MAP = {"ensamble_1":"ensamble_1","ensamble_2.bin":"ensamble_2.bin","ensamble_3":"ensamble_3"}
 
38
  "mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b","opt-30b","opt-350m","opt-6.7b",
39
  "opt-iml-30b","opt-iml-max-1.3b","t0-11b","t0-3b","text-davinci-002","text-davinci-003"])}
40
 
41
+ # ──────────────────────── JWT helpers ──────────────────────────────
42
  SECRET_KEY = os.getenv("SECRET_KEY")
43
  if not SECRET_KEY:
44
  raise RuntimeError("SECRET_KEY env‑var not set – add it in Space settings → Secrets")
 
55
  except JWTError:
56
  raise HTTPException(401,"Invalid or expired token")
57
 
58
+ # ─────────────────────── model bootstrap ───────────────────────────
59
  print("🔄 Fetching ensemble weights…", flush=True)
60
  paths={k:hf_hub_download(WEIGHT_REPO,f,resume_download=True) for k,f in FILE_MAP.items()}
61
 
62
  print("🧩 Building ModernBERT backbone…", flush=True)
63
+ _cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW); _cfg.num_labels = NUM_LABELS
 
64
  _tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
65
  _models: List[AutoModelForSequenceClassification] = []
66
  for p in paths.values():
67
  m = AutoModelForSequenceClassification.from_pretrained(
68
  BASE_MODEL,
69
  config=_cfg,
70
+ ignore_mismatched_sizes=True,
71
  **TOKEN_KW,
72
  )
73
+ m.load_state_dict(torch.load(p, map_location=DEVICE))
 
74
  m.to(DEVICE).eval()
75
  _models.append(m)
76
  print(f"✅ Ensemble ready on {DEVICE}")
77
 
78
+ # ───────────────────────── helpers ─────────────────────────────────
79
 
80
  def _tidy(t:str)->str:
81
+ t=t.replace("\r\n","\n").replace("\r", "\n")
82
  t=re.sub(r"\n\s*\n+","\n\n",t)
83
  t=re.sub(r"[ \t]+"," ",t)
84
  t=re.sub(r"(\w+)-\n(\w+)",r"\1\2",t)
 
94
  top3=[LABELS[i] for i in torch.topk(ai_probs,3).indices.tolist()]
95
  return human, ai, top3
96
 
97
+ # ───────────────────────── schemas ─────────────────────────────────
98
  class TokenOut(BaseModel): access_token:str; token_type:str="bearer"
99
  class AnalyseIn(BaseModel): text:str=Field(...,min_length=1)
100
  class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
101
  class AnalyseOut(BaseModel): verdict:str; confidence:float; ai_avg:float; human_avg:float; per_line:List[Line]; highlight_html:str
102
 
103
+ # ───────────────────────── FastAPI app ─────────────────────────────
104
+ app=FastAPI(title="Orify Text Detector API",version="1.2.0")
105
  app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"])
106
 
107
  @app.post("/token",response_model=TokenOut)
 
125
  badge=(f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict=="AI-generated" else f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
126
  html_out=f"<h3>{badge}</h3><hr>"+"<br>".join(html_parts)
127
  return AnalyseOut(verdict=verdict,confidence=conf,ai_avg=ai_avg,human_avg=human_avg,per_line=per,highlight_html=html_out)
128
+
129
+ # ─────────────────────────── entrypoint ────────────────────────────
130
+ if __name__ == "__main__":
131
+ import uvicorn, sys
132
+ port=int(os.environ.get("PORT", "7860"))
133
+ uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info", reload=False)