|
from fastapi import FastAPI, Request, HTTPException |
|
from fastapi.responses import HTMLResponse |
|
from pydantic import BaseModel |
|
import tensorflow as tf |
|
import numpy as np |
|
import uvicorn |
|
import os |
|
import logging |
|
from typing import Dict, Any, List |
|
from transformers import AutoTokenizer |
|
import json |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL_PATH = "model.tflite" |
|
TOKENIZER_PATH = "tokenizer" |
|
MAX_LENGTH = 128 |
|
|
|
|
|
CLASS_LABELS = { |
|
0: "Evakuasi/Penyelamatan Hewan", |
|
1: "Kebakaran", |
|
2: "Layanan Lingkungan & Fasilitas Umum", |
|
3: "Penyelamatan Non Hewan & Bantuan Teknis" |
|
} |
|
|
|
|
|
|
|
app = FastAPI( |
|
title="Damkar Classification API (TFLite)", |
|
description="API untuk klasifikasi tipe laporan damkar menggunakan TFLite model", |
|
version="1.1.0" |
|
) |
|
|
|
|
|
interpreter = None |
|
tokenizer = None |
|
input_details = None |
|
output_details = None |
|
|
|
@app.on_event("startup") |
|
async def load_model(): |
|
"""Load model dan dependencies saat aplikasi startup""" |
|
global interpreter, tokenizer, input_details, output_details |
|
|
|
try: |
|
logger.info("Loading TFLite model...") |
|
|
|
|
|
if not os.path.exists(MODEL_PATH): |
|
raise FileNotFoundError(f"Model file not found: {MODEL_PATH}") |
|
|
|
interpreter = tf.lite.Interpreter(model_path=MODEL_PATH) |
|
interpreter.allocate_tensors() |
|
|
|
|
|
input_details = interpreter.get_input_details() |
|
output_details = interpreter.get_output_details() |
|
|
|
logger.info(f"Model loaded successfully!") |
|
logger.info(f"Input details: {input_details}") |
|
logger.info(f"Output details: {output_details}") |
|
|
|
|
|
logger.info("Loading tokenizer...") |
|
if os.path.exists(TOKENIZER_PATH): |
|
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) |
|
else: |
|
logger.warning("Local tokenizer not found, using online tokenizer") |
|
tokenizer = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p1") |
|
|
|
logger.info("All components loaded successfully!") |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading model: {str(e)}") |
|
raise e |
|
|
|
def predict_tflite(text: str) -> Dict[str, Any]: |
|
"""Fungsi prediksi menggunakan TFLite model - mengembalikan output dengan label""" |
|
global interpreter, tokenizer, input_details, output_details |
|
|
|
if not all([interpreter, tokenizer]): |
|
raise HTTPException(status_code=503, detail="Model components not loaded") |
|
|
|
try: |
|
|
|
interpreter.resize_tensor_input(input_details[0]['index'], [1, MAX_LENGTH]) |
|
interpreter.resize_tensor_input(input_details[1]['index'], [1, MAX_LENGTH]) |
|
interpreter.resize_tensor_input(input_details[2]['index'], [1, MAX_LENGTH]) |
|
interpreter.allocate_tensors() |
|
|
|
|
|
encoded = tokenizer( |
|
[text], |
|
max_length=MAX_LENGTH, |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors='np' |
|
) |
|
|
|
|
|
input_ids = encoded['input_ids'].astype(np.int32) |
|
token_type_ids = encoded['token_type_ids'].astype(np.int32) |
|
attention_mask = encoded['attention_mask'].astype(np.int32) |
|
|
|
|
|
interpreter.set_tensor(input_details[0]['index'], attention_mask) |
|
interpreter.set_tensor(input_details[1]['index'], input_ids) |
|
interpreter.set_tensor(input_details[2]['index'], token_type_ids) |
|
|
|
|
|
interpreter.invoke() |
|
|
|
|
|
raw_output = interpreter.get_tensor(output_details[0]['index']) |
|
|
|
|
|
probabilities = tf.nn.softmax(raw_output[0]).numpy() |
|
|
|
|
|
predicted_class_index = int(np.argmax(raw_output, axis=1)[0]) |
|
max_confidence = float(np.max(probabilities)) |
|
|
|
|
|
predicted_class_label = CLASS_LABELS.get(predicted_class_index, "Unknown Class") |
|
|
|
return { |
|
"predicted_class_index": predicted_class_index, |
|
"predicted_class_label": predicted_class_label, |
|
"confidence": max_confidence, |
|
"raw_output": raw_output[0].tolist(), |
|
"probabilities": probabilities.tolist(), |
|
"input_text": text, |
|
"model_info": { |
|
"output_shape": raw_output.shape, |
|
"num_classes": len(probabilities) |
|
} |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Prediction error: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") |
|
|
|
|
|
class InputText(BaseModel): |
|
text: str |
|
|
|
class PredictionResponse(BaseModel): |
|
predicted_class_index: int |
|
predicted_class_label: str |
|
confidence: float |
|
raw_output: List[float] |
|
probabilities: List[float] |
|
input_text: str |
|
model_info: Dict[str, Any] |
|
status: str = "success" |
|
|
|
|
|
HTML_TEMPLATE = """ |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>Damkar Classification</title> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<style> |
|
body { |
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
max-width: 900px; |
|
margin: 0 auto; |
|
padding: 20px; |
|
background-color: #f5f5f5; |
|
} |
|
.container { |
|
background: white; |
|
padding: 30px; |
|
border-radius: 10px; |
|
box-shadow: 0 2px 10px rgba(0,0,0,0.1); |
|
} |
|
h1 { |
|
color: #d32f2f; /* Red color for Damkar */ |
|
text-align: center; |
|
margin-bottom: 30px; |
|
} |
|
.form-group { |
|
margin-bottom: 20px; |
|
} |
|
label { |
|
display: block; |
|
margin-bottom: 8px; |
|
font-weight: bold; |
|
color: #555; |
|
} |
|
textarea { |
|
width: 100%; |
|
min-height: 120px; |
|
padding: 12px; |
|
border: 2px solid #ddd; |
|
border-radius: 6px; |
|
font-size: 14px; |
|
resize: vertical; |
|
box-sizing: border-box; |
|
} |
|
textarea:focus { |
|
outline: none; |
|
border-color: #007bff; |
|
} |
|
button { |
|
background-color: #007bff; |
|
color: white; |
|
padding: 12px 30px; |
|
border: none; |
|
border-radius: 6px; |
|
cursor: pointer; |
|
font-size: 16px; |
|
width: 100%; |
|
} |
|
button:hover { |
|
background-color: #0056b3; |
|
} |
|
button:disabled { |
|
background-color: #ccc; |
|
cursor: not-allowed; |
|
} |
|
.result { |
|
margin-top: 20px; |
|
padding: 15px; |
|
border-radius: 6px; |
|
display: none; |
|
} |
|
.result.success { |
|
background-color: #d4edda; |
|
border: 1px solid #c3e6cb; |
|
color: #155724; |
|
} |
|
.result.error { |
|
background-color: #f8d7da; |
|
border: 1px solid #f5c6cb; |
|
color: #721c24; |
|
} |
|
.loading { |
|
text-align: center; |
|
display: none; |
|
} |
|
.prob-item { |
|
display: flex; |
|
justify-content: space-between; |
|
margin: 5px 0; |
|
padding: 8px; |
|
background-color: #f8f9fa; |
|
border-radius: 4px; |
|
font-family: monospace; |
|
} |
|
.examples { |
|
margin-top: 20px; |
|
padding: 15px; |
|
background-color: #f8f9fa; |
|
border-radius: 6px; |
|
} |
|
.example-text { |
|
cursor: pointer; |
|
color: #007bff; |
|
text-decoration: underline; |
|
margin: 5px 0; |
|
} |
|
.example-text:hover { |
|
color: #0056b3; |
|
} |
|
.raw-output { |
|
background-color: #f0f0f0; |
|
padding: 10px; |
|
border-radius: 4px; |
|
font-family: monospace; |
|
font-size: 12px; |
|
margin: 10px 0; |
|
max-height: 150px; |
|
overflow-y: auto; |
|
white-space: pre-wrap; |
|
word-wrap: break-word; |
|
} |
|
.predicted-label { |
|
font-size: 1.5em; |
|
font-weight: bold; |
|
color: #0056b3; |
|
text-align: center; |
|
margin: 15px 0; |
|
padding: 10px; |
|
background-color: #e7f3ff; |
|
border-radius: 6px; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<h1>🚒 Klasifikasi Laporan Damkar</h1> |
|
<p style="text-align: center; color: #666;">Masukkan teks laporan untuk diklasifikasikan oleh model AI.</p> |
|
|
|
<div class="form-group"> |
|
<label for="textInput">Masukkan teks laporan:</label> |
|
<textarea id="textInput" placeholder="Contoh: ada kebakaran di gedung perkantoran..."></textarea> |
|
</div> |
|
|
|
<button onclick="predict()" id="predictBtn">Prediksi Kategori</button> |
|
|
|
<div class="loading" id="loading"> |
|
<p>⏳ Sedang memproses...</p> |
|
</div> |
|
|
|
<div class="result" id="result"></div> |
|
|
|
<div class="examples"> |
|
<h3>Contoh Teks:</h3> |
|
<div class="example-text" onclick="setExample('ada kebakaran di gedung perkantoran lantai 5')"> |
|
🔥 "ada kebakaran di gedung perkantoran lantai 5" |
|
</div> |
|
<div class="example-text" onclick="setExample('ular masuk ke dalam rumah warga')"> |
|
🐍 "ular masuk ke dalam rumah warga" |
|
</div> |
|
<div class="example-text" onclick="setExample('pohon tumbang menghalangi jalan raya')"> |
|
🌳 "pohon tumbang menghalangi jalan raya" |
|
</div> |
|
<div class="example-text" onclick="setExample('cincin tidak bisa dilepas dari jari')"> |
|
💍 "cincin tidak bisa dilepas dari jari" |
|
</div> |
|
</div> |
|
</div> |
|
|
|
<script> |
|
const CLASS_LABELS = { |
|
0: "🐍 Evakuasi/Penyelamatan Hewan", |
|
1: "🔥 Kebakaran", |
|
2: "🌳 Layanan Lingkungan & Fasilitas Umum", |
|
3: "💍 Penyelamatan Non Hewan & Bantuan Teknis" |
|
}; |
|
|
|
function setExample(text) { |
|
document.getElementById('textInput').value = text; |
|
} |
|
|
|
async function predict() { |
|
const text = document.getElementById('textInput').value.trim(); |
|
const resultDiv = document.getElementById('result'); |
|
const loadingDiv = document.getElementById('loading'); |
|
const predictBtn = document.getElementById('predictBtn'); |
|
|
|
if (!text) { |
|
showResult('error', 'Mohon masukkan teks untuk diprediksi.'); |
|
return; |
|
} |
|
|
|
// Show loading |
|
loadingDiv.style.display = 'block'; |
|
resultDiv.style.display = 'none'; |
|
predictBtn.disabled = true; |
|
|
|
try { |
|
const response = await fetch('/predict', { |
|
method: 'POST', |
|
headers: { |
|
'Content-Type': 'application/json', |
|
}, |
|
body: JSON.stringify({ text: text }) |
|
}); |
|
|
|
const data = await response.json(); |
|
|
|
if (response.ok) { |
|
const label = CLASS_LABELS[data.predicted_class_index] || "Label tidak diketahui"; |
|
|
|
let resultHTML = ` |
|
<h3>Hasil Prediksi:</h3> |
|
<div class="predicted-label">${label}</div> |
|
<p><strong>Confidence:</strong> ${(data.confidence * 100).toFixed(2)}%</p> |
|
|
|
<h4>Probabilitas per Kelas:</h4> |
|
`; |
|
|
|
data.probabilities.forEach((prob, index) => { |
|
const percentage = (prob * 100).toFixed(4); |
|
const isMax = index === data.predicted_class_index; |
|
const classLabel = CLASS_LABELS[index] || `Class ${index}`; |
|
resultHTML += ` |
|
<div class="prob-item" style="${isMax ? 'background-color: #fff3cd; font-weight: bold;' : ''}"> |
|
<span>${classLabel}</span> |
|
<span>${percentage}%</span> |
|
</div> |
|
`; |
|
}); |
|
|
|
resultHTML += ` |
|
<details> |
|
<summary style="cursor: pointer; margin-top: 15px;">Lihat Raw Kategori Laporan (untuk developer)</summary> |
|
<p><strong>Predicted Class Index:</strong> ${data.predicted_class_index}</p> |
|
<h4>Raw Kategori Laporan (Logits):</h4> |
|
<div class="raw-output">${JSON.stringify(data.raw_output, null, 2)}</div> |
|
</details> |
|
`; |
|
|
|
showResult('success', resultHTML); |
|
} else { |
|
showResult('error', `Error: ${data.detail || 'Unknown error'}`); |
|
} |
|
} catch (error) { |
|
showResult('error', `Network error: ${error.message}`); |
|
} finally { |
|
loadingDiv.style.display = 'none'; |
|
predictBtn.disabled = false; |
|
} |
|
} |
|
|
|
function showResult(type, content) { |
|
const resultDiv = document.getElementById('result'); |
|
resultDiv.className = `result ${type}`; |
|
resultDiv.innerHTML = content; |
|
resultDiv.style.display = 'block'; |
|
} |
|
|
|
// Allow Ctrl+Enter to submit |
|
document.getElementById('textInput').addEventListener('keydown', function(e) { |
|
if (e.key === 'Enter' && (e.ctrlKey || e.metaKey)) { |
|
predict(); |
|
} |
|
}); |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
def read_root(): |
|
"""UI Interface untuk testing""" |
|
return HTML_TEMPLATE |
|
|
|
@app.get("/health") |
|
def health_check(): |
|
"""Health check endpoint""" |
|
global interpreter, tokenizer |
|
|
|
if not all([interpreter, tokenizer]): |
|
return {"status": "unhealthy", "message": "Model components not loaded"} |
|
|
|
return { |
|
"status": "healthy", |
|
"message": "TFLite model is ready", |
|
"model_info": { |
|
"input_details": [ |
|
{ |
|
"name": detail.get('name', f'input_{i}'), |
|
"shape": detail['shape'].tolist(), |
|
"dtype": str(detail['dtype']) |
|
} for i, detail in enumerate(input_details) |
|
], |
|
"output_details": [ |
|
{ |
|
"name": detail.get('name', f'output_{i}'), |
|
"shape": detail['shape'].tolist(), |
|
"dtype": str(detail['dtype']) |
|
} for i, detail in enumerate(output_details) |
|
], |
|
"max_length": MAX_LENGTH, |
|
"class_labels": CLASS_LABELS |
|
} |
|
} |
|
|
|
@app.post("/predict", response_model=PredictionResponse) |
|
def predict(input: InputText): |
|
"""API endpoint untuk prediksi""" |
|
|
|
|
|
if not input.text or input.text.strip() == "": |
|
raise HTTPException(status_code=400, detail="Text input cannot be empty") |
|
|
|
try: |
|
|
|
result = predict_tflite(input.text) |
|
|
|
return PredictionResponse(**result) |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"Unexpected error: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
@app.get("/test") |
|
def test_endpoint(): |
|
"""Test endpoint""" |
|
return { |
|
"message": "TFLite API is working!", |
|
"status": "ok", |
|
"version": "1.1.0", |
|
"endpoints": { |
|
"ui": "/", |
|
"predict": "/predict", |
|
"health": "/health", |
|
"docs": "/docs" |
|
} |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |