hoololi commited on
Commit
ad0e663
·
verified ·
1 Parent(s): 9bdeee9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -211
app.py CHANGED
@@ -1,272 +1,217 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoImageProcessor, AutoModelForObjectDetection
3
  from PIL import Image, ImageDraw, ImageFont
4
  import torch
5
  import spaces
6
  import numpy as np
 
7
 
8
- # Modèles disponibles sur Hugging Face Hub
9
- AVAILABLE_MODELS = {
 
10
  "DETR ResNet-50": "facebook/detr-resnet-50",
11
- "DETR ResNet-101": "facebook/detr-resnet-101",
12
- "Conditional DETR": "microsoft/conditional-detr-resnet-50",
13
- "Table Transformer": "microsoft/table-transformer-detection",
14
- "YOLOS Tiny": "hustvl/yolos-tiny",
15
  "YOLOS Small": "hustvl/yolos-small",
16
- "YOLOS Base": "hustvl/yolos-base",
17
- "RT-DETR": "PekingU/rtdetr_r50vd_coco_o365",
18
- "OWL-ViT": "google/owlvit-base-patch32"
19
  }
20
 
21
- # Cache pour éviter de recharger les modèles
22
- model_cache = {}
 
23
 
24
- def load_model(model_name):
25
- """Charge un modèle avec cache"""
26
- if model_name not in model_cache:
27
- print(f"Chargement du modèle: {model_name}")
28
-
29
- if "owlvit" in model_name:
30
- # OWL-ViT est un modèle de détection zero-shot
31
- model_cache[model_name] = pipeline(
32
- "zero-shot-object-detection",
33
- model=model_name,
34
- device=0 if torch.cuda.is_available() else -1
35
- )
36
- else:
37
- # Autres modèles de détection standard
38
- model_cache[model_name] = pipeline(
39
- "object-detection",
40
- model=model_name,
41
- device=0 if torch.cuda.is_available() else -1
42
- )
43
-
44
- return model_cache[model_name]
45
 
46
  @spaces.GPU
47
- def detect_objects(image, model_choice, confidence_threshold, custom_classes=""):
48
- """Détection d'objets avec modèles transformers"""
49
-
50
- if image is None:
51
- return None, "❌ Veuillez uploader une image"
 
 
52
 
53
  try:
54
- # Charger le modèle sélectionné
55
- model_id = AVAILABLE_MODELS[model_choice]
56
- detector = load_model(model_id)
57
 
58
- # Traitement spécial pour OWL-ViT (zero-shot)
59
- if "owlvit" in model_id.lower():
60
- if not custom_classes.strip():
61
- custom_classes = "person, car, dog, cat, chair, table, bottle, cup"
62
-
63
- class_list = [cls.strip() for cls in custom_classes.split(",")]
64
- results = detector(image, candidate_labels=class_list)
65
  else:
66
- # Modèles de détection standard
67
- results = detector(image)
68
 
69
- # Filtrer par seuil de confiance
70
- filtered_results = [
71
- obj for obj in results
72
- if obj['score'] >= confidence_threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  ]
74
 
75
- # Dessiner les détections
76
- annotated_image = draw_detections(image.copy(), filtered_results)
 
 
 
 
 
77
 
78
- # Créer le résumé
79
- summary = create_summary(filtered_results, model_choice)
80
 
81
- return annotated_image, summary
 
82
 
83
  except Exception as e:
84
- return image, f"❌ Erreur: {str(e)}"
 
85
 
86
- def draw_detections(image, detections):
87
- """Dessine les boîtes de détection sur l'image"""
 
 
 
88
  draw = ImageDraw.Draw(image)
89
 
90
- # Essayer de charger une police, sinon utiliser la police par défaut
91
  try:
92
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
93
- except:
94
  font = ImageFont.load_default()
 
 
95
 
96
- colors = [
97
- "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FECA57",
98
- "#FF9FF3", "#54A0FF", "#5F27CD", "#00D2D3", "#FF9F43"
99
- ]
100
 
101
  for i, detection in enumerate(detections):
102
  box = detection['box']
103
  label = detection['label']
104
  score = detection['score']
105
 
106
- # Coordonnées de la boîte
107
  x1, y1 = box['xmin'], box['ymin']
108
  x2, y2 = box['xmax'], box['ymax']
109
 
110
- # Couleur pour cette classe
111
  color = colors[i % len(colors)]
112
 
113
- # Dessiner la boîte
114
- draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
115
-
116
- # Texte du label
117
- text = f"{label} ({score:.2f})"
118
 
