Spaces:
Sleeping
Sleeping
Commit
Β·
da19d2e
1
Parent(s):
940c691
add code
Browse files- app/main.py +6 -7
app/main.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
from tensorflow.keras.models import load_model
|
|
|
2 |
from fastapi import FastAPI, File, UploadFile
|
3 |
from fastapi.middleware.cors import CORSMiddleware
|
4 |
import numpy as np
|
@@ -22,25 +23,23 @@ app.add_middleware(
|
|
22 |
|
23 |
MODEL = None
|
24 |
|
25 |
-
def
|
26 |
global MODEL
|
27 |
if MODEL is None:
|
28 |
HF_MODEL_DIR = "/tmp/rice_model"
|
29 |
os.makedirs(HF_MODEL_DIR, exist_ok=True)
|
30 |
try:
|
31 |
-
|
32 |
model_path = hf_hub_download(
|
33 |
repo_id="ShahzadAli44/rice_cnn",
|
34 |
filename="rice_cnn_model.keras",
|
35 |
local_dir=HF_MODEL_DIR,
|
36 |
-
local_dir_use_symlinks=False
|
37 |
)
|
38 |
-
MODEL =
|
39 |
logging.info("β
Model loaded successfully.")
|
40 |
except Exception as e:
|
41 |
logging.error(f"β Failed to load model: {e}")
|
42 |
|
43 |
-
|
44 |
CLASS_NAMES = [
|
45 |
"bacterial_leaf_blight", "brown_spot", "healthy", "leaf_blast",
|
46 |
"leaf_scald", "narrow_brown_spot", "rice_hispa", "sheath_blight", "tungro"
|
@@ -102,7 +101,7 @@ def home():
|
|
102 |
|
103 |
@app.post("/predict")
|
104 |
async def predict(file: UploadFile = File(...)):
|
105 |
-
|
106 |
if MODEL is None:
|
107 |
return {"error": "Model failed to load."}
|
108 |
|
|
|
1 |
+
from tensorflow.keras.models import load_model as keras_load_model
|
2 |
+
|
3 |
from fastapi import FastAPI, File, UploadFile
|
4 |
from fastapi.middleware.cors import CORSMiddleware
|
5 |
import numpy as np
|
|
|
23 |
|
24 |
MODEL = None
|
25 |
|
26 |
+
def initialize_model():
|
27 |
global MODEL
|
28 |
if MODEL is None:
|
29 |
HF_MODEL_DIR = "/tmp/rice_model"
|
30 |
os.makedirs(HF_MODEL_DIR, exist_ok=True)
|
31 |
try:
|
|
|
32 |
model_path = hf_hub_download(
|
33 |
repo_id="ShahzadAli44/rice_cnn",
|
34 |
filename="rice_cnn_model.keras",
|
35 |
local_dir=HF_MODEL_DIR,
|
36 |
+
local_dir_use_symlinks=False
|
37 |
)
|
38 |
+
MODEL = keras_load_model(model_path)
|
39 |
logging.info("β
Model loaded successfully.")
|
40 |
except Exception as e:
|
41 |
logging.error(f"β Failed to load model: {e}")
|
42 |
|
|
|
43 |
CLASS_NAMES = [
|
44 |
"bacterial_leaf_blight", "brown_spot", "healthy", "leaf_blast",
|
45 |
"leaf_scald", "narrow_brown_spot", "rice_hispa", "sheath_blight", "tungro"
|
|
|
101 |
|
102 |
@app.post("/predict")
|
103 |
async def predict(file: UploadFile = File(...)):
|
104 |
+
initialize_model()
|
105 |
if MODEL is None:
|
106 |
return {"error": "Model failed to load."}
|
107 |
|