|
from fastapi import FastAPI, HTTPException, Header |
|
from pydantic import BaseModel |
|
import numpy as np |
|
from tensorflow.keras.models import load_model |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
model = load_model("model_stresss.h5") |
|
labels = ['Tidak Stress', 'Sedikit Stress', 'Normal', 'Stress', 'Sangat Stress'] |
|
|
|
|
|
model_dir = "Chipan/indobert-emotion" |
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
model_bert = AutoModelForSequenceClassification.from_pretrained(model_dir) |
|
model_bert.eval() |
|
|
|
|
|
label_map = {0: "Bersyukur", 1: "Marah", 2: "Sedih", 3: "Senang", 4: "Stress"} |
|
|
|
|
|
app = FastAPI() |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
class CheckInData(BaseModel): |
|
mood: float |
|
sleep: float |
|
anxiety: float |
|
exercise: float |
|
support: float |
|
|
|
class TextInput(BaseModel): |
|
text: str |
|
|
|
@app.post("/predict") |
|
def predict(data: CheckInData, authorization: str = Header(None)): |
|
try: |
|
raw = np.array([[data.mood, data.sleep, data.anxiety, data.exercise, data.support]]) |
|
prediction = model.predict(raw) |
|
idx = int(np.argmax(prediction)) |
|
return { |
|
"predicted_index": idx, |
|
"predicted_label": labels[idx], |
|
"raw_prediction": prediction.tolist() |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") |
|
|
|
@app.post("/analyze") |
|
def analyze_emotion(input: TextInput): |
|
try: |
|
inputs = tokenizer(input.text, return_tensors="pt", padding=True, truncation=True, max_length=128) |
|
with torch.no_grad(): |
|
logits = model_bert(**inputs).logits |
|
probs = F.softmax(logits, dim=1) |
|
idx = int(torch.argmax(probs)) |
|
return { |
|
"emotion": label_map.get(idx, "unknown"), |
|
"confidence": round(probs[0, idx].item(), 4) |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Emotion analysis error: {str(e)}") |
|
|
|
@app.get("/") |
|
def root(): |
|
return {"status": "ok"} |
|
|