muskan19 commited on
Commit
b9d4152
·
verified ·
1 Parent(s): 3dad993

Update src/predict.py

Browse files
Files changed (1) hide show
  1. src/predict.py +17 -2
src/predict.py CHANGED
@@ -1,7 +1,22 @@
1
  from tensorflow.keras.models import load_model
2
 
3
- def load_trained_model(path):
4
- return load_model(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def run_prediction(model, processed_frame): # Renamed function
7
  prediction = model.predict(processed_frame)
 
1
  from tensorflow.keras.models import load_model
2
 
3
+ #def load_trained_model(path):
4
+ #return load_model(path)
5
+
6
+ import os
7
+ import urllib.request
8
+ from tensorflow.keras.models import load_model
9
+
10
+ def load_trained_model(path="violence_model.h5"):
11
+ url = "https://huggingface.co/spaces/muskan19/Violence_Detector/resolve/main/violence_model.h5"
12
+
13
+ if not os.path.exists(path):
14
+ print("Downloading model...")
15
+ urllib.request.urlretrieve(url, path)
16
+ print("Model downloaded.")
17
+
18
+ return load_model(path, compile=False)
19
+
20
 
21
  def run_prediction(model, processed_frame): # Renamed function
22
  prediction = model.predict(processed_frame)