infer / server.py
s-ahal's picture
Update server.py
d8b7a9d verified
import torch
import torch.nn as nn
from flask import Flask, request, jsonify, render_template, make_response, send_from_directory
from flask_cors import CORS
import io
import os
from PIL import Image
from diffusers import StableDiffusionPipeline
import os
token = os.getenv("HF_TOKEN")
# Define the MIDM model
class MIDM(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(MIDM, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.sigmoid(out)
return out
app = Flask(__name__, static_folder='static', template_folder='templates')
CORS(app)
# Load models once when the app starts to avoid reloading for each request
stable_diff_pipe = None
model = None
def load_models(model_name="CompVis/stable-diffusion-v1-4"):
global stable_diff_pipe, model
# Load Stable Diffusion model pipeline
stable_diff_pipe = StableDiffusionPipeline.from_pretrained(model_name)
stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
# Initialize MIDM model
input_dim = 10
hidden_dim = 64
output_dim = 1
model = MIDM(input_dim, hidden_dim, output_dim)
model.eval()
# Function to extract features from the image using Stable Diffusion
def extract_image_features(image):
#Extracts image features using the Stable Diffusion pipeline.
# Preprocess the image and get the feature vector
image_input = stable_diff_pipe.feature_extractor(image, return_tensors="pt").pixel_values.to(stable_diff_pipe.device)
# Generate the image embedding using the model
with torch.no_grad():
generated_features = stable_diff_pipe.vae.encode(image_input).latent_dist.mean
return generated_features
@app.route('/')
def index():
return send_from_directory('.', 'index.html')
@app.route('/resources')
def resources():
return send_from_directory('.', 'resources.html')
@app.route('/get-organized')
def get_organized():
return send_from_directory('.', 'get-organized.html')
@app.route('/static/<path:filename>')
def static_files(filename):
return send_from_directory('static', filename)
@app.route('/api/check-membership', methods=['POST'])
def check_membership():
try:
model_name = request.form.get('model', 'CompVis/stable-diffusion-v1-4')
global stable_diff_pipe, model
if stable_diff_pipe is None or model is None:
load_models(model_name)
elif stable_diff_pipe.name_or_path != model_name:
load_models(model_name)
if 'image' not in request.files:
return jsonify({'error': 'No image found in request'}), 400
file = request.files['image']
image_bytes = file.read()
image = Image.open(io.BytesIO(image_bytes))
image_features = extract_image_features(image)
processed_features = image_features.reshape(1, -1)[:, :10]
with torch.no_grad():
output = model(processed_features)
probability = output.item()
predicted = int(output > 0.5)
return jsonify({
'probability': probability,
'predicted_class': predicted,
'message': f"Predicted membership probability: {probability}",
'is_in_training_data': "Likely" if predicted == 1 else "Unlikely"
})
except Exception as e:
print(f"Error processing request: {str(e)}")
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)