adn commited on
Commit
53d7564
·
verified ·
1 Parent(s): cf5ac94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -77
app.py CHANGED
@@ -1,14 +1,12 @@
1
  from fastapi import FastAPI, Request, HTTPException
2
  from fastapi.responses import HTMLResponse
3
- from fastapi.staticfiles import StaticFiles
4
  from pydantic import BaseModel
5
  import tensorflow as tf
6
  import numpy as np
7
  import uvicorn
8
  import os
9
  import logging
10
- import pickle
11
- from typing import Dict, Any
12
  from transformers import AutoTokenizer
13
 
14
  # Setup logging
@@ -17,28 +15,26 @@ logger = logging.getLogger(__name__)
17
 
18
  # Configuration
19
  MODEL_PATH = "model.tflite"
20
- TOKENIZER_PATH = "tokenizer"
21
- LABEL_ENCODER_PATH = "label_encoder.pkl"
22
  MAX_LENGTH = 128
23
 
24
  # Inisialisasi FastAPI
25
  app = FastAPI(
26
  title="Damkar Classification API (TFLite)",
27
- description="API untuk klasifikasi tipe laporan damkar menggunakan TFLite model",
28
  version="1.0.0"
29
  )
30
 
31
  # Global variables
32
  interpreter = None
33
  tokenizer = None
34
- label_encoder = None
35
  input_details = None
36
  output_details = None
37
 
38
  @app.on_event("startup")
39
  async def load_model():
40
  """Load model dan dependencies saat aplikasi startup"""
41
- global interpreter, tokenizer, label_encoder, input_details, output_details
42
 
43
  try:
44
  logger.info("Loading TFLite model...")
@@ -54,62 +50,36 @@ async def load_model():
54
  input_details = interpreter.get_input_details()
55
  output_details = interpreter.get_output_details()
56
 
57
- logger.info(f"Model loaded. Input shape: {[detail['shape'] for detail in input_details]}")
 
 
58
 
59
  # Load tokenizer
60
  logger.info("Loading tokenizer...")
61
  if os.path.exists(TOKENIZER_PATH):
62
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
63
  else:
64
- # Fallback ke tokenizer online jika tidak ada lokal
65
  logger.warning("Local tokenizer not found, using online tokenizer")
66
  tokenizer = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p1")
67
 
68
- # Load label encoder
69
- logger.info("Loading label encoder...")
70
- if os.path.exists(LABEL_ENCODER_PATH):
71
- with open(LABEL_ENCODER_PATH, 'rb') as f:
72
- label_encoder = pickle.load(f)
73
- else:
74
- # Default labels jika tidak ada label encoder
75
- logger.warning("Label encoder not found, using default labels")
76
- label_encoder = create_default_label_encoder()
77
-
78
  logger.info("All components loaded successfully!")
79
 
80
  except Exception as e:
81
  logger.error(f"Error loading model: {str(e)}")
82
  raise e
83
 
84
- def create_default_label_encoder():
85
- """Create default label encoder jika file tidak ada"""
86
- class DefaultLabelEncoder:
87
- def __init__(self):
88
- # Sesuaikan dengan kategori yang Anda miliki
89
- self.classes_ = [
90
- "Kebakaran",
91
- "Evakuasi/Penyelamatan Hewan",
92
- "Penyelamatan Non Hewan & Bantuan Teknis",
93
- "Lain-lain"
94
- ]
95
-
96
- def inverse_transform(self, encoded):
97
- return [self.classes_[i] for i in encoded]
98
-
99
- return DefaultLabelEncoder()
100
-
101
  def predict_tflite(text: str) -> Dict[str, Any]:
102
- """Fungsi prediksi menggunakan TFLite model"""
103
- global interpreter, tokenizer, label_encoder, input_details, output_details
104
 
105
- if not all([interpreter, tokenizer, label_encoder]):
106
  raise HTTPException(status_code=503, detail="Model components not loaded")
107
 
108
  try:
109
  # Resize input tensors
110
- interpreter.resize_tensor_input(0, [1, MAX_LENGTH]) # attention_mask
111
- interpreter.resize_tensor_input(1, [1, MAX_LENGTH]) # input_ids
112
- interpreter.resize_tensor_input(2, [1, MAX_LENGTH]) # token_type_ids
113
  interpreter.allocate_tensors()
114
 
115
  # Tokenize text
@@ -126,7 +96,7 @@ def predict_tflite(text: str) -> Dict[str, Any]:
126
  token_type_ids = encoded['token_type_ids'].astype(np.int32)
127
  attention_mask = encoded['attention_mask'].astype(np.int32)
128
 