119
- # Fond du texte
120
- bbox = draw.textbbox((x1, y1-25), text, font=font)
121
- draw.rectangle(bbox, fill=color)
122
 
123
- # Texte
124
- draw.text((x1, y1-25), text, fill="white", font=font)
 
 
 
 
 
125
 
126
  return image
127
 
128
- def create_summary(detections, model_name):
129
- """Crée un résumé des détections"""
130
- if not detections:
131
- return "🔍 Aucun objet détecté"
132
-
133
- summary = f"🎯 **{len(detections)} objets détectés** avec {model_name}\n\n"
134
-
135
- # Grouper par classe
136
- class_counts = {}
137
- for det in detections:
138
- label = det['label']
139
- score = det['score']
140
-
141
- if label not in class_counts:
142
- class_counts[label] = []
143
- class_counts[label].append(score)
144
-
145
- # Afficher le résumé
146
- for label, scores in class_counts.items():
147
- count = len(scores)
148
- avg_score = sum(scores) / len(scores)
149
- max_score = max(scores)
150
-
151
- summary += f"**{label}**: {count}x (confiance: {avg_score:.2f} avg, {max_score:.2f} max)\n"
152
-
153
- return summary
154
-
155
- # Interface Gradio
156
- with gr.Blocks(title="🤖 Object Detection avec Transformers", theme=gr.themes.Soft()) as demo:
157
 
158
  gr.Markdown("""
159
- # 🤖 Object Detection avec Transformers
160
 
161
- Utilisez les meilleurs modèles de détection d'objets disponibles sur Hugging Face Hub !
162
 
163
- **✨ Fonctionnalités:**
164
- - 🔄 Changement de modèle en temps réel
165
- - 🎯 Seuil de confiance ajustable
166
- - 🏷️ Classes personnalisées (OWL-ViT)
167
- - 📊 Résumé détaillé des détections
168
  """)
169
 
170
  with gr.Row():
171
- with gr.Column(scale=1):
172
- # Input
173
- image_input = gr.Image(
174
- type="pil",
175
- label="📸 Image à analyser",
176
- height=400
177
- )
178
-
179
- # Sélection du modèle
180
- model_dropdown = gr.Dropdown(
181
- choices=list(AVAILABLE_MODELS.keys()),
182
- value="DETR ResNet-50",
183
- label="🤖 Modèle de détection",
184
- info="Chaque modèle a ses spécialités"
185
- )
186
-
187
- # Paramètres
188
- confidence_slider = gr.Slider(
189
- minimum=0.1,
190
- maximum=1.0,
191
- value=0.5,
192
- step=0.05,
193
- label="🎯 Seuil de confiance minimum"
194
- )
195
-
196
- # Classes personnalisées pour OWL-ViT
197
- custom_classes_input = gr.Textbox(
198
- label="🏷️ Classes personnalisées (pour OWL-ViT)",
199
- placeholder="person, car, dog, bottle, phone",
200
- info="Séparées par des virgules. Uniquement pour OWL-ViT."
201
- )
202
-
203
- # Bouton de détection
204
- detect_btn = gr.Button(
205
- "🔍 Détecter les objets",
206
- variant="primary",
207
- size="lg"
208
  )
209
 
210
  with gr.Column(scale=1):
211
- # Outputs
212
- output_image = gr.Image(
213
- label="📊 Résultats de détection",
214
- height=400
215
- )
216
 
217
- detection_summary = gr.Textbox(
218
- label="📈 Résumé des détections",
219
- lines=8,
220
- max_lines=15
221
- )
222
-
223
- # Event handlers
224
- detect_btn.click(
225
- fn=detect_objects,
226
- inputs=[image_input, model_dropdown, confidence_slider, custom_classes_input],
227
- outputs=[output_image, detection_summary]
228
- )
229
-
230
- # Auto-detect en changeant de modèle
231
- model_dropdown.change(
232
- fn=detect_objects,
233
- inputs=[image_input, model_dropdown, confidence_slider, custom_classes_input],
234
- outputs=[output_image, detection_summary]
235
- )
236
-
237
- with gr.Accordion("📚 Guide des modèles", open=False):
238
- gr.Markdown("""
239
- ## 🎯 Guide de sélection des modèles
240
-
241
- ### **DETR (Detection Transformer)**
242
- - **ResNet-50**: Équilibre vitesse/précision ⚖️
243
- - **ResNet-101**: Plus précis, plus lent 🎯
244
- - **Conditional DETR**: Version optimisée 🚀
245
-
246
- ### **YOLOS (You Only Look Once Transformer)**
247
- - **Tiny**: Ultra-rapide ⚡
248
- - **Small**: Bon compromis 🎯
249
- - **Base**: Maximum de précision 🔍
250
-
251
- ### **OWL-ViT (Zero-shot Detection)**
252
- - Détecte **n'importe quoi** que vous décrivez ! 🎨
253
- - Tapez vos propres classes dans le champ "Classes personnalisées"
254
-
255
- ### **RT-DETR**
256
- - Optimisé pour le temps réel ⚡
257
-
258
- ### **Table Transformer**
259
- - Spécialisé dans la détection de tableaux 📊
260
- """)
261
-
262
- # Exemples
263
- gr.Examples(
264
- examples=[
265
- ["example1.jpg", "DETR ResNet-50", 0.5, ""],
266
- ["example2.jpg", "OWL-ViT", 0.3, "smartphone, laptop, coffee cup"],
267
  ],
268
- inputs=[image_input, model_dropdown, confidence_slider, custom_classes_input]
 
 
 
269
  )
