import torch import torch.nn as nn from flask import Flask, request, jsonify, render_template, make_response from flask_cors import CORS import io import os from PIL import Image from diffusers import StableDiffusionPipeline import os from tinydb import TinyDB, Query import uuid token = os.getenv("HF_TOKEN") # Define the MIDM model class MIDM(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(MIDM, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, output_dim) self.sigmoid = nn.Sigmoid() def forward(self, x): out = self.fc1(x) out = self.relu(out) out = self.fc2(out) out = self.sigmoid(out) return out app = Flask(__name__, static_folder='static', template_folder='templates') CORS(app) # Load models once when the app starts to avoid reloading for each request stable_diff_pipe = None model = None db = TinyDB('artwork_results.json') def load_models(model_name="rupeshs/LCM-runwayml-stable-diffusion-v1-5"): global stable_diff_pipe, model # Load Stable Diffusion model pipeline stable_diff_pipe = StableDiffusionPipeline.from_pretrained(model_name) stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu") # Initialize MIDM model input_dim = 10 hidden_dim = 64 output_dim = 1 model = MIDM(input_dim, hidden_dim, output_dim) model.eval() # Function to extract features from the image using Stable Diffusion def extract_image_features(image): #Extracts image features using the Stable Diffusion pipeline. # Preprocess the image and get the feature vector image_input = stable_diff_pipe.feature_extractor(image, return_tensors="pt").pixel_values.to(stable_diff_pipe.device) # Generate the image embedding using the model with torch.no_grad(): generated_features = stable_diff_pipe.vae.encode(image_input).latent_dist.mean return generated_features # Helper to get or set a unique user_id via cookies def get_user_id(): user_id = request.cookies.get('user_id') if not user_id: user_id = str(uuid.uuid4()) return user_id @app.route('/') def index(): return render_template('index.html') @app.route('/api/check-membership', methods=['POST']) def check_membership(): try: model_name = request.form.get('model', 'rupeshs/LCM-runwayml-stable-diffusion-v1-5') contact_info = request.form.get('contact_info') # Optional, if frontend sends it opt_in = request.form.get('opt_in', 'false').lower() == 'true' user_id = get_user_id() # Ensure models are loaded with the selected model global stable_diff_pipe, model if stable_diff_pipe is None or model is None: load_models(model_name) elif stable_diff_pipe.name_or_path != model_name: load_models(model_name) if 'image' not in request.files: return jsonify({'error': 'No image found in request'}), 400 file = request.files['image'] image_bytes = file.read() image = Image.open(io.BytesIO(image_bytes)) image_features = extract_image_features(image) processed_features = image_features.reshape(1, -1)[:, :10] with torch.no_grad(): output = model(processed_features) probability = output.item() predicted = int(output > 0.5) # Store in DB if likely and opt-in if predicted == 1: entry = { 'model': model_name, 'user_id': user_id, 'timestamp': int(uuid.uuid1().time), } if opt_in and contact_info: entry['contact_info'] = contact_info # Only store if not already present for this user/model User = Query() exists = db.search((User.model == model_name) & (User.user_id == user_id)) if not exists: db.insert(entry) # Count unique user_ids for this model User = Query() user_ids = set(r['user_id'] for r in db.search(User.model == model_name)) count = len(user_ids) response = make_response(jsonify({ 'probability': probability, 'predicted_class': predicted, 'message': f"Predicted membership probability: {probability}", 'is_in_training_data': "Likely" if predicted == 1 else "Unlikely", 'likely_count': count })) # Set user_id cookie if not present if not request.cookies.get('user_id'): response.set_cookie('user_id', user_id, max_age=60*60*24*365*5) # 5 years return response except Exception as e: print(f"Error processing request: {str(e)}") return jsonify({'error': str(e)}), 500 if __name__ == '__main__': port = int(os.environ.get('PORT', 7860)) app.run(host='0.0.0.0', port=port)