129
- # Set tensors
130
  interpreter.set_tensor(input_details[0]['index'], attention_mask)
131
  interpreter.set_tensor(input_details[1]['index'], input_ids)
132
  interpreter.set_tensor(input_details[2]['index'], token_type_ids)
@@ -134,20 +104,25 @@ def predict_tflite(text: str) -> Dict[str, Any]:
134
  # Run inference
135
  interpreter.invoke()
136
 
137
- # Get output
138
- output = interpreter.get_tensor(output_details[0]['index'])
 
 
 
139
 
140
- # Get predictions
141
- probabilities = tf.nn.softmax(output[0]).numpy()
142
- pred_encoded = np.argmax(output, axis=1)
143
- predicted_label = label_encoder.inverse_transform(pred_encoded)[0]
144
- confidence = float(np.max(probabilities))
145
 
146
  return {
147
- "label": predicted_label,
148
- "confidence": confidence,
149
- "probabilities": {
150
- label: float(prob) for label, prob in zip(label_encoder.classes_, probabilities)
 
 
 
 
151
  }
152
  }
153
 
@@ -160,9 +135,12 @@ class InputText(BaseModel):
160
  text: str
161
 
162
  class PredictionResponse(BaseModel):
163
- label: str
164
  confidence: float
165
- probabilities: Dict[str, float]
 
 
 
166
  status: str = "success"
167
 
168
  # HTML template untuk UI
@@ -170,13 +148,13 @@ HTML_TEMPLATE = """
170
  <!DOCTYPE html>
171
  <html>
172
  <head>
173
- <title>Damkar Classification</title>
174
  <meta charset="UTF-8">
175
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
176
  <style>
177
  body {
178
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
179
- max-width: 800px;
180
  margin: 0 auto;
181
  padding: 20px;
182
  background-color: #f5f5f5;
@@ -256,9 +234,10 @@ HTML_TEMPLATE = """
256
  display: flex;
257
  justify-content: space-between;
258
  margin: 5px 0;
259
- padding: 5px;
260
  background-color: #f8f9fa;
261
  border-radius: 4px;
 
262
  }
