ShahzadAli44 commited on
Commit
da19d2e
Β·
1 Parent(s): 940c691
Files changed (1) hide show
  1. app/main.py +6 -7
app/main.py CHANGED
@@ -1,4 +1,5 @@
1
- from tensorflow.keras.models import load_model
 
2
  from fastapi import FastAPI, File, UploadFile
3
  from fastapi.middleware.cors import CORSMiddleware
4
  import numpy as np
@@ -22,25 +23,23 @@ app.add_middleware(
22
 
23
  MODEL = None
24
 
25
- def load_model():
26
  global MODEL
27
  if MODEL is None:
28
  HF_MODEL_DIR = "/tmp/rice_model"
29
  os.makedirs(HF_MODEL_DIR, exist_ok=True)
30
  try:
31
-
32
  model_path = hf_hub_download(
33
  repo_id="ShahzadAli44/rice_cnn",
34
  filename="rice_cnn_model.keras",
35
  local_dir=HF_MODEL_DIR,
36
- local_dir_use_symlinks=False
37
  )
38
- MODEL = load_model(model_path)
39
  logging.info("βœ… Model loaded successfully.")
40
  except Exception as e:
41
  logging.error(f"❌ Failed to load model: {e}")
42
 
43
-
44
  CLASS_NAMES = [
45
  "bacterial_leaf_blight", "brown_spot", "healthy", "leaf_blast",
46
  "leaf_scald", "narrow_brown_spot", "rice_hispa", "sheath_blight", "tungro"
@@ -102,7 +101,7 @@ def home():
102
 
103
  @app.post("/predict")
104
  async def predict(file: UploadFile = File(...)):
105
- load_model()
106
  if MODEL is None:
107
  return {"error": "Model failed to load."}
108
 
 
1
+ from tensorflow.keras.models import load_model as keras_load_model
2
+
3
  from fastapi import FastAPI, File, UploadFile
4
  from fastapi.middleware.cors import CORSMiddleware
5
  import numpy as np
 
23
 
24
  MODEL = None
25
 
26
+ def initialize_model():
27
  global MODEL
28
  if MODEL is None:
29
  HF_MODEL_DIR = "/tmp/rice_model"
30
  os.makedirs(HF_MODEL_DIR, exist_ok=True)
31
  try:
 
32
  model_path = hf_hub_download(
33
  repo_id="ShahzadAli44/rice_cnn",
34
  filename="rice_cnn_model.keras",
35
  local_dir=HF_MODEL_DIR,
36
+ local_dir_use_symlinks=False
37
  )
38
+ MODEL = keras_load_model(model_path)
39
  logging.info("βœ… Model loaded successfully.")
40
  except Exception as e:
41
  logging.error(f"❌ Failed to load model: {e}")
42
 
 
43
  CLASS_NAMES = [
44
  "bacterial_leaf_blight", "brown_spot", "healthy", "leaf_blast",
45
  "leaf_scald", "narrow_brown_spot", "rice_hispa", "sheath_blight", "tungro"
 
101
 
102
  @app.post("/predict")
103
  async def predict(file: UploadFile = File(...)):
104
+ initialize_model()
105
  if MODEL is None:
106
  return {"error": "Model failed to load."}
107