infer / server.py
s-ahal's picture
Update server.py
176294d verified
raw
history blame
4.99 kB
import torch
import torch.nn as nn
from flask import Flask, request, jsonify, render_template, make_response
from flask_cors import CORS
import io
import os
from PIL import Image
from diffusers import StableDiffusionPipeline
import os
from tinydb import TinyDB, Query
import uuid
token = os.getenv("HF_TOKEN")
# Define the MIDM model
class MIDM(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(MIDM, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.sigmoid(out)
return out
app = Flask(__name__, static_folder='static', template_folder='templates')
CORS(app)
# Load models once when the app starts to avoid reloading for each request
stable_diff_pipe = None
model = None
db = TinyDB('artwork_results.json')
def load_models(model_name="rupeshs/LCM-runwayml-stable-diffusion-v1-5"):
global stable_diff_pipe, model
# Load Stable Diffusion model pipeline
stable_diff_pipe = StableDiffusionPipeline.from_pretrained(model_name)
stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
# Initialize MIDM model
input_dim = 10
hidden_dim = 64
output_dim = 1
model = MIDM(input_dim, hidden_dim, output_dim)
model.eval()
# Function to extract features from the image using Stable Diffusion
def extract_image_features(image):
#Extracts image features using the Stable Diffusion pipeline.
# Preprocess the image and get the feature vector
image_input = stable_diff_pipe.feature_extractor(image, return_tensors="pt").pixel_values.to(stable_diff_pipe.device)
# Generate the image embedding using the model
with torch.no_grad():
generated_features = stable_diff_pipe.vae.encode(image_input).latent_dist.mean
return generated_features
# Helper to get or set a unique user_id via cookies
def get_user_id():
user_id = request.cookies.get('user_id')
if not user_id:
user_id = str(uuid.uuid4())
return user_id
@app.route('/')
def index():
return render_template('index.html')
@app.route('/api/check-membership', methods=['POST'])
def check_membership():
try:
model_name = request.form.get('model', 'rupeshs/LCM-runwayml-stable-diffusion-v1-5')
contact_info = request.form.get('contact_info') # Optional, if frontend sends it
opt_in = request.form.get('opt_in', 'false').lower() == 'true'
user_id = get_user_id()
# Ensure models are loaded with the selected model
global stable_diff_pipe, model
if stable_diff_pipe is None or model is None:
load_models(model_name)
elif stable_diff_pipe.name_or_path != model_name:
load_models(model_name)
if 'image' not in request.files:
return jsonify({'error': 'No image found in request'}), 400
file = request.files['image']
image_bytes = file.read()
image = Image.open(io.BytesIO(image_bytes))
image_features = extract_image_features(image)
processed_features = image_features.reshape(1, -1)[:, :10]
with torch.no_grad():
output = model(processed_features)
probability = output.item()
predicted = int(output > 0.5)
# Store in DB if likely and opt-in
if predicted == 1:
entry = {
'model': model_name,
'user_id': user_id,
'timestamp': int(uuid.uuid1().time),
}
if opt_in and contact_info:
entry['contact_info'] = contact_info
# Only store if not already present for this user/model
User = Query()
exists = db.search((User.model == model_name) & (User.user_id == user_id))
if not exists:
db.insert(entry)
# Count unique user_ids for this model
User = Query()
user_ids = set(r['user_id'] for r in db.search(User.model == model_name))
count = len(user_ids)
response = make_response(jsonify({
'probability': probability,
'predicted_class': predicted,
'message': f"Predicted membership probability: {probability}",
'is_in_training_data': "Likely" if predicted == 1 else "Unlikely",
'likely_count': count
}))
# Set user_id cookie if not present
if not request.cookies.get('user_id'):
response.set_cookie('user_id', user_id, max_age=60*60*24*365*5) # 5 years
return response
except Exception as e:
print(f"Error processing request: {str(e)}")
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)