270
 
271
  if __name__ == "__main__":
272
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
  from PIL import Image, ImageDraw, ImageFont
4
  import torch
5
  import spaces
6
  import numpy as np
7
+ import cv2
8
 
9
+ # Modèles optimisés pour le temps réel
10
+ REALTIME_MODELS = {
11
+ "YOLOS Tiny (ultra-rapide)": "hustvl/yolos-tiny",
12
  "DETR ResNet-50": "facebook/detr-resnet-50",
 
 
 
 
13
  "YOLOS Small": "hustvl/yolos-small",
14
+ "Conditional DETR": "microsoft/conditional-detr-resnet-50"
 
 
15
  }
16
 
17
+ # Cache global pour le modèle
18
+ current_detector = None
19
+ current_model_name = None
20
 
21
+ def load_detector(model_name):
22
+ """Charge le détecteur avec cache"""
23
+ global current_detector, current_model_name
24
+
25
+ if current_model_name != model_name:
26
+ print(f"🔄 Chargement du modèle: {model_name}")
27
+ model_id = REALTIME_MODELS[model_name]
28
+ current_detector = pipeline(
29
+ "object-detection",
30
+ model=model_id,
31
+ device=0 if torch.cuda.is_available() else -1
32
+ )
33
+ current_model_name = model_name
34
+ print(f"✅ Modèle chargé: {model_name}")
35
+
36
+ return current_detector
 
 
 
 
 
37
 
38
  @spaces.GPU
39
+ def process_webcam_frame(frame, model_choice, confidence_threshold):
40
+ """
41
+ Traite chaque frame de la webcam en temps réel
42
+ Cette fonction est appelée automatiquement pour chaque frame
43
+ """
44
+ if frame is None:
45
+ return frame
46
 
47
  try:
48
+ # Charger le détecteur
49
+ detector = load_detector(model_choice)
 
50
 
51
+ # Convertir numpy array en PIL Image si nécessaire
52
+ if isinstance(frame, np.ndarray):
53
+ # Gradio webcam donne du RGB
54
+ pil_image = Image.fromarray(frame)
 
 
 
55
  else:
56
+ pil_image = frame
 
57
 
58
+ # Redimensionner pour accélérer le traitement
59
+ original_size = pil_image.size
60
+ max_size = 640 # Réduire la taille pour plus de vitesse
61
+
62
+ if max(original_size) > max_size:
63
+ ratio = max_size / max(original_size)
64
+ new_size = (int(original_size[0] * ratio), int(original_size[1] * ratio))
65
+ resized_image = pil_image.resize(new_size)
66
+ else:
67
+ resized_image = pil_image
68
+ ratio = 1.0
69
+
70
+ # Détection sur l'image redimensionnée
71
+ detections = detector(resized_image)
72
+
73
+ # Filtrer par confiance
74
+ filtered_detections = [
75
+ det for det in detections
76
+ if det['score'] >= confidence_threshold
77
  ]
78
 
79
+ # Redimensionner les coordonnées vers la taille originale
80
+ for det in filtered_detections:
81
+ if ratio != 1.0:
82
+ det['box']['xmin'] = int(det['box']['xmin'] / ratio)
83
+ det['box']['ymin'] = int(det['box']['ymin'] / ratio)
84
+ det['box']['xmax'] = int(det['box']['xmax'] / ratio)
85
+ det['box']['ymax'] = int(det['box']['ymax'] / ratio)
86
 
87
+ # Dessiner les détections sur l'image originale
88
+ annotated_image = draw_detections_fast(pil_image, filtered_detections)
89
 
90
+ # Convertir back en numpy pour Gradio
91
+ return np.array(annotated_image)
92
 
