s-ahal commited on
Commit
176294d
·
verified ·
1 Parent(s): 08b8cd7

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +44 -20
server.py CHANGED
@@ -1,12 +1,14 @@
1
  import torch
2
  import torch.nn as nn
3
- from flask import Flask, request, jsonify, render_template
4
  from flask_cors import CORS
5
  import io
6
  import os
7
  from PIL import Image
8
  from diffusers import StableDiffusionPipeline
9
  import os
 
 
10
 
11
  token = os.getenv("HF_TOKEN")
12
 
@@ -33,6 +35,8 @@ CORS(app)
33
  stable_diff_pipe = None
34
  model = None
35
 
 
 
36
  def load_models(model_name="rupeshs/LCM-runwayml-stable-diffusion-v1-5"):
37
  global stable_diff_pipe, model
38
 
@@ -60,6 +64,13 @@ def extract_image_features(image):
60
 
61
  return generated_features
62
 
 
 
 
 
 
 
 
63
  @app.route('/')
64
  def index():
65
  return render_template('index.html')
@@ -67,45 +78,58 @@ def index():
67
  @app.route('/api/check-membership', methods=['POST'])
68
  def check_membership():
69
  try:
70
- # Get the model name from the request
71
  model_name = request.form.get('model', 'rupeshs/LCM-runwayml-stable-diffusion-v1-5')
72
-
 
 
73
  # Ensure models are loaded with the selected model
 
74
  if stable_diff_pipe is None or model is None:
75
  load_models(model_name)
76
  elif stable_diff_pipe.name_or_path != model_name:
77
- # Reload the model if a different one is selected
78
  load_models(model_name)
79
-
80
  if 'image' not in request.files:
81
  return jsonify({'error': 'No image found in request'}), 400
82
-
83
- # Get the image from the request
84
  file = request.files['image']
85
  image_bytes = file.read()
86
  image = Image.open(io.BytesIO(image_bytes))
87
-
88
- # Get image features using Stable Diffusion
89
  image_features = extract_image_features(image)
90
-
91
- # Preprocess the features for MIDM model
92
- processed_features = image_features.reshape(1, -1)[:, :10] # Select first 10 features
93
-
94
- # Perform inference
95
  with torch.no_grad():
96
  output = model(processed_features)
97
  probability = output.item()
98
  predicted = int(output > 0.5)
99
-
100
- return jsonify({
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  'probability': probability,
102
  'predicted_class': predicted,
103
  'message': f"Predicted membership probability: {probability}",
104
- 'is_in_training_data': "Likely" if predicted == 1 else "Unlikely"
105
- })
106
-
 
 
 
 
107
  except Exception as e:
108
- print(f"Error processing request: {str(e)}") # Add logging
109
  return jsonify({'error': str(e)}), 500
110
 
111
  if __name__ == '__main__':
 
1
  import torch
2
  import torch.nn as nn
3
+ from flask import Flask, request, jsonify, render_template, make_response
4
  from flask_cors import CORS
5
  import io
6
  import os
7
  from PIL import Image
8
  from diffusers import StableDiffusionPipeline
9
  import os
10
+ from tinydb import TinyDB, Query
11
+ import uuid
12
 
13
  token = os.getenv("HF_TOKEN")
14
 
 
35
  stable_diff_pipe = None
36
  model = None
37
 
38
+ db = TinyDB('artwork_results.json')
39
+
40
  def load_models(model_name="rupeshs/LCM-runwayml-stable-diffusion-v1-5"):
41
  global stable_diff_pipe, model
42
 
 
64
 
65
  return generated_features
66
 
67
+ # Helper to get or set a unique user_id via cookies
68
+ def get_user_id():
69
+ user_id = request.cookies.get('user_id')
70
+ if not user_id:
71
+ user_id = str(uuid.uuid4())
72
+ return user_id
73
+
74
  @app.route('/')
75
  def index():
76
  return render_template('index.html')
 
78
  @app.route('/api/check-membership', methods=['POST'])
79
  def check_membership():
80
  try:
 
81
  model_name = request.form.get('model', 'rupeshs/LCM-runwayml-stable-diffusion-v1-5')
82
+ contact_info = request.form.get('contact_info') # Optional, if frontend sends it
83
+ opt_in = request.form.get('opt_in', 'false').lower() == 'true'
84
+ user_id = get_user_id()
85
  # Ensure models are loaded with the selected model
86
+ global stable_diff_pipe, model
87
  if stable_diff_pipe is None or model is None:
88
  load_models(model_name)
89
  elif stable_diff_pipe.name_or_path != model_name:
 
90
  load_models(model_name)
 
91
  if 'image' not in request.files:
92
  return jsonify({'error': 'No image found in request'}), 400
 
 
93
  file = request.files['image']
94
  image_bytes = file.read()
95
  image = Image.open(io.BytesIO(image_bytes))
 
 
96
  image_features = extract_image_features(image)
97
+ processed_features = image_features.reshape(1, -1)[:, :10]
 
 
 
 
98
  with torch.no_grad():
99
  output = model(processed_features)
100
  probability = output.item()
101
  predicted = int(output > 0.5)
102
+ # Store in DB if likely and opt-in
103
+ if predicted == 1:
104
+ entry = {
105
+ 'model': model_name,
106
+ 'user_id': user_id,
107
+ 'timestamp': int(uuid.uuid1().time),
108
+ }
109
+ if opt_in and contact_info:
110
+ entry['contact_info'] = contact_info
111
+ # Only store if not already present for this user/model
112
+ User = Query()
113
+ exists = db.search((User.model == model_name) & (User.user_id == user_id))
114
+ if not exists:
115
+ db.insert(entry)
116
+ # Count unique user_ids for this model
117
+ User = Query()
118
+ user_ids = set(r['user_id'] for r in db.search(User.model == model_name))
119
+ count = len(user_ids)
120
+ response = make_response(jsonify({
121
  'probability': probability,
122
  'predicted_class': predicted,
123
  'message': f"Predicted membership probability: {probability}",
124
+ 'is_in_training_data': "Likely" if predicted == 1 else "Unlikely",
125
+ 'likely_count': count
126
+ }))
127
+ # Set user_id cookie if not present
128
+ if not request.cookies.get('user_id'):
129
+ response.set_cookie('user_id', user_id, max_age=60*60*24*365*5) # 5 years
130
+ return response
131
  except Exception as e:
132
+ print(f"Error processing request: {str(e)}")
133
  return jsonify({'error': str(e)}), 500
134
 
135
  if __name__ == '__main__':