File size: 4,993 Bytes
ff13394 176294d ff13394 7e1e741 9a5bfef 176294d 9a5bfef ff13394 7e1e741 ff13394 176294d 08b8cd7 7e1e741 ff13394 7e1e741 08b8cd7 7e1e741 ff13394 c176b63 ff13394 7e1e741 c176b63 7e1e741 ff13394 7e1e741 ff13394 176294d ff13394 08b8cd7 176294d 08b8cd7 176294d 08b8cd7 ff13394 7e1e741 176294d ff13394 176294d ff13394 176294d ff13394 176294d ff13394 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
import torch.nn as nn
from flask import Flask, request, jsonify, render_template, make_response
from flask_cors import CORS
import io
import os
from PIL import Image
from diffusers import StableDiffusionPipeline
import os
from tinydb import TinyDB, Query
import uuid
token = os.getenv("HF_TOKEN")
# Define the MIDM model
class MIDM(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(MIDM, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.sigmoid(out)
return out
app = Flask(__name__, static_folder='static', template_folder='templates')
CORS(app)
# Load models once when the app starts to avoid reloading for each request
stable_diff_pipe = None
model = None
db = TinyDB('artwork_results.json')
def load_models(model_name="rupeshs/LCM-runwayml-stable-diffusion-v1-5"):
global stable_diff_pipe, model
# Load Stable Diffusion model pipeline
stable_diff_pipe = StableDiffusionPipeline.from_pretrained(model_name)
stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
# Initialize MIDM model
input_dim = 10
hidden_dim = 64
output_dim = 1
model = MIDM(input_dim, hidden_dim, output_dim)
model.eval()
# Function to extract features from the image using Stable Diffusion
def extract_image_features(image):
#Extracts image features using the Stable Diffusion pipeline.
# Preprocess the image and get the feature vector
image_input = stable_diff_pipe.feature_extractor(image, return_tensors="pt").pixel_values.to(stable_diff_pipe.device)
# Generate the image embedding using the model
with torch.no_grad():
generated_features = stable_diff_pipe.vae.encode(image_input).latent_dist.mean
return generated_features
# Helper to get or set a unique user_id via cookies
def get_user_id():
user_id = request.cookies.get('user_id')
if not user_id:
user_id = str(uuid.uuid4())
return user_id
@app.route('/')
def index():
return render_template('index.html')
@app.route('/api/check-membership', methods=['POST'])
def check_membership():
try:
model_name = request.form.get('model', 'rupeshs/LCM-runwayml-stable-diffusion-v1-5')
contact_info = request.form.get('contact_info') # Optional, if frontend sends it
opt_in = request.form.get('opt_in', 'false').lower() == 'true'
user_id = get_user_id()
# Ensure models are loaded with the selected model
global stable_diff_pipe, model
if stable_diff_pipe is None or model is None:
load_models(model_name)
elif stable_diff_pipe.name_or_path != model_name:
load_models(model_name)
if 'image' not in request.files:
return jsonify({'error': 'No image found in request'}), 400
file = request.files['image']
image_bytes = file.read()
image = Image.open(io.BytesIO(image_bytes))
image_features = extract_image_features(image)
processed_features = image_features.reshape(1, -1)[:, :10]
with torch.no_grad():
output = model(processed_features)
probability = output.item()
predicted = int(output > 0.5)
# Store in DB if likely and opt-in
if predicted == 1:
entry = {
'model': model_name,
'user_id': user_id,
'timestamp': int(uuid.uuid1().time),
}
if opt_in and contact_info:
entry['contact_info'] = contact_info
# Only store if not already present for this user/model
User = Query()
exists = db.search((User.model == model_name) & (User.user_id == user_id))
if not exists:
db.insert(entry)
# Count unique user_ids for this model
User = Query()
user_ids = set(r['user_id'] for r in db.search(User.model == model_name))
count = len(user_ids)
response = make_response(jsonify({
'probability': probability,
'predicted_class': predicted,
'message': f"Predicted membership probability: {probability}",
'is_in_training_data': "Likely" if predicted == 1 else "Unlikely",
'likely_count': count
}))
# Set user_id cookie if not present
if not request.cookies.get('user_id'):
response.set_cookie('user_id', user_id, max_age=60*60*24*365*5) # 5 years
return response
except Exception as e:
print(f"Error processing request: {str(e)}")
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port) |