import torch import torch.nn as nn from flask import Flask, request, jsonify, render_template, make_response, send_from_directory from flask_cors import CORS import io import os from PIL import Image from diffusers import StableDiffusionPipeline import os token = os.getenv("HF_TOKEN") # Define the MIDM model class MIDM(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(MIDM, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, output_dim) self.sigmoid = nn.Sigmoid() def forward(self, x): out = self.fc1(x) out = self.relu(out) out = self.fc2(out) out = self.sigmoid(out) return out app = Flask(__name__, static_folder='static', template_folder='templates') CORS(app) # Load models once when the app starts to avoid reloading for each request stable_diff_pipe = None model = None def load_models(model_name="CompVis/stable-diffusion-v1-4"): global stable_diff_pipe, model # Load Stable Diffusion model pipeline stable_diff_pipe = StableDiffusionPipeline.from_pretrained(model_name) stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu") # Initialize MIDM model input_dim = 10 hidden_dim = 64 output_dim = 1 model = MIDM(input_dim, hidden_dim, output_dim) model.eval() # Function to extract features from the image using Stable Diffusion def extract_image_features(image): #Extracts image features using the Stable Diffusion pipeline. # Preprocess the image and get the feature vector image_input = stable_diff_pipe.feature_extractor(image, return_tensors="pt").pixel_values.to(stable_diff_pipe.device) # Generate the image embedding using the model with torch.no_grad(): generated_features = stable_diff_pipe.vae.encode(image_input).latent_dist.mean return generated_features @app.route('/') def index(): return send_from_directory('.', 'index.html') @app.route('/resources') def resources(): return send_from_directory('.', 'resources.html') @app.route('/get-organized') def get_organized(): return send_from_directory('.', 'get-organized.html') @app.route('/static/') def static_files(filename): return send_from_directory('static', filename) @app.route('/api/check-membership', methods=['POST']) def check_membership(): try: model_name = request.form.get('model', 'CompVis/stable-diffusion-v1-4') global stable_diff_pipe, model if stable_diff_pipe is None or model is None: load_models(model_name) elif stable_diff_pipe.name_or_path != model_name: load_models(model_name) if 'image' not in request.files: return jsonify({'error': 'No image found in request'}), 400 file = request.files['image'] image_bytes = file.read() image = Image.open(io.BytesIO(image_bytes)) image_features = extract_image_features(image) processed_features = image_features.reshape(1, -1)[:, :10] with torch.no_grad(): output = model(processed_features) probability = output.item() predicted = int(output > 0.5) return jsonify({ 'probability': probability, 'predicted_class': predicted, 'message': f"Predicted membership probability: {probability}", 'is_in_training_data': "Likely" if predicted == 1 else "Unlikely" }) except Exception as e: print(f"Error processing request: {str(e)}") return jsonify({'error': str(e)}), 500 if __name__ == '__main__': port = int(os.environ.get('PORT', 7860)) app.run(host='0.0.0.0', port=port)