263
  .examples {
264
  margin-top: 20px;
@@ -275,11 +254,29 @@ HTML_TEMPLATE = """
275
  .example-text:hover {
276
  color: #0056b3;
277
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  </style>
279
  </head>
280
  <body>
281
  <div class="container">
282
- <h1>🚒 Klasifikasi Laporan Damkar</h1>
 
283
 
284
  <div class="form-group">
285
  <label for="textInput">Masukkan teks laporan:</label>
@@ -346,20 +343,33 @@ HTML_TEMPLATE = """
346
  if (response.ok) {
347
  let resultHTML = `
348
  <h3>Hasil Prediksi:</h3>
349
- <p><strong>Kategori:</strong> ${data.label}</p>
350
  <p><strong>Confidence:</strong> ${(data.confidence * 100).toFixed(2)}%</p>
351
- <h4>Detail Probabilitas:</h4>
 
 
 
 
 
 
 
352
  `;
353
 
354
- for (const [label, prob] of Object.entries(data.probabilities)) {
355
- const percentage = (prob * 100).toFixed(2);
 
356
  resultHTML += `
357
- <div class="prob-item">
358
- <span>${label}</span>
359
  <span>${percentage}%</span>
360
  </div>
361
  `;
362
- }
 
 
 
 
 
363
 
364
  showResult('success', resultHTML);
365
  } else {
@@ -400,17 +410,29 @@ def read_root():
400
  @app.get("/health")
401
  def health_check():
402
  """Health check endpoint"""
403
- global interpreter, tokenizer, label_encoder
404
 
405
- if not all([interpreter, tokenizer, label_encoder]):
406
  return {"status": "unhealthy", "message": "Model components not loaded"}
407
 
408
  return {
409
  "status": "healthy",
410
  "message": "TFLite model is ready",
411
  "model_info": {
412
- "input_shapes": [detail['shape'] for detail in input_details],
413
- "output_shape": output_details[0]['shape'] if output_details else None,
 
 
 
 
 
 
 
 
 
 
 
 
414
  "max_length": MAX_LENGTH
415
  }
416
  }
@@ -427,11 +449,7 @@ def predict(input: InputText):
427
  # Lakukan prediksi
428
  result = predict_tflite(input.text)
429
 
430
- return PredictionResponse(
431
- label=result["label"],
432
- confidence=result["confidence"],
433
- probabilities=result["probabilities"]
434
- )
435
 
436
  except HTTPException:
437
  raise
@@ -445,6 +463,7 @@ def test_endpoint():
445
  return {
446
  "message": "TFLite API is working!",
447
  "status": "ok",
 
448
  "endpoints": {
449
  "ui": "/",
450
  "predict": "/predict",
 
1
  from fastapi import FastAPI, Request, HTTPException
2
  from fastapi.responses import HTMLResponse
 
3
  from pydantic import BaseModel
4
  import tensorflow as tf
5
  import numpy as np
6
  import uvicorn
7
  import os
8
  import logging
9
+ from typing import Dict, Any, List
 
10
  from transformers import AutoTokenizer
11
 
12
  # Setup logging
 
15
 
16
  # Configuration
17
  MODEL_PATH = "model.tflite"
18
+ TOKENIZER_PATH = "tokenizer"
 
19
  MAX_LENGTH = 128
20
 
21
  # Inisialisasi FastAPI
22
  app = FastAPI(
23
  title="Damkar Classification API (TFLite)",
24
+ description="API untuk klasifikasi tipe laporan damkar menggunakan TFLite model - Raw Output",
25
  version="1.0.0"
26
  )
27
 
28
  # Global variables
29
  interpreter = None
30
  tokenizer = None
 
31
  input_details = None
32
  output_details = None
33
 
34
  @app.on_event("startup")
35
  async def load_model():
36
  """Load model dan dependencies saat aplikasi startup"""
37
+ global interpreter, tokenizer, input_details, output_details
38
 
39
  try:
40
  logger.info("Loading TFLite model...")
 
50
  input_details = interpreter.get_input_details()
51
  output_details = interpreter.get_output_details()
52
 
53
+ logger.info(f"Model loaded successfully!")
54
+ logger.info(f"Input details: {input_details}")
55
+ logger.info(f"Output details: {output_details}")
56
 
57
  # Load tokenizer
58
  logger.info("Loading tokenizer...")
59
  if os.path.exists(TOKENIZER_PATH):
60
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
61
  else:
 
62
  logger.warning("Local tokenizer not found, using online tokenizer")
63
  tokenizer = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p1")
64
 
 
 
 
 
 
 
 
 
 
 
65
  logger.info("All components loaded successfully!")
66
 
67
  except Exception as e:
68
  logger.error(f"Error loading model: {str(e)}")
69
  raise e
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def predict_tflite(text: str) -> Dict[str, Any]:
72
+ """Fungsi prediksi menggunakan TFLite model - mengembalikan raw output"""
73
+ global interpreter, tokenizer, input_details, output_details
74
 
75
+ if not all([interpreter, tokenizer]):
76
  raise HTTPException(status_code=503, detail="Model components not loaded")
77
 
78
  try:
79
  # Resize input tensors
80
+ interpreter.resize_tensor_input(input_details[0]['index'], [1, MAX_LENGTH])
81
+ interpreter.resize_tensor_input(input_details[1]['index'], [1, MAX_LENGTH])
82
+ interpreter.resize_tensor_input(input_details[2]['index'], [1, MAX_LENGTH])
83
  interpreter.allocate_tensors()
84
 
85
  # Tokenize text
 
96
  token_type_ids = encoded['token_type_ids'].astype(np.int32)
97
  attention_mask = encoded['attention_mask'].astype(np.int32)
98
 
99
+ # Set tensors - gunakan urutan yang benar
100
  interpreter.set_tensor(input_details[0]['index'], attention_mask)
101
  interpreter.set_tensor(input_details[1]['index'], input_ids)
102
  interpreter.set_tensor(input_details[2]['index'], token_type_ids)
 
104
  # Run inference
105
  interpreter.invoke()
106
 
107
+ # Get raw output
108
+ raw_output = interpreter.get_tensor(output_details[0]['index'])
109
+
110
+ # Hitung probabilitas dengan softmax
111
+ probabilities = tf.nn.softmax(raw_output[0]).numpy()
112
 
113
+ # Prediksi kelas (index dengan probabilitas tertinggi)
114
+ predicted_class_index = int(np.argmax(raw_output, axis=1)[0])
115
+ max_confidence = float(np.max(probabilities))
 
 
116
 
117
  return {
118
+ "predicted_class_index": predicted_class_index,
119
+ "confidence": max_confidence,
120
+ "raw_output": raw_output[0].tolist(), # Convert numpy array to list
121
+ "probabilities": probabilities.tolist(),
122
+ "input_text": text,
123
+ "model_info": {
124
+ "output_shape": raw_output.shape,
125
+ "num_classes": len(probabilities)
126
  }
127
  }
128
 
 
135
  text: str
136
 
137
  class PredictionResponse(BaseModel):
138
+ predicted_class_index: int
139
  confidence: float
140
+ raw_output: List[float]
141
+ probabilities: List[float]
142
+ input_text: str
143
+ model_info: Dict[str, Any]
144
  status: str = "success"
145
 
146
  # HTML template untuk UI
 
148
  <!DOCTYPE html>
149
  <html>
150
  <head>
151
+ <title>Damkar Classification - Raw Output</title>
152
  <meta charset="UTF-8">
153
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
154
  <style>
155
  body {
156
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
157
+ max-width: 900px;
158
  margin: 0 auto;
159
  padding: 20px;
160
  background-color: #f5f5f5;
 
234
  display: flex;
235
  justify-content: space-between;
236
  margin: 5px 0;
237
+ padding: 8px;
238
  background-color: #f8f9fa;
239
  border-radius: 4px;
240
+ font-family: monospace;
241
  }
242
  .examples {
243
  margin-top: 20px;
 
254
  .example-text:hover {
255
  color: #0056b3;
256
  }
257
+ .raw-output {
258
+ background-color: #f8f9fa;
259
+ padding: 10px;
260
+ border-radius: 4px;
261
+ font-family: monospace;
262
+ font-size: 12px;
263
+ margin: 10px 0;
264
+ max-height: 200px;
265
+ overflow-y: auto;
266
+ }
267
+ .model-info {
268
+ background-color: #e7f3ff;
269
+ padding: 10px;
270
+ border-radius: 4px;
271
+ margin: 10px 0;
272
+ font-size: 14px;
273
+ }
274
  </style>
275
  </head>
276
  <body>
277
  <div class="container">
278
+ <h1>🚒 Damkar Classification - Raw Output</h1>
279
+ <p style="text-align: center; color: #666;">Menampilkan output mentah dari TFLite model tanpa label encoder</p>
280
 
281
  <div class="form-group">
282
  <label for="textInput">Masukkan teks laporan:</label>
 
343
  if (response.ok) {
344
  let resultHTML = `
345
  <h3>Hasil Prediksi:</h3>
346
+ <p><strong>Predicted Class Index:</strong> ${data.predicted_class_index}</p>
347
  <p><strong>Confidence:</strong> ${(data.confidence * 100).toFixed(2)}%</p>
348
+
349
+ <div class="model-info">
350
+ <strong>Model Info:</strong><br>
351
+ Output Shape: ${JSON.stringify(data.model_info.output_shape)}<br>
352
+ Number of Classes: ${data.model_info.num_classes}
353
+ </div>
354
+
355
+ <h4>Probabilitas per Class:</h4>
356
  `;
357
 
358
+ data.probabilities.forEach((prob, index) => {
359
+ const percentage = (prob * 100).toFixed(4);
360
+ const isMax = index === data.predicted_class_index;
361
  resultHTML += `
362
+ <div class="prob-item" style="${isMax ? 'background-color: #fff3cd; font-weight: bold;' : ''}">
363
+ <span>Class ${index}</span>
364
  <span>${percentage}%</span>
365
  </div>
366
  `;
367
+ });
368
+
369
+ resultHTML += `
370
+ <h4>Raw Output (Logits):</h4>
371
+ <div class="raw-output">${JSON.stringify(data.raw_output, null, 2)}</div>
372
+ `;
373
 
374
  showResult('success', resultHTML);
375
  } else {
 
410
  @app.get("/health")
411
  def health_check():
412
  """Health check endpoint"""
413
+ global interpreter, tokenizer
414
 
415
+ if not all([interpreter, tokenizer]):
416
  return {"status": "unhealthy", "message": "Model components not loaded"}
417
 
418
  return {
419
  "status": "healthy",
420
  "message": "TFLite model is ready",
421
  "model_info": {
422
+ "input_details": [
423
+ {
424
+ "name": detail.get('name', f'input_{i}'),
425
+ "shape": detail['shape'].tolist(),
426
+ "dtype": str(detail['dtype'])
427
+ } for i, detail in enumerate(input_details)
428
+ ],
429
+ "output_details": [
430
+ {
431
+ "name": detail.get('name', f'output_{i}'),
432
+ "shape": detail['shape'].tolist(),
433
+ "dtype": str(detail['dtype'])
434
+ } for i, detail in enumerate(output_details)
435
+ ],
436
  "max_length": MAX_LENGTH
437
  }
438
  }
 
449
  # Lakukan prediksi
450
  result = predict_tflite(input.text)
451
 
452
+ return PredictionResponse(**result)
 
 
 
 
453
 
454
  except HTTPException:
455
  raise
 
463
  return {
464
  "message": "TFLite API is working!",
465
  "status": "ok",
466
+ "version": "raw_output",
467
  "endpoints": {
468
  "ui": "/",
469
  "predict": "/predict",