File size: 4,993 Bytes
ff13394
 
176294d
ff13394
 
 
 
7e1e741
9a5bfef
176294d
 
9a5bfef
 
ff13394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e1e741
ff13394
 
176294d
 
08b8cd7
7e1e741
ff13394
7e1e741
08b8cd7
7e1e741
ff13394
 
c176b63
ff13394
 
 
 
 
 
7e1e741
 
c176b63
7e1e741
 
 
 
ff13394
7e1e741
 
 
ff13394
176294d
 
 
 
 
 
 
ff13394
 
 
 
 
 
 
08b8cd7
176294d
 
 
08b8cd7
176294d
08b8cd7
 
 
 
 
 
ff13394
 
 
7e1e741
176294d
ff13394
 
 
 
176294d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff13394
 
 
176294d
 
 
 
 
 
 
ff13394
176294d
ff13394
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn as nn
from flask import Flask, request, jsonify, render_template, make_response
from flask_cors import CORS
import io
import os
from PIL import Image
from diffusers import StableDiffusionPipeline
import os
from tinydb import TinyDB, Query
import uuid

token = os.getenv("HF_TOKEN")

# Define the MIDM model
class MIDM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MIDM, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        return out

app = Flask(__name__, static_folder='static', template_folder='templates')
CORS(app)

# Load models once when the app starts to avoid reloading for each request
stable_diff_pipe = None
model = None

db = TinyDB('artwork_results.json')

def load_models(model_name="rupeshs/LCM-runwayml-stable-diffusion-v1-5"):
    global stable_diff_pipe, model
    
    # Load Stable Diffusion model pipeline
    stable_diff_pipe = StableDiffusionPipeline.from_pretrained(model_name)
    stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize MIDM model
    input_dim = 10  
    hidden_dim = 64
    output_dim = 1
    model = MIDM(input_dim, hidden_dim, output_dim)
    
    model.eval()

# Function to extract features from the image using Stable Diffusion
def extract_image_features(image):
    #Extracts image features using the Stable Diffusion pipeline.
    # Preprocess the image and get the feature vector
    image_input = stable_diff_pipe.feature_extractor(image, return_tensors="pt").pixel_values.to(stable_diff_pipe.device)
    
    # Generate the image embedding using the model
    with torch.no_grad():
        generated_features = stable_diff_pipe.vae.encode(image_input).latent_dist.mean

    return generated_features

# Helper to get or set a unique user_id via cookies
def get_user_id():
    user_id = request.cookies.get('user_id')
    if not user_id:
        user_id = str(uuid.uuid4())
    return user_id

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/api/check-membership', methods=['POST'])
def check_membership():
    try:
        model_name = request.form.get('model', 'rupeshs/LCM-runwayml-stable-diffusion-v1-5')
        contact_info = request.form.get('contact_info')  # Optional, if frontend sends it
        opt_in = request.form.get('opt_in', 'false').lower() == 'true'
        user_id = get_user_id()
        # Ensure models are loaded with the selected model
        global stable_diff_pipe, model
        if stable_diff_pipe is None or model is None:
            load_models(model_name)
        elif stable_diff_pipe.name_or_path != model_name:
            load_models(model_name)
        if 'image' not in request.files:
            return jsonify({'error': 'No image found in request'}), 400
        file = request.files['image']
        image_bytes = file.read()
        image = Image.open(io.BytesIO(image_bytes))
        image_features = extract_image_features(image)
        processed_features = image_features.reshape(1, -1)[:, :10]
        with torch.no_grad():
            output = model(processed_features)
            probability = output.item()
            predicted = int(output > 0.5)
        # Store in DB if likely and opt-in
        if predicted == 1:
            entry = {
                'model': model_name,
                'user_id': user_id,
                'timestamp': int(uuid.uuid1().time),
            }
            if opt_in and contact_info:
                entry['contact_info'] = contact_info
            # Only store if not already present for this user/model
            User = Query()
            exists = db.search((User.model == model_name) & (User.user_id == user_id))
            if not exists:
                db.insert(entry)
        # Count unique user_ids for this model
        User = Query()
        user_ids = set(r['user_id'] for r in db.search(User.model == model_name))
        count = len(user_ids)
        response = make_response(jsonify({
            'probability': probability,
            'predicted_class': predicted,
            'message': f"Predicted membership probability: {probability}",
            'is_in_training_data': "Likely" if predicted == 1 else "Unlikely",
            'likely_count': count
        }))
        # Set user_id cookie if not present
        if not request.cookies.get('user_id'):
            response.set_cookie('user_id', user_id, max_age=60*60*24*365*5)  # 5 years
        return response
    except Exception as e:
        print(f"Error processing request: {str(e)}")
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    port = int(os.environ.get('PORT', 7860))
    app.run(host='0.0.0.0', port=port)