s-ahal commited on
Commit
08b8cd7
·
verified ·
1 Parent(s): a3efd32

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +16 -19
server.py CHANGED
@@ -33,15 +33,11 @@ CORS(app)
33
  stable_diff_pipe = None
34
  model = None
35
 
36
- def load_models(model_name="CompVis/stable-diffusion-v1-4"):
37
  global stable_diff_pipe, model
38
 
39
  # Load Stable Diffusion model pipeline
40
- stable_diff_pipe = StableDiffusionPipeline.from_pretrained(
41
- model_name,
42
- safety_checker=None, # Disable safety checker for feature extraction
43
- requires_safety_checker=False # Explicitly indicate we don't need the safety checker
44
- )
45
  stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
46
 
47
  # Initialize MIDM model
@@ -70,20 +66,20 @@ def index():
70
 
71
  @app.route('/api/check-membership', methods=['POST'])
72
  def check_membership():
73
- # Get the model name from the request
74
- model_name = request.form.get('model', 'CompVis/stable-diffusion-v1-4')
75
-
76
- # Ensure models are loaded with the selected model
77
- if stable_diff_pipe is None or model is None:
78
- load_models(model_name)
79
- elif stable_diff_pipe.name_or_path != model_name:
80
- # Reload the model if a different one is selected
81
- load_models(model_name)
82
-
83
- if 'image' not in request.files:
84
- return jsonify({'error': 'No image found in request'}), 400
85
-
86
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  # Get the image from the request
88
  file = request.files['image']
89
  image_bytes = file.read()
@@ -109,6 +105,7 @@ def check_membership():
109
  })
110
 
111
  except Exception as e:
 
112
  return jsonify({'error': str(e)}), 500
113
 
114
  if __name__ == '__main__':
 
33
  stable_diff_pipe = None
34
  model = None
35
 
36
+ def load_models(model_name="rupeshs/LCM-runwayml-stable-diffusion-v1-5"):
37
  global stable_diff_pipe, model
38
 
39
  # Load Stable Diffusion model pipeline
40
+ stable_diff_pipe = StableDiffusionPipeline.from_pretrained(model_name)
 
 
 
 
41
  stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
  # Initialize MIDM model
 
66
 
67
  @app.route('/api/check-membership', methods=['POST'])
68
  def check_membership():
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  try:
70
+ # Get the model name from the request
71
+ model_name = request.form.get('model', 'rupeshs/LCM-runwayml-stable-diffusion-v1-5')
72
+
73
+ # Ensure models are loaded with the selected model
74
+ if stable_diff_pipe is None or model is None:
75
+ load_models(model_name)
76
+ elif stable_diff_pipe.name_or_path != model_name:
77
+ # Reload the model if a different one is selected
78
+ load_models(model_name)
79
+
80
+ if 'image' not in request.files:
81
+ return jsonify({'error': 'No image found in request'}), 400
82
+
83
  # Get the image from the request
84
  file = request.files['image']
85
  image_bytes = file.read()
 
105
  })
106
 
107
  except Exception as e:
108
+ print(f"Error processing request: {str(e)}") # Add logging
109
  return jsonify({'error': str(e)}), 500
110
 
111
  if __name__ == '__main__':