muskan19 commited on
Commit
1ba0338
·
verified ·
1 Parent(s): b9d4152

Update src/predict.py

Browse files
Files changed (1) hide show
  1. src/predict.py +19 -2
src/predict.py CHANGED
@@ -2,7 +2,7 @@ from tensorflow.keras.models import load_model
2
 
3
  #def load_trained_model(path):
4
  #return load_model(path)
5
-
6
  import os
7
  import urllib.request
8
  from tensorflow.keras.models import load_model
@@ -16,10 +16,27 @@ def load_trained_model(path="violence_model.h5"):
16
  print("Model downloaded.")
17
 
18
  return load_model(path, compile=False)
19
-
20
 
21
  def run_prediction(model, processed_frame): # Renamed function
22
  prediction = model.predict(processed_frame)
23
  return prediction[0][0]
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  #def load_trained_model(path):
4
  #return load_model(path)
5
+ '''
6
  import os
7
  import urllib.request
8
  from tensorflow.keras.models import load_model
 
16
  print("Model downloaded.")
17
 
18
  return load_model(path, compile=False)
19
+ '''
20
 
21
  def run_prediction(model, processed_frame): # Renamed function
22
  prediction = model.predict(processed_frame)
23
  return prediction[0][0]
24
 
25
 
26
+
27
+ import os
28
+ import urllib.request
29
+ from tensorflow.keras.models import load_model
30
+
31
+ def load_trained_model(path="violence_model.h5"):
32
+ url = "https://huggingface.co/spaces/muskan19/Violence_Detector/resolve/main/violence_model.h5"
33
+
34
+ if not os.path.exists(path):
35
+ print("Downloading model...")
36
+ urllib.request.urlretrieve(url, path)
37
+ print("Download complete.")
38
+
39
+ return load_model(path, compile=False)
40
+
41
+
42
+