|
import torch |
|
import torch.nn as nn |
|
from flask import Flask, request, jsonify, render_template, make_response, send_from_directory |
|
from flask_cors import CORS |
|
import io |
|
import os |
|
from PIL import Image |
|
from diffusers import StableDiffusionPipeline |
|
import os |
|
|
|
token = os.getenv("HF_TOKEN") |
|
|
|
|
|
class MIDM(nn.Module): |
|
def __init__(self, input_dim, hidden_dim, output_dim): |
|
super(MIDM, self).__init__() |
|
self.fc1 = nn.Linear(input_dim, hidden_dim) |
|
self.relu = nn.ReLU() |
|
self.fc2 = nn.Linear(hidden_dim, output_dim) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
out = self.fc1(x) |
|
out = self.relu(out) |
|
out = self.fc2(out) |
|
out = self.sigmoid(out) |
|
return out |
|
|
|
app = Flask(__name__, static_folder='static', template_folder='templates') |
|
CORS(app) |
|
|
|
|
|
stable_diff_pipe = None |
|
model = None |
|
|
|
def load_models(model_name="CompVis/stable-diffusion-v1-4"): |
|
global stable_diff_pipe, model |
|
|
|
|
|
stable_diff_pipe = StableDiffusionPipeline.from_pretrained(model_name) |
|
stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
input_dim = 10 |
|
hidden_dim = 64 |
|
output_dim = 1 |
|
model = MIDM(input_dim, hidden_dim, output_dim) |
|
|
|
model.eval() |
|
|
|
|
|
def extract_image_features(image): |
|
|
|
|
|
image_input = stable_diff_pipe.feature_extractor(image, return_tensors="pt").pixel_values.to(stable_diff_pipe.device) |
|
|
|
|
|
with torch.no_grad(): |
|
generated_features = stable_diff_pipe.vae.encode(image_input).latent_dist.mean |
|
|
|
return generated_features |
|
|
|
@app.route('/') |
|
def index(): |
|
return send_from_directory('.', 'index.html') |
|
|
|
@app.route('/resources') |
|
def resources(): |
|
return send_from_directory('.', 'resources.html') |
|
|
|
@app.route('/get-organized') |
|
def get_organized(): |
|
return send_from_directory('.', 'get-organized.html') |
|
|
|
@app.route('/static/<path:filename>') |
|
def static_files(filename): |
|
return send_from_directory('static', filename) |
|
|
|
@app.route('/api/check-membership', methods=['POST']) |
|
def check_membership(): |
|
try: |
|
model_name = request.form.get('model', 'CompVis/stable-diffusion-v1-4') |
|
global stable_diff_pipe, model |
|
if stable_diff_pipe is None or model is None: |
|
load_models(model_name) |
|
elif stable_diff_pipe.name_or_path != model_name: |
|
load_models(model_name) |
|
if 'image' not in request.files: |
|
return jsonify({'error': 'No image found in request'}), 400 |
|
file = request.files['image'] |
|
image_bytes = file.read() |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
image_features = extract_image_features(image) |
|
processed_features = image_features.reshape(1, -1)[:, :10] |
|
with torch.no_grad(): |
|
output = model(processed_features) |
|
probability = output.item() |
|
predicted = int(output > 0.5) |
|
return jsonify({ |
|
'probability': probability, |
|
'predicted_class': predicted, |
|
'message': f"Predicted membership probability: {probability}", |
|
'is_in_training_data': "Likely" if predicted == 1 else "Unlikely" |
|
}) |
|
except Exception as e: |
|
print(f"Error processing request: {str(e)}") |
|
return jsonify({'error': str(e)}), 500 |
|
|
|
if __name__ == '__main__': |
|
port = int(os.environ.get('PORT', 7860)) |
|
app.run(host='0.0.0.0', port=port) |