from fastapi import FastAPI, Request, HTTPException from fastapi.responses import HTMLResponse from pydantic import BaseModel import tensorflow as tf import numpy as np import uvicorn import os import logging from typing import Dict, Any, List from transformers import AutoTokenizer import json # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Configuration MODEL_PATH = "model.tflite" TOKENIZER_PATH = "tokenizer" MAX_LENGTH = 128 # Class label mapping CLASS_LABELS = { 0: "Evakuasi/Penyelamatan Hewan", 1: "Kebakaran", 2: "Layanan Lingkungan & Fasilitas Umum", 3: "Penyelamatan Non Hewan & Bantuan Teknis" } # Inisialisasi FastAPI app = FastAPI( title="Damkar Classification API (TFLite)", description="API untuk klasifikasi tipe laporan damkar menggunakan TFLite model", version="1.1.0" ) # Global variables interpreter = None tokenizer = None input_details = None output_details = None @app.on_event("startup") async def load_model(): """Load model dan dependencies saat aplikasi startup""" global interpreter, tokenizer, input_details, output_details try: logger.info("Loading TFLite model...") # Load TFLite model if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model file not found: {MODEL_PATH}") interpreter = tf.lite.Interpreter(model_path=MODEL_PATH) interpreter.allocate_tensors() # Get input/output details input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() logger.info(f"Model loaded successfully!") logger.info(f"Input details: {input_details}") logger.info(f"Output details: {output_details}") # Load tokenizer logger.info("Loading tokenizer...") if os.path.exists(TOKENIZER_PATH): tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) else: logger.warning("Local tokenizer not found, using online tokenizer") tokenizer = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p1") logger.info("All components loaded successfully!") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise e def predict_tflite(text: str) -> Dict[str, Any]: """Fungsi prediksi menggunakan TFLite model - mengembalikan output dengan label""" global interpreter, tokenizer, input_details, output_details if not all([interpreter, tokenizer]): raise HTTPException(status_code=503, detail="Model components not loaded") try: # Resize input tensors (jika diperlukan) interpreter.resize_tensor_input(input_details[0]['index'], [1, MAX_LENGTH]) interpreter.resize_tensor_input(input_details[1]['index'], [1, MAX_LENGTH]) interpreter.resize_tensor_input(input_details[2]['index'], [1, MAX_LENGTH]) interpreter.allocate_tensors() # Tokenize text encoded = tokenizer( [text], max_length=MAX_LENGTH, padding='max_length', truncation=True, return_tensors='np' ) # Convert to int32 for TFLite input_ids = encoded['input_ids'].astype(np.int32) token_type_ids = encoded['token_type_ids'].astype(np.int32) attention_mask = encoded['attention_mask'].astype(np.int32) # Set tensors - gunakan urutan yang benar sesuai model interpreter.set_tensor(input_details[0]['index'], attention_mask) interpreter.set_tensor(input_details[1]['index'], input_ids) interpreter.set_tensor(input_details[2]['index'], token_type_ids) # Run inference interpreter.invoke() # Get raw output (logits) raw_output = interpreter.get_tensor(output_details[0]['index']) # Hitung probabilitas dengan softmax probabilities = tf.nn.softmax(raw_output[0]).numpy() # Prediksi kelas (index dengan probabilitas tertinggi) predicted_class_index = int(np.argmax(raw_output, axis=1)[0]) max_confidence = float(np.max(probabilities)) # Dapatkan label kelas dari index predicted_class_label = CLASS_LABELS.get(predicted_class_index, "Unknown Class") return { "predicted_class_index": predicted_class_index, "predicted_class_label": predicted_class_label, "confidence": max_confidence, "raw_output": raw_output[0].tolist(), # Convert numpy array to list "probabilities": probabilities.tolist(), "input_text": text, "model_info": { "output_shape": raw_output.shape, "num_classes": len(probabilities) } } except Exception as e: logger.error(f"Prediction error: {str(e)}") raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") # Request/Response models class InputText(BaseModel): text: str class PredictionResponse(BaseModel): predicted_class_index: int predicted_class_label: str confidence: float raw_output: List[float] probabilities: List[float] input_text: str model_info: Dict[str, Any] status: str = "success" # HTML template untuk UI HTML_TEMPLATE = """ Damkar Classification

🚒 Klasifikasi Laporan Damkar

Masukkan teks laporan untuk diklasifikasikan oleh model AI.

⏳ Sedang memproses...

Contoh Teks:

🔥 "ada kebakaran di gedung perkantoran lantai 5"
🐍 "ular masuk ke dalam rumah warga"
🌳 "pohon tumbang menghalangi jalan raya"
💍 "cincin tidak bisa dilepas dari jari"
""" # Routes @app.get("/", response_class=HTMLResponse) def read_root(): """UI Interface untuk testing""" return HTML_TEMPLATE @app.get("/health") def health_check(): """Health check endpoint""" global interpreter, tokenizer if not all([interpreter, tokenizer]): return {"status": "unhealthy", "message": "Model components not loaded"} return { "status": "healthy", "message": "TFLite model is ready", "model_info": { "input_details": [ { "name": detail.get('name', f'input_{i}'), "shape": detail['shape'].tolist(), "dtype": str(detail['dtype']) } for i, detail in enumerate(input_details) ], "output_details": [ { "name": detail.get('name', f'output_{i}'), "shape": detail['shape'].tolist(), "dtype": str(detail['dtype']) } for i, detail in enumerate(output_details) ], "max_length": MAX_LENGTH, "class_labels": CLASS_LABELS } } @app.post("/predict", response_model=PredictionResponse) def predict(input: InputText): """API endpoint untuk prediksi""" # Validasi input if not input.text or input.text.strip() == "": raise HTTPException(status_code=400, detail="Text input cannot be empty") try: # Lakukan prediksi result = predict_tflite(input.text) return PredictionResponse(**result) except HTTPException: raise except Exception as e: logger.error(f"Unexpected error: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.get("/test") def test_endpoint(): """Test endpoint""" return { "message": "TFLite API is working!", "status": "ok", "version": "1.1.0", "endpoints": { "ui": "/", "predict": "/predict", "health": "/health", "docs": "/docs" } } # Jalankan lokal (untuk development) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)