|
from fastapi import FastAPI, Request, HTTPException |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from pydantic import BaseModel |
|
import tensorflow as tf |
|
import numpy as np |
|
import uvicorn |
|
import os |
|
import logging |
|
import pickle |
|
from typing import Dict, Any |
|
from transformers import AutoTokenizer |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL_PATH = "model.tflite" |
|
TOKENIZER_PATH = "tokenizer" |
|
LABEL_ENCODER_PATH = "label_encoder.pkl" |
|
MAX_LENGTH = 128 |
|
|
|
|
|
app = FastAPI( |
|
title="Damkar Classification API (TFLite)", |
|
description="API untuk klasifikasi tipe laporan damkar menggunakan TFLite model", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
interpreter = None |
|
tokenizer = None |
|
label_encoder = None |
|
input_details = None |
|
output_details = None |
|
|
|
@app.on_event("startup") |
|
async def load_model(): |
|
"""Load model dan dependencies saat aplikasi startup""" |
|
global interpreter, tokenizer, label_encoder, input_details, output_details |
|
|
|
try: |
|
logger.info("Loading TFLite model...") |
|
|
|
|
|
if not os.path.exists(MODEL_PATH): |
|
raise FileNotFoundError(f"Model file not found: {MODEL_PATH}") |
|
|
|
interpreter = tf.lite.Interpreter(model_path=MODEL_PATH) |
|
interpreter.allocate_tensors() |
|
|
|
|
|
input_details = interpreter.get_input_details() |
|
output_details = interpreter.get_output_details() |
|
|
|
logger.info(f"Model loaded. Input shape: {[detail['shape'] for detail in input_details]}") |
|
|
|
|
|
logger.info("Loading tokenizer...") |
|
if os.path.exists(TOKENIZER_PATH): |
|
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) |
|
else: |
|
|
|
logger.warning("Local tokenizer not found, using online tokenizer") |
|
tokenizer = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p1") |
|
|
|
|
|
logger.info("Loading label encoder...") |
|
if os.path.exists(LABEL_ENCODER_PATH): |
|
with open(LABEL_ENCODER_PATH, 'rb') as f: |
|
label_encoder = pickle.load(f) |
|
else: |
|
|
|
logger.warning("Label encoder not found, using default labels") |
|
label_encoder = create_default_label_encoder() |
|
|
|
logger.info("All components loaded successfully!") |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading model: {str(e)}") |
|
raise e |
|
|
|
def create_default_label_encoder(): |
|
"""Create default label encoder jika file tidak ada""" |
|
class DefaultLabelEncoder: |
|
def __init__(self): |
|
|
|
self.classes_ = [ |
|
"Kebakaran", |
|
"Evakuasi/Penyelamatan Hewan", |
|
"Penyelamatan Non Hewan & Bantuan Teknis", |
|
"Lain-lain" |
|
] |
|
|
|
def inverse_transform(self, encoded): |
|
return [self.classes_[i] for i in encoded] |
|
|
|
return DefaultLabelEncoder() |
|
|
|
def predict_tflite(text: str) -> Dict[str, Any]: |
|
"""Fungsi prediksi menggunakan TFLite model""" |
|
global interpreter, tokenizer, label_encoder, input_details, output_details |
|
|
|
if not all([interpreter, tokenizer, label_encoder]): |
|
raise HTTPException(status_code=503, detail="Model components not loaded") |
|
|
|
try: |
|
|
|
interpreter.resize_tensor_input(0, [1, MAX_LENGTH]) |
|
interpreter.resize_tensor_input(1, [1, MAX_LENGTH]) |
|
interpreter.resize_tensor_input(2, [1, MAX_LENGTH]) |
|
interpreter.allocate_tensors() |
|
|
|
|
|
encoded = tokenizer( |
|
[text], |
|
max_length=MAX_LENGTH, |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors='np' |
|
) |
|
|
|
|
|
input_ids = encoded['input_ids'].astype(np.int32) |
|
token_type_ids = encoded['token_type_ids'].astype(np.int32) |
|
attention_mask = encoded['attention_mask'].astype(np.int32) |
|
|
|
|
|
interpreter.set_tensor(input_details[0]['index'], attention_mask) |
|
interpreter.set_tensor(input_details[1]['index'], input_ids) |
|
interpreter.set_tensor(input_details[2]['index'], token_type_ids) |
|
|
|
|
|
interpreter.invoke() |
|
|
|
|
|
output = interpreter.get_tensor(output_details[0]['index']) |
|
|
|
|
|
probabilities = tf.nn.softmax(output[0]).numpy() |
|
pred_encoded = np.argmax(output, axis=1) |
|
predicted_label = label_encoder.inverse_transform(pred_encoded)[0] |
|
confidence = float(np.max(probabilities)) |
|
|
|
return { |
|
"label": predicted_label, |
|
"confidence": confidence, |
|
"probabilities": { |
|
label: float(prob) for label, prob in zip(label_encoder.classes_, probabilities) |
|
} |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Prediction error: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") |
|
|
|
|
|
class InputText(BaseModel): |
|
text: str |
|
|
|
class PredictionResponse(BaseModel): |
|
label: str |
|
confidence: float |
|
probabilities: Dict[str, float] |
|
status: str = "success" |
|
|
|
|
|
HTML_TEMPLATE = """ |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>Damkar Classification</title> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<style> |
|
body { |
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
max-width: 800px; |
|
margin: 0 auto; |
|
padding: 20px; |
|
background-color: #f5f5f5; |
|
} |
|
.container { |
|
background: white; |
|
padding: 30px; |
|
border-radius: 10px; |
|
box-shadow: 0 2px 10px rgba(0,0,0,0.1); |
|
} |
|
h1 { |
|
color: #333; |
|
text-align: center; |
|
margin-bottom: 30px; |
|
} |
|
.form-group { |
|
margin-bottom: 20px; |
|
} |
|
label { |
|
display: block; |
|
margin-bottom: 8px; |
|
font-weight: bold; |
|
color: #555; |
|
} |
|
textarea { |
|
width: 100%; |
|
min-height: 120px; |
|
padding: 12px; |
|
border: 2px solid #ddd; |
|
border-radius: 6px; |
|
font-size: 14px; |
|
resize: vertical; |
|
box-sizing: border-box; |
|
} |
|
textarea:focus { |
|
outline: none; |
|
border-color: #007bff; |
|
} |
|
button { |
|
background-color: #007bff; |
|
color: white; |
|
padding: 12px 30px; |
|
border: none; |
|
border-radius: 6px; |
|
cursor: pointer; |
|
font-size: 16px; |
|
width: 100%; |
|
} |
|
button:hover { |
|
background-color: #0056b3; |
|
} |
|
button:disabled { |
|
background-color: #ccc; |
|
cursor: not-allowed; |
|
} |
|
.result { |
|
margin-top: 20px; |
|
padding: 15px; |
|
border-radius: 6px; |
|
display: none; |
|
} |
|
.result.success { |
|
background-color: #d4edda; |
|
border: 1px solid #c3e6cb; |
|
color: #155724; |
|
} |
|
.result.error { |
|
background-color: #f8d7da; |
|
border: 1px solid #f5c6cb; |
|
color: #721c24; |
|
} |
|
.loading { |
|
text-align: center; |
|
display: none; |
|
} |
|
.prob-item { |
|
display: flex; |
|
justify-content: space-between; |
|
margin: 5px 0; |
|
padding: 5px; |
|
background-color: #f8f9fa; |
|
border-radius: 4px; |
|
} |
|
.examples { |
|
margin-top: 20px; |
|
padding: 15px; |
|
background-color: #f8f9fa; |
|
border-radius: 6px; |
|
} |
|
.example-text { |
|
cursor: pointer; |
|
color: #007bff; |
|
text-decoration: underline; |
|
margin: 5px 0; |
|
} |
|
.example-text:hover { |
|
color: #0056b3; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<h1>🚒 Klasifikasi Laporan Damkar</h1> |
|
|
|
<div class="form-group"> |
|
<label for="textInput">Masukkan teks laporan:</label> |
|
<textarea id="textInput" placeholder="Contoh: ada kebakaran di gedung perkantoran..."></textarea> |
|
</div> |
|
|
|
<button onclick="predict()" id="predictBtn">Prediksi Kategori</button> |
|
|
|
<div class="loading" id="loading"> |
|
<p>⏳ Sedang memproses...</p> |
|
</div> |
|
|
|
<div class="result" id="result"></div> |
|
|
|
<div class="examples"> |
|
<h3>Contoh Teks:</h3> |
|
<div class="example-text" onclick="setExample('ada kebakaran di gedung perkantoran lantai 5')"> |
|
🔥 "ada kebakaran di gedung perkantoran lantai 5" |
|
</div> |
|
<div class="example-text" onclick="setExample('ular masuk ke dalam rumah warga')"> |
|
🐍 "ular masuk ke dalam rumah warga" |
|
</div> |
|
<div class="example-text" onclick="setExample('kucing terjebak di atas pohon tinggi')"> |
|
🐱 "kucing terjebak di atas pohon tinggi" |
|
</div> |
|
<div class="example-text" onclick="setExample('pohon tumbang menghalangi jalan raya')"> |
|
🌳 "pohon tumbang menghalangi jalan raya" |
|
</div> |
|
</div> |
|
</div> |
|
|
|
<script> |
|
function setExample(text) { |
|
document.getElementById('textInput').value = text; |
|
} |
|
|
|
async function predict() { |
|
const text = document.getElementById('textInput').value.trim(); |
|
const resultDiv = document.getElementById('result'); |
|
const loadingDiv = document.getElementById('loading'); |
|
const predictBtn = document.getElementById('predictBtn'); |
|
|
|
if (!text) { |
|
showResult('error', 'Mohon masukkan teks untuk diprediksi.'); |
|
return; |
|
} |
|
|
|
// Show loading |
|
loadingDiv.style.display = 'block'; |
|
resultDiv.style.display = 'none'; |
|
predictBtn.disabled = true; |
|
|
|
try { |
|
const response = await fetch('/predict', { |
|
method: 'POST', |
|
headers: { |
|
'Content-Type': 'application/json', |
|
}, |
|
body: JSON.stringify({ text: text }) |
|
}); |
|
|
|
const data = await response.json(); |
|
|
|
if (response.ok) { |
|
let resultHTML = ` |
|
<h3>Hasil Prediksi:</h3> |
|
<p><strong>Kategori:</strong> ${data.label}</p> |
|
<p><strong>Confidence:</strong> ${(data.confidence * 100).toFixed(2)}%</p> |
|
<h4>Detail Probabilitas:</h4> |
|
`; |
|
|
|
for (const [label, prob] of Object.entries(data.probabilities)) { |
|
const percentage = (prob * 100).toFixed(2); |
|
resultHTML += ` |
|
<div class="prob-item"> |
|
<span>${label}</span> |
|
<span>${percentage}%</span> |
|
</div> |
|
`; |
|
} |
|
|
|
showResult('success', resultHTML); |
|
} else { |
|
showResult('error', `Error: ${data.detail || 'Unknown error'}`); |
|
} |
|
} catch (error) { |
|
showResult('error', `Network error: ${error.message}`); |
|
} finally { |
|
loadingDiv.style.display = 'none'; |
|
predictBtn.disabled = false; |
|
} |
|
} |
|
|
|
function showResult(type, content) { |
|
const resultDiv = document.getElementById('result'); |
|
resultDiv.className = `result ${type}`; |
|
resultDiv.innerHTML = content; |
|
resultDiv.style.display = 'block'; |
|
} |
|
|
|
// Allow Enter key to submit |
|
document.getElementById('textInput').addEventListener('keypress', function(e) { |
|
if (e.key === 'Enter' && e.ctrlKey) { |
|
predict(); |
|
} |
|
}); |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
def read_root(): |
|
"""UI Interface untuk testing""" |
|
return HTML_TEMPLATE |
|
|
|
@app.get("/health") |
|
def health_check(): |
|
"""Health check endpoint""" |
|
global interpreter, tokenizer, label_encoder |
|
|
|
if not all([interpreter, tokenizer, label_encoder]): |
|
return {"status": "unhealthy", "message": "Model components not loaded"} |
|
|
|
return { |
|
"status": "healthy", |
|
"message": "TFLite model is ready", |
|
"model_info": { |
|
"input_shapes": [detail['shape'] for detail in input_details], |
|
"output_shape": output_details[0]['shape'] if output_details else None, |
|
"max_length": MAX_LENGTH |
|
} |
|
} |
|
|
|
@app.post("/predict", response_model=PredictionResponse) |
|
def predict(input: InputText): |
|
"""API endpoint untuk prediksi""" |
|
|
|
|
|
if not input.text or input.text.strip() == "": |
|
raise HTTPException(status_code=400, detail="Text input cannot be empty") |
|
|
|
try: |
|
|
|
result = predict_tflite(input.text) |
|
|
|
return PredictionResponse( |
|
label=result["label"], |
|
confidence=result["confidence"], |
|
probabilities=result["probabilities"] |
|
) |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"Unexpected error: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
@app.get("/test") |
|
def test_endpoint(): |
|
"""Test endpoint""" |
|
return { |
|
"message": "TFLite API is working!", |
|
"status": "ok", |
|
"endpoints": { |
|
"ui": "/", |
|
"predict": "/predict", |
|
"health": "/health", |
|
"docs": "/docs" |
|
} |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |