File size: 7,489 Bytes
da2daf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from transformers import AutoModel, AutoFeatureExtractor
import timm
import numpy as np
import json
import base64
from io import BytesIO

class AIDetectionGradCAM:
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.models = {}
        self.feature_extractors = {}
        self.target_layers = {}
        
        # Initialiser les modèles
        self._load_models()
    
    def _load_models(self):
        """Charge les modèles pour la détection"""
        try:
            # Modèle Swin Transformer
            model_name = "microsoft/swin-base-patch4-window7-224-in22k"
            self.models['swin'] = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=2)
            self.feature_extractors['swin'] = AutoFeatureExtractor.from_pretrained(model_name)
            
            # Définir les couches cibles pour GradCAM
            self.target_layers['swin'] = [self.models['swin'].layers[-1].blocks[-1].norm1]
            
            # Mettre en mode évaluation
            for model in self.models.values():
                model.eval()
                model.to(self.device)
                
        except Exception as e:
            print(f"Erreur lors du chargement des modèles: {e}")
    
    def _preprocess_image(self, image, model_type='swin'):
        """Prétraite l'image pour le modèle"""
        if isinstance(image, str):
            # Si c'est un chemin ou base64
            if image.startswith('data:image'):
                # Décoder base64
                header, data = image.split(',', 1)
                image_data = base64.b64decode(data)
                image = Image.open(BytesIO(image_data))
            else:
                image = Image.open(image)
        
        # Convertir en RGB si nécessaire
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Redimensionner
        image = image.resize((224, 224))
        
        # Normalisation standard
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        tensor = transform(image).unsqueeze(0).to(self.device)
        return tensor, np.array(image) / 255.0
    
    def _generate_gradcam(self, image_tensor, rgb_img, model_type='swin'):
        """Génère la carte de saillance GradCAM"""
        try:
            model = self.models[model_type]
            target_layers = self.target_layers[model_type]
            
            # Créer l'objet GradCAM
            cam = GradCAM(model=model, target_layers=target_layers)
            
            # Générer la carte de saillance
            grayscale_cam = cam(input_tensor=image_tensor, targets=None)
            grayscale_cam = grayscale_cam[0, :]
            
            # Superposer sur l'image originale
            cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
            
            return cam_image
            
        except Exception as e:
            print(f"Erreur GradCAM: {e}")
            return rgb_img * 255
    
    def predict_and_explain(self, image):
        """Prédiction avec explication GradCAM"""
        try:
            # Prétraitement
            image_tensor, rgb_img = self._preprocess_image(image)
            
            # Prédiction
            with torch.no_grad():
                outputs = self.models['swin'](image_tensor)
                probabilities = torch.nn.functional.softmax(outputs, dim=1)
                confidence = probabilities.max().item()
                prediction = probabilities.argmax().item()
            
            # Génération GradCAM
            cam_image = self._generate_gradcam(image_tensor, rgb_img)
            
            # Résultats
            class_names = ['Real', 'AI-Generated']
            predicted_class = class_names[prediction]
            
            result = {
                'prediction': prediction,
                'confidence': confidence,
                'predicted_class': predicted_class,
                'probabilities': {
                    'Real': probabilities[0][0].item(),
                    'AI-Generated': probabilities[0][1].item()
                }
            }
            
            return cam_image.astype(np.uint8), result
            
        except Exception as e:
            return image, {'error': str(e)}

# Initialiser le détecteur
detector = AIDetectionGradCAM()

def analyze_image(image):
    """Fonction pour l'interface Gradio"""
    if image is None:
        return None, "Veuillez télécharger une image"
    
    try:
        cam_image, result = detector.predict_and_explain(image)
        
        if 'error' in result:
            return image, f"Erreur: {result['error']}"
        
        # Formatage du résultat
        confidence_percent = result['confidence'] * 100
        predicted_class = result['predicted_class']
        
        analysis_text = f"""

        ## 🔍 Analyse de l'image

        

        **Prédiction:** {predicted_class}

        **Confiance:** {confidence_percent:.1f}%

        

        **Probabilités détaillées:**

        - Real: {result['probabilities']['Real']:.3f}

        - AI-Generated: {result['probabilities']['AI-Generated']:.3f}

        

        La carte de saillance (GradCAM) montre les zones que le modèle considère comme importantes pour sa décision.

        """
        
        return cam_image, analysis_text
        
    except Exception as e:
        return image, f"Erreur lors de l'analyse: {str(e)}"

# Interface Gradio
with gr.Blocks(theme=gr.themes.Soft(), title="VerifAI - Détection d'images IA avec GradCAM") as app:
    gr.Markdown("""

    # 🔍 VerifAI - Détecteur d'images IA avec GradCAM

    

    Téléchargez une image pour déterminer si elle a été générée par une IA. 

    L'application utilise GradCAM pour expliquer visuellement sa décision.

    """)
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                type="pil",
                label="📸 Téléchargez votre image",
                height=400
            )
            analyze_btn = gr.Button("🔍 Analyser l'image", variant="primary", size="lg")
        
        with gr.Column():
            output_image = gr.Image(
                label="🎯 Carte de saillance GradCAM",
                height=400
            )
            result_text = gr.Markdown(label="📊 Résultats de l'analyse")
    
    analyze_btn.click(
        fn=analyze_image,
        inputs=[input_image],
        outputs=[output_image, result_text]
    )
    
    gr.Markdown("""

    ---

    ### 💡 Comment interpréter les résultats

    

    - **Real**: L'image semble être une vraie photo

    - **AI-Generated**: L'image semble être générée par IA

    - **Carte de saillance**: Les zones colorées indiquent les régions importantes pour la décision

    """)

if __name__ == "__main__":
    app.launch()