|
|
|
|
|
from fastapi import FastAPI |
|
|
from fastapi.responses import HTMLResponse |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Dict, Any |
|
|
|
|
|
try: |
|
|
import torch |
|
|
DEVICE = 0 if torch.cuda.is_available() else -1 |
|
|
except Exception: |
|
|
DEVICE = -1 |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline |
|
|
|
|
|
|
|
|
FABSA_ID = "Anudeep-Narala/fabsa-roberta-sentiment" |
|
|
TWITTER_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest" |
|
|
MOOD_ID = "Priyanshuchaudhary2425/MoodMeter-sentimental-analysis" |
|
|
|
|
|
def load_pipe(model_id: str) -> TextClassificationPipeline: |
|
|
tok = AutoTokenizer.from_pretrained(model_id) |
|
|
mdl = AutoModelForSequenceClassification.from_pretrained(model_id) |
|
|
return TextClassificationPipeline( |
|
|
model=mdl, tokenizer=tok, device=DEVICE, |
|
|
return_all_scores=True, truncation=True |
|
|
) |
|
|
|
|
|
print("Loading models…") |
|
|
fabsa = load_pipe(FABSA_ID) |
|
|
twitter = load_pipe(TWITTER_ID) |
|
|
mood = load_pipe(MOOD_ID) |
|
|
print("Models ready.") |
|
|
|
|
|
def norm3(scores: List[Dict[str, Any]]): |
|
|
"""Map list[{label,score}] -> (pred, scores3, top, margin).""" |
|
|
out = {"negative": 0.0, "neutral": 0.0, "positive": 0.0} |
|
|
for e in scores: |
|
|
lbl = e["label"].lower() |
|
|
s = float(e["score"]) |
|
|
if "neg" in lbl or lbl == "label_0": out["negative"] = s |
|
|
elif "neu" in lbl or lbl == "label_1": out["neutral"] = s |
|
|
elif "pos" in lbl or lbl == "label_2": out["positive"] = s |
|
|
pred = max(out, key=out.get) |
|
|
vals = sorted(out.values(), reverse=True) |
|
|
top, margin = float(out[pred]), float(vals[0]-vals[1]) |
|
|
return pred, out, top, margin |
|
|
|
|
|
def norm2(scores: List[Dict[str, Any]]): |
|
|
"""Embed 2-class model into 3-class dict (neutral=0).""" |
|
|
d = {"negative": 0.0, "positive": 0.0} |
|
|
for e in scores: |
|
|
lbl = e["label"].lower() |
|
|
s = float(e["score"]) |
|
|
if "neg" in lbl or lbl == "label_0": d["negative"] = s |
|
|
elif "pos" in lbl or lbl == "label_1": d["positive"] = s |
|
|
pred = "negative" if d["negative"] >= d["positive"] else "positive" |
|
|
top = float(max(d.values())) |
|
|
margin = float(abs(d["negative"] - d["positive"])) |
|
|
out3 = {"negative": d["negative"], "neutral": 0.0, "positive": d["positive"]} |
|
|
return pred, out3, top, margin |
|
|
|
|
|
def fuse(fabsa_label: str, twitter_label: str) -> str: |
|
|
|
|
|
|
|
|
if fabsa_label == "negative": return "negative" |
|
|
if twitter_label == "neutral": return "neutral" |
|
|
return twitter_label |
|
|
|
|
|
app = FastAPI(title="HF Space — Sentiment Inference (FABSA + MoodMeter + Twitter)") |
|
|
|
|
|
class PredictIn(BaseModel): |
|
|
text: str |
|
|
|
|
|
class BatchIn(BaseModel): |
|
|
texts: List[str] |
|
|
|
|
|
@app.get("/health") |
|
|
def health(): |
|
|
return {"ok": True, "device": DEVICE} |
|
|
|
|
|
@app.post("/predict") |
|
|
def predict(inp: PredictIn): |
|
|
t = (inp.text or "").strip() |
|
|
f_raw = fabsa(t)[0] |
|
|
t_raw = twitter(t)[0] |
|
|
m_raw = mood(t)[0] |
|
|
|
|
|
f_pred, f_scores, f_top, f_margin = norm3(f_raw) |
|
|
t_pred, t_scores, t_top, t_margin = norm3(t_raw) |
|
|
m_pred, m_scores, m_top, m_margin = norm2(m_raw) |
|
|
|
|
|
return { |
|
|
"text": t, |
|
|
"fabsa": {"label": f_pred, "scores": f_scores, "top": f_top, "margin": f_margin}, |
|
|
"twitter": {"label": t_pred, "scores": t_scores, "top": t_top, "margin": t_margin}, |
|
|
"mood": {"label": m_pred, "scores": m_scores, "top": m_top, "margin": m_margin}, |
|
|
"ensemble": {"label": fuse(f_pred, t_pred)} |
|
|
} |
|
|
|
|
|
@app.post("/batch") |
|
|
def batch(inp: BatchIn): |
|
|
texts = [(x or "").strip() for x in inp.texts] |
|
|
f_raw = fabsa(texts, batch_size=16) |
|
|
t_raw = twitter(texts, batch_size=16) |
|
|
m_raw = mood(texts, batch_size=16) |
|
|
|
|
|
out = [] |
|
|
for i, t in enumerate(texts): |
|
|
f_pred, f_scores, f_top, f_margin = norm3(f_raw[i]) |
|
|
t_pred, t_scores, t_top, t_margin = norm3(t_raw[i]) |
|
|
m_pred, m_scores, m_top, m_margin = norm2(m_raw[i]) |
|
|
out.append({ |
|
|
"text": t, |
|
|
"fabsa": {"label": f_pred, "scores": f_scores, "top": f_top, "margin": f_margin}, |
|
|
"twitter": {"label": t_pred, "scores": t_scores, "top": t_top, "margin": t_margin}, |
|
|
"mood": {"label": m_pred, "scores": m_scores, "top": m_top, "margin": m_margin}, |
|
|
"ensemble": {"label": fuse(f_pred, t_pred)} |
|
|
}) |
|
|
return {"items": out} |
|
|
|
|
|
|
|
|
INDEX_HTML = """<!doctype html> |
|
|
<html lang="en"><head> |
|
|
<meta charset="utf-8"/> |
|
|
<meta name="viewport" content="width=device-width,initial-scale=1"/> |
|
|
<title>Sentiment Space — Quick Test</title> |
|
|
<style> |
|
|
body{font-family: system-ui, Arial, sans-serif; background:#0b0b0b; color:#eee; padding:24px; max-width:900px; margin:0 auto;} |
|
|
.card{border:1px solid #333; border-radius:12px; padding:16px; background:#111; margin-top:16px;} |
|
|
textarea{width:100%; background:#0e0e0e; color:#eee; border:1px solid #333; border-radius:8px; padding:12px;} |
|
|
button{background:#4F46E5; color:#fff; border:none; padding:10px 14px; border-radius:8px; cursor:pointer;} |
|
|
pre{background:#0e0e0e; padding:12px; border-radius:8px; overflow:auto;} |
|
|
.row{display:grid; grid-template-columns: 1fr 1fr; gap:16px;} |
|
|
</style> |
|
|
</head><body> |
|
|
<h1>Sentiment Model Space — Quick Test</h1> |
|
|
<div class="card"> |
|
|
<textarea id="txt" rows="5" placeholder="Type something like: I feel exhausted and nothing seems to help."></textarea> |
|
|
<div style="margin-top:12px; display:flex; gap:12px;"> |
|
|
<button onclick="run()">Predict</button> |
|
|
<button onclick="demo()">Demo Text</button> |
|
|
</div> |
|
|
<div id="status" style="opacity:.7; margin-top:8px;"></div> |
|
|
</div> |
|
|
|
|
|
<div class="row"> |
|
|
<div class="card"><h3>FABSA</h3><pre id="fabsa"></pre></div> |
|
|
<div class="card"><h3>Twitter-RoBERTa</h3><pre id="twitter"></pre></div> |
|
|
</div> |
|
|
<div class="row"> |
|
|
<div class="card"><h3>MoodMeter (2-class)</h3><pre id="mood"></pre></div> |
|
|
<div class="card"><h3>Ensemble</h3><pre id="ens"></pre></div> |
|
|
</div> |
|
|
|
|
|
<script> |
|
|
async function run(){ |
|
|
const s = document.getElementById('txt').value.trim(); |
|
|
if(!s){ alert('Enter some text'); return; } |
|
|
set('#status','Predicting...') |
|
|
try{ |
|
|
const r = await fetch('/predict', { |
|
|
method:'POST', headers:{'Content-Type':'application/json'}, |
|
|
body: JSON.stringify({text:s}) |
|
|
}); |
|
|
const j = await r.json(); |
|
|
set('#fabsa', fmt(j.fabsa)); |
|
|
set('#twitter', fmt(j.twitter)); |
|
|
set('#mood', fmt(j.mood)); |
|
|
set('#ens', JSON.stringify(j.ensemble, null, 2)); |
|
|
set('#status','Done.') |
|
|
}catch(e){ |
|
|
set('#status','Error: '+e.message) |
|
|
} |
|
|
} |
|
|
function set(sel, val){ document.querySelector(sel).textContent = typeof val==='string'? val : JSON.stringify(val,null,2); } |
|
|
function fmt(x){ |
|
|
if(!x) return ''; |
|
|
const s = x.scores||{}; |
|
|
return JSON.stringify({ |
|
|
label: x.label, |
|
|
neg: round(s.negative), neu: round(s.neutral), pos: round(s.positive), |
|
|
top: round(x.top), margin: round(x.margin) |
|
|
}, null, 2); |
|
|
} |
|
|
function round(v){ return (v==null? null : Math.round(v*1000)/1000); } |
|
|
function demo(){ |
|
|
document.getElementById('txt').value = "It’s 3 a.m. again. I’m staring at the ceiling replaying everything I might fail tomorrow."; |
|
|
} |
|
|
</script> |
|
|
</body></html> |
|
|
""" |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
def index(): |
|
|
return HTMLResponse(INDEX_HTML) |
|
|
|