TerenceG commited on
Commit
6ebfebb
·
verified ·
1 Parent(s): 58b6d0e

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, File, UploadFile
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import torch
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ from pytorch_grad_cam import GradCAM
7
+ from pytorch_grad_cam.utils.image import show_cam_on_image
8
+ from transformers import AutoFeatureExtractor
9
+ import timm
10
+ import numpy as np
11
+ import json
12
+ import base64
13
+ from io import BytesIO
14
+ import uvicorn
15
+
16
+ app = FastAPI(title="VerifAI GradCAM API", description="API pour la détection d'images IA avec GradCAM")
17
+
18
+ # Configuration CORS
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ class AIDetectionGradCAM:
28
+ def __init__(self):
29
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
+ self.models = {}
31
+ self.feature_extractors = {}
32
+ self.target_layers = {}
33
+
34
+ # Initialiser les modèles
35
+ self._load_models()
36
+
37
+ def _load_models(self):
38
+ """Charge les modèles pour la détection"""
39
+ try:
40
+ # Modèle Swin Transformer
41
+ model_name = "microsoft/swin-base-patch4-window7-224-in22k"
42
+ self.models['swin'] = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=2)
43
+ self.feature_extractors['swin'] = AutoFeatureExtractor.from_pretrained(model_name)
44
+
45
+ # Définir les couches cibles pour GradCAM
46
+ self.target_layers['swin'] = [self.models['swin'].layers[-1].blocks[-1].norm1]
47
+
48
+ # Mettre en mode évaluation
49
+ for model in self.models.values():
50
+ model.eval()
51
+ model.to(self.device)
52
+
53
+ except Exception as e:
54
+ print(f"Erreur lors du chargement des modèles: {e}")
55
+
56
+ def _preprocess_image(self, image, model_type='swin'):
57
+ """Prétraite l'image pour le modèle"""
58
+ if isinstance(image, str):
59
+ # Si c'est un chemin ou base64
60
+ if image.startswith('data:image'):
61
+ # Décoder base64
62
+ header, data = image.split(',', 1)
63
+ image_data = base64.b64decode(data)
64
+ image = Image.open(BytesIO(image_data))
65
+ else:
66
+ image = Image.open(image)
67
+
68
+ # Convertir en RGB si nécessaire
69
+ if image.mode != 'RGB':
70
+ image = image.convert('RGB')
71
+
72
+ # Redimensionner
73
+ image = image.resize((224, 224))
74
+
75
+ # Normalisation standard
76
+ transform = transforms.Compose([
77
+ transforms.ToTensor(),
78
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
79
+ std=[0.229, 0.224, 0.225])
80
+ ])
81
+
82
+ tensor = transform(image).unsqueeze(0).to(self.device)
83
+ return tensor, np.array(image) / 255.0
84
+
85
+ def _generate_gradcam(self, image_tensor, rgb_img, model_type='swin'):
86
+ """Génère la carte de saillance GradCAM"""
87
+ try:
88
+ model = self.models[model_type]
89
+ target_layers = self.target_layers[model_type]
90
+
91
+ # Créer l'objet GradCAM
92
+ cam = GradCAM(model=model, target_layers=target_layers)
93
+
94
+ # Générer la carte de saillance
95
+ grayscale_cam = cam(input_tensor=image_tensor, targets=None)
96
+ grayscale_cam = grayscale_cam[0, :]
97
+
98
+ # Superposer sur l'image originale
99
+ cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
100
+
101
+ return cam_image
102
+
103
+ except Exception as e:
104
+ print(f"Erreur GradCAM: {e}")
105
+ return rgb_img * 255
106
+
107
+ def predict_and_explain(self, image):
108
+ """Prédiction avec explication GradCAM"""
109
+ try:
110
+ # Prétraitement
111
+ image_tensor, rgb_img = self._preprocess_image(image)
112
+
113
+ # Prédiction
114
+ with torch.no_grad():
115
+ outputs = self.models['swin'](image_tensor)
116
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
117
+ confidence = probabilities.max().item()
118
+ prediction = probabilities.argmax().item()
119
+
120
+ # Génération GradCAM
121
+ cam_image = self._generate_gradcam(image_tensor, rgb_img)
122
+
123
+ # Convertir l'image GradCAM en base64
124
+ pil_image = Image.fromarray(cam_image.astype(np.uint8))
125
+ buffer = BytesIO()
126
+ pil_image.save(buffer, format='PNG')
127
+ cam_base64 = base64.b64encode(buffer.getvalue()).decode()
128
+
129
+ # Résultats
130
+ result = {
131
+ 'prediction': prediction,
132
+ 'confidence': confidence,
133
+ 'class_probabilities': {
134
+ 'Real': probabilities[0][0].item(),
135
+ 'AI-Generated': probabilities[0][1].item()
136
+ },
137
+ 'cam_image': f"data:image/png;base64,{cam_base64}",
138
+ 'status': 'success'
139
+ }
140
+
141
+ return result
142
+
143
+ except Exception as e:
144
+ return {'status': 'error', 'message': str(e)}
145
+
146
+ # Initialiser le détecteur
147
+ detector = AIDetectionGradCAM()
148
+
149
+ @app.get("/")
150
+ async def root():
151
+ return {"message": "VerifAI GradCAM API", "status": "running"}
152
+
153
+ @app.post("/predict")
154
+ async def predict_image(file: UploadFile = File(...)):
155
+ """Endpoint pour analyser une image"""
156
+ try:
157
+ # Lire l'image
158
+ image_data = await file.read()
159
+ image = Image.open(BytesIO(image_data))
160
+
161
+ # Analyser
162
+ result = detector.predict_and_explain(image)
163
+
164
+ return result
165
+
166
+ except Exception as e:
167
+ raise HTTPException(status_code=500, detail=str(e))
168
+
169
+ @app.post("/predict-base64")
170
+ async def predict_base64(data: dict):
171
+ """Endpoint pour analyser une image en base64"""
172
+ try:
173
+ if 'image' not in data:
174
+ raise HTTPException(status_code=400, detail="Champ 'image' requis")
175
+
176
+ image_b64 = data['image']
177
+
178
+ # Analyser
179
+ result = detector.predict_and_explain(image_b64)
180
+
181
+ return result
182
+
183
+ except Exception as e:
184
+ raise HTTPException(status_code=500, detail=str(e))
185
+
186
+ if __name__ == "__main__":
187
+ uvicorn.run(app, host="0.0.0.0", port=7860)