Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from transformers import AutoModel, AutoFeatureExtractor | |
import timm | |
import numpy as np | |
import json | |
import base64 | |
from io import BytesIO | |
class AIDetectionGradCAM: | |
def __init__(self): | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.models = {} | |
self.feature_extractors = {} | |
self.target_layers = {} | |
# Initialiser les modèles | |
self._load_models() | |
def _load_models(self): | |
"""Charge les modèles pour la détection""" | |
try: | |
# Modèle Swin Transformer | |
model_name = "microsoft/swin-base-patch4-window7-224-in22k" | |
self.models['swin'] = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=2) | |
self.feature_extractors['swin'] = AutoFeatureExtractor.from_pretrained(model_name) | |
# Définir les couches cibles pour GradCAM | |
self.target_layers['swin'] = [self.models['swin'].layers[-1].blocks[-1].norm1] | |
# Mettre en mode évaluation | |
for model in self.models.values(): | |
model.eval() | |
model.to(self.device) | |
except Exception as e: | |
print(f"Erreur lors du chargement des modèles: {e}") | |
def _preprocess_image(self, image, model_type='swin'): | |
"""Prétraite l'image pour le modèle""" | |
if isinstance(image, str): | |
# Si c'est un chemin ou base64 | |
if image.startswith('data:image'): | |
# Décoder base64 | |
header, data = image.split(',', 1) | |
image_data = base64.b64decode(data) | |
image = Image.open(BytesIO(image_data)) | |
else: | |
image = Image.open(image) | |
# Convertir en RGB si nécessaire | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Redimensionner | |
image = image.resize((224, 224)) | |
# Normalisation standard | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
tensor = transform(image).unsqueeze(0).to(self.device) | |
return tensor, np.array(image) / 255.0 | |
def _generate_gradcam(self, image_tensor, rgb_img, model_type='swin'): | |
"""Génère la carte de saillance GradCAM""" | |
try: | |
model = self.models[model_type] | |
target_layers = self.target_layers[model_type] | |
# Créer l'objet GradCAM | |
cam = GradCAM(model=model, target_layers=target_layers) | |
# Générer la carte de saillance | |
grayscale_cam = cam(input_tensor=image_tensor, targets=None) | |
grayscale_cam = grayscale_cam[0, :] | |
# Superposer sur l'image originale | |
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) | |
return cam_image | |
except Exception as e: | |
print(f"Erreur GradCAM: {e}") | |
return rgb_img * 255 | |
def predict_and_explain(self, image): | |
"""Prédiction avec explication GradCAM""" | |
try: | |
# Prétraitement | |
image_tensor, rgb_img = self._preprocess_image(image) | |
# Prédiction | |
with torch.no_grad(): | |
outputs = self.models['swin'](image_tensor) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
confidence = probabilities.max().item() | |
prediction = probabilities.argmax().item() | |
# Génération GradCAM | |
cam_image = self._generate_gradcam(image_tensor, rgb_img) | |
# Résultats | |
class_names = ['Real', 'AI-Generated'] | |
predicted_class = class_names[prediction] | |
result = { | |
'prediction': prediction, | |
'confidence': confidence, | |
'predicted_class': predicted_class, | |
'probabilities': { | |
'Real': probabilities[0][0].item(), | |
'AI-Generated': probabilities[0][1].item() | |
} | |
} | |
return cam_image.astype(np.uint8), result | |
except Exception as e: | |
return image, {'error': str(e)} | |
# Initialiser le détecteur | |
detector = AIDetectionGradCAM() | |
def analyze_image(image): | |
"""Fonction pour l'interface Gradio""" | |
if image is None: | |
return None, "Veuillez télécharger une image" | |
try: | |
cam_image, result = detector.predict_and_explain(image) | |
if 'error' in result: | |
return image, f"Erreur: {result['error']}" | |
# Formatage du résultat | |
confidence_percent = result['confidence'] * 100 | |
predicted_class = result['predicted_class'] | |
analysis_text = f""" | |
## 🔍 Analyse de l'image | |
**Prédiction:** {predicted_class} | |
**Confiance:** {confidence_percent:.1f}% | |
**Probabilités détaillées:** | |
- Real: {result['probabilities']['Real']:.3f} | |
- AI-Generated: {result['probabilities']['AI-Generated']:.3f} | |
La carte de saillance (GradCAM) montre les zones que le modèle considère comme importantes pour sa décision. | |
""" | |
return cam_image, analysis_text | |
except Exception as e: | |
return image, f"Erreur lors de l'analyse: {str(e)}" | |
# Interface Gradio | |
with gr.Blocks(theme=gr.themes.Soft(), title="VerifAI - Détection d'images IA avec GradCAM") as app: | |
gr.Markdown(""" | |
# 🔍 VerifAI - Détecteur d'images IA avec GradCAM | |
Téléchargez une image pour déterminer si elle a été générée par une IA. | |
L'application utilise GradCAM pour expliquer visuellement sa décision. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
type="pil", | |
label="📸 Téléchargez votre image", | |
height=400 | |
) | |
analyze_btn = gr.Button("🔍 Analyser l'image", variant="primary", size="lg") | |
with gr.Column(): | |
output_image = gr.Image( | |
label="🎯 Carte de saillance GradCAM", | |
height=400 | |
) | |
result_text = gr.Markdown(label="📊 Résultats de l'analyse") | |
analyze_btn.click( | |
fn=analyze_image, | |
inputs=[input_image], | |
outputs=[output_image, result_text] | |
) | |
gr.Markdown(""" | |
--- | |
### 💡 Comment interpréter les résultats | |
- **Real**: L'image semble être une vraie photo | |
- **AI-Generated**: L'image semble être générée par IA | |
- **Carte de saillance**: Les zones colorées indiquent les régions importantes pour la décision | |
""") | |
if __name__ == "__main__": | |
app.launch() |