93
  except Exception as e:
94
+ print(f"❌ Erreur de traitement: {e}")
95
+ return frame
96
 
97
+ def draw_detections_fast(image, detections):
98
+ """Version optimisée pour dessiner les détections"""
99
+ if not detections:
100
+ return image
101
+
102
  draw = ImageDraw.Draw(image)
103
 
104
+ # Police par défaut pour la vitesse
105
  try:
 
 
106
  font = ImageFont.load_default()
107
+ except:
108
+ font = None
109
 
110
+ colors = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FECA57"]
 
 
 
111
 
112
  for i, detection in enumerate(detections):
113
  box = detection['box']
114
  label = detection['label']
115
  score = detection['score']
116
 
117
+ # Coordonnées
118
  x1, y1 = box['xmin'], box['ymin']
119
  x2, y2 = box['xmax'], box['ymax']
120
 
121
+ # Couleur
122
  color = colors[i % len(colors)]
123
 
124
+ # Boîte
125
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
 
 
 
126
 
127
+ # Label avec score
128
+ text = f"{label} {score:.2f}"
 
129
 
130
+ # Fond du texte (simplifié)
131
+ if font:
132
+ bbox = draw.textbbox((x1, y1-20), text, font=font)
133
+ draw.rectangle(bbox, fill=color)
134
+ draw.text((x1, y1-20), text, fill="white", font=font)
135
+ else:
136
+ draw.text((x1, y1-15), text, fill=color)
137
 
138
  return image
139
 
140
+ # Interface Gradio avec streaming
141
+ with gr.Blocks(title="🎥 Détection Live", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  gr.Markdown("""
144
+ # 🎥 Détection d'Objets en Temps Réel
145
 
146
+ **Activez votre webcam** et voyez la détection se faire en direct !
147
 
148
+ **Optimisé pour la vitesse** avec des modèles légers
 
 
 
 
149
  """)
150
 
151
  with gr.Row():
152
+ with gr.Column(scale=2):
153
+ # Composant webcam avec streaming
154
+ webcam = gr.Interface(
155
+ fn=process_webcam_frame,
156
+ inputs=[
157
+ gr.Image(sources=["webcam"], streaming=True, type="numpy"),
158
+ gr.Dropdown(
159
+ choices=list(REALTIME_MODELS.keys()),
160
+ value="YOLOS Tiny (ultra-rapide)",
161
+ label="🤖 Modèle (changement en direct)"
162
+ ),
163
+ gr.Slider(
164
+ minimum=0.1,
165
+ maximum=1.0,
166
+ value=0.5,
167
+ step=0.1,
168
+ label="🎯 Seuil de confiance"
169
+ )
170
+ ],
171
+ outputs=gr.Image(type="numpy", streaming=True),
172
+ live=True, # ⭐ CRUCIAL: Active le mode live
173
+ title=None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  )
175
 
176
  with gr.Column(scale=1):
177
+ gr.Markdown("""
178
+ ## 📊 Informations Live
 
 
 
179
 
180
+ ### 🎛️ Contrôles en temps réel:
181
+ - **Modèle**: Change instantanément
182
+ - **Confiance**: Ajuste le filtrage
183
+ - **Streaming**: Traitement frame par frame
184
+
185
+ ### ⚡ Optimisations:
186
+ - Images redimensionnées à 640px
187
+ - Modèles légers prioritaires
188
+ - Cache intelligent des modèles
189
+ - Dessin optimisé
190
+
191
+ ### 🎯 Modèles recommandés:
192
+ - **YOLOS Tiny**: Maximum de vitesse
193
+ - **DETR ResNet-50**: Bon équilibre
194
+ """)
195
+
196
+ # Version alternative avec Interface simple
197
+ gr.Markdown("---")
198
+ gr.Markdown("## 🎥 Version Alternative (Interface Simple)")
199
+
200
+ alternative_interface = gr.Interface(
201
+ fn=process_webcam_frame,
202
+ inputs=[
203
+ gr.Image(sources=["webcam"], streaming=True),
204
+ gr.Dropdown(
205
+ choices=list(REALTIME_MODELS.keys()),
206
+ value="YOLOS Tiny (ultra-rapide)"
207
+ ),
208
+ gr.Slider(0.1, 1.0, 0.5, step=0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  ],
210
+ outputs=gr.Image(streaming=True),
211
+ live=True, # ⭐ Mode live activé
212
+ title="Détection Webcam Live",
213
+ description="Cliquez sur la webcam pour démarrer le streaming live!"
214
  )
215
 
216
  if __name__ == "__main__":
217
+ demo.launch()