File size: 6,650 Bytes
6ebfebb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import torch
import torchvision.transforms as transforms
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from transformers import AutoFeatureExtractor
import timm
import numpy as np
import json
import base64
from io import BytesIO
import uvicorn

app = FastAPI(title="VerifAI GradCAM API", description="API pour la détection d'images IA avec GradCAM")

# Configuration CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class AIDetectionGradCAM:
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.models = {}
        self.feature_extractors = {}
        self.target_layers = {}
        
        # Initialiser les modèles
        self._load_models()
    
    def _load_models(self):
        """Charge les modèles pour la détection"""
        try:
            # Modèle Swin Transformer
            model_name = "microsoft/swin-base-patch4-window7-224-in22k"
            self.models['swin'] = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=2)
            self.feature_extractors['swin'] = AutoFeatureExtractor.from_pretrained(model_name)
            
            # Définir les couches cibles pour GradCAM
            self.target_layers['swin'] = [self.models['swin'].layers[-1].blocks[-1].norm1]
            
            # Mettre en mode évaluation
            for model in self.models.values():
                model.eval()
                model.to(self.device)
                
        except Exception as e:
            print(f"Erreur lors du chargement des modèles: {e}")
    
    def _preprocess_image(self, image, model_type='swin'):
        """Prétraite l'image pour le modèle"""
        if isinstance(image, str):
            # Si c'est un chemin ou base64
            if image.startswith('data:image'):
                # Décoder base64
                header, data = image.split(',', 1)
                image_data = base64.b64decode(data)
                image = Image.open(BytesIO(image_data))
            else:
                image = Image.open(image)
        
        # Convertir en RGB si nécessaire
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Redimensionner
        image = image.resize((224, 224))
        
        # Normalisation standard
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        tensor = transform(image).unsqueeze(0).to(self.device)
        return tensor, np.array(image) / 255.0
    
    def _generate_gradcam(self, image_tensor, rgb_img, model_type='swin'):
        """Génère la carte de saillance GradCAM"""
        try:
            model = self.models[model_type]
            target_layers = self.target_layers[model_type]
            
            # Créer l'objet GradCAM
            cam = GradCAM(model=model, target_layers=target_layers)
            
            # Générer la carte de saillance
            grayscale_cam = cam(input_tensor=image_tensor, targets=None)
            grayscale_cam = grayscale_cam[0, :]
            
            # Superposer sur l'image originale
            cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
            
            return cam_image
            
        except Exception as e:
            print(f"Erreur GradCAM: {e}")
            return rgb_img * 255
    
    def predict_and_explain(self, image):
        """Prédiction avec explication GradCAM"""
        try:
            # Prétraitement
            image_tensor, rgb_img = self._preprocess_image(image)
            
            # Prédiction
            with torch.no_grad():
                outputs = self.models['swin'](image_tensor)
                probabilities = torch.nn.functional.softmax(outputs, dim=1)
                confidence = probabilities.max().item()
                prediction = probabilities.argmax().item()
            
            # Génération GradCAM
            cam_image = self._generate_gradcam(image_tensor, rgb_img)
            
            # Convertir l'image GradCAM en base64
            pil_image = Image.fromarray(cam_image.astype(np.uint8))
            buffer = BytesIO()
            pil_image.save(buffer, format='PNG')
            cam_base64 = base64.b64encode(buffer.getvalue()).decode()
            
            # Résultats
            result = {
                'prediction': prediction,
                'confidence': confidence,
                'class_probabilities': {
                    'Real': probabilities[0][0].item(),
                    'AI-Generated': probabilities[0][1].item()
                },
                'cam_image': f"data:image/png;base64,{cam_base64}",
                'status': 'success'
            }
            
            return result
            
        except Exception as e:
            return {'status': 'error', 'message': str(e)}

# Initialiser le détecteur
detector = AIDetectionGradCAM()

@app.get("/")
async def root():
    return {"message": "VerifAI GradCAM API", "status": "running"}

@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
    """Endpoint pour analyser une image"""
    try:
        # Lire l'image
        image_data = await file.read()
        image = Image.open(BytesIO(image_data))
        
        # Analyser
        result = detector.predict_and_explain(image)
        
        return result
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/predict-base64")
async def predict_base64(data: dict):
    """Endpoint pour analyser une image en base64"""
    try:
        if 'image' not in data:
            raise HTTPException(status_code=400, detail="Champ 'image' requis")
        
        image_b64 = data['image']
        
        # Analyser
        result = detector.predict_and_explain(image_b64)
        
        return result
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)