2nzi commited on
Commit
2da5287
·
1 Parent(s): 7d0ff20

update for line plot frontend

Browse files
Files changed (2) hide show
  1. api.py +703 -529
  2. requirements.txt +22 -24
api.py CHANGED
@@ -1,530 +1,704 @@
1
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
- from typing import Dict, List, Any, Optional
5
- import json
6
- import tempfile
7
- import os
8
- from PIL import Image
9
- import numpy as np
10
- import cv2
11
- import torch
12
- import torchvision.transforms as T
13
- import torchvision.transforms.functional as f
14
- import yaml
15
- from tqdm import tqdm
16
- from huggingface_hub import hf_hub_download
17
-
18
- from get_camera_params import get_camera_parameters
19
-
20
- # Imports pour l'inférence automatique
21
- from model.cls_hrnet import get_cls_net
22
- from model.cls_hrnet_l import get_cls_net as get_cls_net_l
23
- from utils.utils_calib import FramebyFrameCalib
24
- from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, complete_keypoints, coords_to_dict
25
-
26
- app = FastAPI(
27
- title="Football Vision Calibration API",
28
- description="API pour la calibration de caméras à partir de lignes de terrain de football",
29
- version="1.0.0"
30
- )
31
-
32
- # Configuration CORS pour autoriser les requêtes depuis le frontend
33
- app.add_middleware(
34
- CORSMiddleware,
35
- allow_origins=["*"], # En production, spécifiez les domaines autorisés
36
- allow_credentials=True,
37
- allow_methods=["*"],
38
- allow_headers=["*"],
39
- )
40
-
41
- # Paramètres par défaut pour l'inférence
42
- WEIGHTS_KP = "models/SV_FT_TSWC_kp"
43
- WEIGHTS_LINE = "models/SV_FT_TSWC_lines"
44
- # DEVICE = "cuda:0"
45
- DEVICE = "cpu"
46
- KP_THRESHOLD = 0.15
47
- LINE_THRESHOLD = 0.15
48
- PNL_REFINE = True
49
- FRAME_STEP = 5
50
-
51
- # Cache pour les modèles (éviter de les recharger à chaque requête)
52
- _models_cache = None
53
-
54
- # Paramètres pour HF Hub
55
- HF_MODEL_REPO = "2nzi/SV_FT_TSWC_kp" # Remplacez par votre repo
56
- WEIGHTS_KP_FILE = "SV_FT_TSWC_kp" # Nom du fichier dans le repo
57
- WEIGHTS_LINE_FILE = "SV_FT_TSWC_lines" # Nom du fichier dans le repo
58
-
59
- def load_inference_models():
60
- """Charge les modèles d'inférence depuis Hugging Face Hub"""
61
- global _models_cache
62
-
63
- if _models_cache is not None:
64
- return _models_cache
65
-
66
- try:
67
- # Device detection
68
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69
- print(f"Using device: {device}")
70
-
71
- # Télécharger les modèles depuis HF Hub
72
- print("Téléchargement des modèles depuis Hugging Face Hub...")
73
-
74
- weights_kp_path = hf_hub_download(
75
- repo_id=HF_MODEL_REPO,
76
- filename=WEIGHTS_KP_FILE,
77
- cache_dir="./hf_cache"
78
- )
79
-
80
- weights_line_path = hf_hub_download(
81
- repo_id=HF_MODEL_REPO,
82
- filename=WEIGHTS_LINE_FILE,
83
- cache_dir="./hf_cache"
84
- )
85
-
86
- print(f"Modèles téléchargés:")
87
- print(f" - Keypoints: {weights_kp_path}")
88
- print(f" - Lines: {weights_line_path}")
89
-
90
- # Vérifier l'existence des fichiers de configuration
91
- config_files = ["config/hrnetv2_w48.yaml", "config/hrnetv2_w48_l.yaml"]
92
- for config_file in config_files:
93
- if not os.path.exists(config_file):
94
- raise FileNotFoundError(f"Fichier de configuration manquant: {config_file}")
95
-
96
- # Charger les configurations
97
- with open("config/hrnetv2_w48.yaml", 'r') as f:
98
- cfg = yaml.safe_load(f)
99
- with open("config/hrnetv2_w48_l.yaml", 'r') as f:
100
- cfg_l = yaml.safe_load(f)
101
-
102
- # Modèle keypoints
103
- model = get_cls_net(cfg)
104
- model.load_state_dict(torch.load(weights_kp_path, map_location=device))
105
- model.to(device)
106
- model.eval()
107
-
108
- # Modèle lignes
109
- model_l = get_cls_net_l(cfg_l)
110
- model_l.load_state_dict(torch.load(weights_line_path, map_location=device))
111
- model_l.to(device)
112
- model_l.eval()
113
-
114
- _models_cache = (model, model_l, device)
115
- print("✅ Modèles chargés avec succès depuis HF Hub!")
116
- return _models_cache
117
-
118
- except Exception as e:
119
- print(f"❌ Erreur lors du chargement des modèles: {e}")
120
- raise HTTPException(
121
- status_code=503,
122
- detail=f"Modèles non disponibles: {str(e)}. Veuillez réessayer plus tard."
123
- )
124
-
125
- def process_frame_inference(frame, model, model_l, device, frame_width, frame_height):
126
- """Traite une frame et retourne les paramètres de caméra"""
127
- transform = T.Resize((540, 960))
128
-
129
- # Préparer la frame pour l'inférence
130
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
131
- frame_pil = Image.fromarray(frame_rgb)
132
- frame_tensor = f.to_tensor(frame_pil).float().unsqueeze(0)
133
-
134
- if frame_tensor.size()[-1] != 960:
135
- frame_tensor = transform(frame_tensor)
136
-
137
- frame_tensor = frame_tensor.to(device)
138
- b, c, h, w = frame_tensor.size()
139
-
140
- # Inférence
141
- with torch.no_grad():
142
- heatmaps = model(frame_tensor)
143
- heatmaps_l = model_l(frame_tensor)
144
-
145
- # Extraire les keypoints et lignes
146
- kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
147
- line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
148
- kp_dict = coords_to_dict(kp_coords, threshold=KP_THRESHOLD)
149
- lines_dict = coords_to_dict(line_coords, threshold=LINE_THRESHOLD)
150
- kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h, normalize=True)
151
-
152
- # Calibration
153
- cam = FramebyFrameCalib(iwidth=frame_width, iheight=frame_height, denormalize=True)
154
- cam.update(kp_dict, lines_dict)
155
- final_params_dict = cam.heuristic_voting(refine_lines=PNL_REFINE)
156
-
157
- return final_params_dict
158
-
159
- # Modèles Pydantic pour la validation des données
160
- class Point(BaseModel):
161
- x: float
162
- y: float
163
-
164
- class LinePolygon(BaseModel):
165
- points: List[Point]
166
-
167
- class CalibrationRequest(BaseModel):
168
- lines: Dict[str, List[Point]]
169
-
170
- class CalibrationResponse(BaseModel):
171
- status: str
172
- camera_parameters: Dict[str, Any]
173
- input_lines: Dict[str, List[Point]]
174
- message: str
175
-
176
- class InferenceImageResponse(BaseModel):
177
- status: str
178
- camera_parameters: Optional[Dict[str, Any]]
179
- image_info: Dict[str, Any]
180
- message: str
181
-
182
- class InferenceVideoResponse(BaseModel):
183
- status: str
184
- camera_parameters: List[Dict[str, Any]]
185
- video_info: Dict[str, Any]
186
- frames_processed: int
187
- message: str
188
-
189
- @app.get("/")
190
- async def root():
191
- return {
192
- "message": "Football Vision Calibration API",
193
- "version": "1.0.0",
194
- "endpoints": {
195
- "/calibrate": "POST - Calibrer une caméra à partir d'une image et de lignes",
196
- "/inference/image": "POST - Extraire les paramètres de caméra d'une image automatiquement",
197
- "/inference/video": "POST - Extraire les paramètres de caméra d'une vidéo automatiquement",
198
- "/health": "GET - Vérifier l'état de l'API"
199
- }
200
- }
201
-
202
- @app.get("/health")
203
- async def health_check():
204
- return {"status": "healthy", "message": "API is running"}
205
-
206
- @app.post("/calibrate", response_model=CalibrationResponse)
207
- async def calibrate_camera(
208
- image: UploadFile = File(..., description="Image du terrain de football"),
209
- lines_data: str = Form(..., description="JSON des lignes du terrain")
210
- ):
211
- """
212
- Calibrer une caméra à partir d'une image et des lignes du terrain.
213
-
214
- Args:
215
- image: Image du terrain de football (formats: jpg, jpeg, png)
216
- lines_data: JSON contenant les lignes du terrain au format:
217
- {"nom_ligne": [{"x": float, "y": float}, ...], ...}
218
-
219
- Returns:
220
- Paramètres de calibration de la caméra et lignes d'entrée
221
- """
222
- try:
223
- # Validation du format d'image - version robuste
224
- content_type = getattr(image, 'content_type', None) or ""
225
- filename = getattr(image, 'filename', "") or ""
226
-
227
- # Vérifier le type MIME ou l'extension du fichier
228
- image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
229
- is_image_content = content_type.startswith('image/') if content_type else False
230
- is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions)
231
-
232
- if not is_image_content and not is_image_extension:
233
- raise HTTPException(
234
- status_code=400,
235
- detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}"
236
- )
237
-
238
- # Parse des données de lignes
239
- try:
240
- lines_dict = json.loads(lines_data)
241
- except json.JSONDecodeError:
242
- raise HTTPException(status_code=400, detail="Format JSON invalide pour les lignes")
243
-
244
- # Validation de la structure des lignes
245
- validated_lines = {}
246
- for line_name, points in lines_dict.items():
247
- if not isinstance(points, list):
248
- raise HTTPException(
249
- status_code=400,
250
- detail=f"Les points de la ligne '{line_name}' doivent être une liste"
251
- )
252
-
253
- validated_points = []
254
- for i, point in enumerate(points):
255
- if not isinstance(point, dict) or 'x' not in point or 'y' not in point:
256
- raise HTTPException(
257
- status_code=400,
258
- detail=f"Point {i} de la ligne '{line_name}' doit avoir les clés 'x' et 'y'"
259
- )
260
- try:
261
- validated_points.append({
262
- "x": float(point['x']),
263
- "y": float(point['y'])
264
- })
265
- except (ValueError, TypeError):
266
- raise HTTPException(
267
- status_code=400,
268
- detail=f"Coordonnées invalides pour le point {i} de la ligne '{line_name}'"
269
- )
270
-
271
- validated_lines[line_name] = validated_points
272
-
273
- # Sauvegarde temporaire de l'image
274
- file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
275
- with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
276
- content = await image.read()
277
- temp_file.write(content)
278
- temp_image_path = temp_file.name
279
-
280
- try:
281
- # Validation de l'image
282
- pil_image = Image.open(temp_image_path)
283
- pil_image.verify() # Vérification de l'intégrité de l'image
284
-
285
- # Calibration de la caméra
286
- camera_params = get_camera_parameters(temp_image_path, validated_lines)
287
-
288
- # Formatage de la réponse
289
- response = CalibrationResponse(
290
- status="success",
291
- camera_parameters=camera_params,
292
- input_lines=validated_lines,
293
- message="Calibration réussie"
294
- )
295
-
296
- return response
297
-
298
- except Exception as e:
299
- raise HTTPException(
300
- status_code=500,
301
- detail=f"Erreur lors de la calibration: {str(e)}"
302
- )
303
-
304
- finally:
305
- # Nettoyage du fichier temporaire
306
- if os.path.exists(temp_image_path):
307
- os.unlink(temp_image_path)
308
-
309
- except HTTPException:
310
- raise
311
- except Exception as e:
312
- raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
313
-
314
- @app.post("/inference/image", response_model=InferenceImageResponse)
315
- async def inference_image(
316
- image: UploadFile = File(..., description="Image du terrain de football"),
317
- kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"),
318
- line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes")
319
- ):
320
- """
321
- Extraire automatiquement les paramètres de caméra à partir d'une image.
322
-
323
- Args:
324
- image: Image du terrain de football (formats: jpg, jpeg, png)
325
- kp_threshold: Seuil pour la détection des keypoints (défaut: 0.15)
326
- line_threshold: Seuil pour la détection des lignes (défaut: 0.15)
327
-
328
- Returns:
329
- Paramètres de calibration de la caméra extraits automatiquement
330
- """
331
- params = None # Initialiser params
332
- try:
333
- # Validation du format d'image - version robuste
334
- content_type = getattr(image, 'content_type', None) or ""
335
- filename = getattr(image, 'filename', "") or ""
336
-
337
- # Vérifier le type MIME ou l'extension du fichier
338
- image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
339
- is_image_content = content_type.startswith('image/') if content_type else False
340
- is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions)
341
-
342
- if not is_image_content and not is_image_extension:
343
- raise HTTPException(
344
- status_code=400,
345
- detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}"
346
- )
347
-
348
- # Sauvegarde temporaire de l'image
349
- file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
350
- with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
351
- content = await image.read()
352
- temp_file.write(content)
353
- temp_image_path = temp_file.name
354
-
355
- try:
356
- # Charger les modèles
357
- model, model_l, device = load_inference_models()
358
-
359
- # Lire l'image
360
- frame = cv2.imread(temp_image_path)
361
- if frame is None:
362
- raise HTTPException(status_code=400, detail="Impossible de lire l'image")
363
-
364
- frame_height, frame_width = frame.shape[:2]
365
-
366
- # Mettre à jour les seuils globaux
367
- global KP_THRESHOLD, LINE_THRESHOLD
368
- KP_THRESHOLD = kp_threshold
369
- LINE_THRESHOLD = line_threshold
370
-
371
- # Traitement
372
- params = process_frame_inference(frame, model, model_l, device, frame_width, frame_height)
373
-
374
- # Formatage de la réponse
375
- response = InferenceImageResponse(
376
- status="success" if params is not None else "failed",
377
- camera_parameters=params,
378
- image_info={
379
- "filename": filename,
380
- "width": frame_width,
381
- "height": frame_height,
382
- "kp_threshold": kp_threshold,
383
- "line_threshold": line_threshold
384
- },
385
- message="Paramètres extraits avec succès" if params is not None else "Échec de l'extraction des paramètres"
386
- )
387
-
388
- return response
389
-
390
- except Exception as e:
391
- raise HTTPException(
392
- status_code=500,
393
- detail=f"Erreur lors de l'inférence: {str(e)}"
394
- )
395
-
396
- finally:
397
- # Nettoyage du fichier temporaire
398
- if os.path.exists(temp_image_path):
399
- os.unlink(temp_image_path)
400
-
401
- except HTTPException:
402
- raise
403
- except Exception as e:
404
- raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
405
-
406
- @app.post("/inference/video", response_model=InferenceVideoResponse)
407
- async def inference_video(
408
- video: UploadFile = File(..., description="Vidéo du terrain de football"),
409
- kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"),
410
- line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes"),
411
- frame_step: int = Form(FRAME_STEP, description="Traiter 1 frame sur N")
412
- ):
413
- """
414
- Extraire automatiquement les paramètres de caméra à partir d'une vidéo.
415
-
416
- Args:
417
- video: Vidéo du terrain de football (formats: mp4, avi, mov, etc.)
418
- kp_threshold: Seuil pour la détection des keypoints (défaut: 0.15)
419
- line_threshold: Seuil pour la détection des lignes (défaut: 0.15)
420
- frame_step: Traiter 1 frame sur N pour accélérer le traitement (défaut: 5)
421
-
422
- Returns:
423
- Liste des paramètres de calibration de la caméra pour chaque frame traitée
424
- """
425
- try:
426
- # Validation du format vidéo - version robuste
427
- content_type = getattr(video, 'content_type', None) or ""
428
- filename = getattr(video, 'filename', "") or ""
429
-
430
- video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv']
431
- is_video_content = content_type.startswith('video/') if content_type else False
432
- is_video_extension = any(filename.lower().endswith(ext) for ext in video_extensions)
433
-
434
- if not is_video_content and not is_video_extension:
435
- raise HTTPException(
436
- status_code=400,
437
- detail=f"Le fichier doit être une vidéo. Type détecté: {content_type}, Fichier: {filename}"
438
- )
439
-
440
- # Sauvegarde temporaire de la vidéo
441
- file_extension = os.path.splitext(filename)[1] if filename else '.mp4'
442
- with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
443
- content = await video.read()
444
- temp_file.write(content)
445
- temp_video_path = temp_file.name
446
-
447
- try:
448
- # Charger les modèles
449
- model, model_l, device = load_inference_models()
450
-
451
- # Ouvrir la vidéo
452
- cap = cv2.VideoCapture(temp_video_path)
453
- if not cap.isOpened():
454
- raise HTTPException(status_code=400, detail="Impossible d'ouvrir la vidéo")
455
-
456
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
457
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
458
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
459
- fps = int(cap.get(cv2.CAP_PROP_FPS))
460
-
461
- # Mettre à jour les seuils globaux
462
- global KP_THRESHOLD, LINE_THRESHOLD
463
- KP_THRESHOLD = kp_threshold
464
- LINE_THRESHOLD = line_threshold
465
-
466
- all_params = []
467
- frame_count = 0
468
- processed_count = 0
469
-
470
- while cap.isOpened():
471
- ret, frame = cap.read()
472
- if not ret:
473
- break
474
-
475
- # Traiter seulement 1 frame sur frame_step
476
- if frame_count % frame_step != 0:
477
- frame_count += 1
478
- continue
479
-
480
- # Traitement
481
- params = process_frame_inference(frame, model, model_l, device, frame_width, frame_height)
482
-
483
- if params is not None:
484
- params['frame_number'] = frame_count
485
- params['timestamp_seconds'] = frame_count / fps
486
- all_params.append(params)
487
- processed_count += 1
488
-
489
- frame_count += 1
490
-
491
- cap.release()
492
-
493
- # Formatage de la réponse
494
- response = InferenceVideoResponse(
495
- status="success" if all_params else "failed",
496
- camera_parameters=all_params,
497
- video_info={
498
- "filename": filename,
499
- "width": frame_width,
500
- "height": frame_height,
501
- "total_frames": total_frames,
502
- "fps": fps,
503
- "duration_seconds": total_frames / fps,
504
- "kp_threshold": kp_threshold,
505
- "line_threshold": line_threshold,
506
- "frame_step": frame_step
507
- },
508
- frames_processed=processed_count,
509
- message=f"Paramètres extraits de {processed_count} frames" if all_params else "Aucun paramètre extrait"
510
- )
511
-
512
- return response
513
-
514
- except Exception as e:
515
- raise HTTPException(
516
- status_code=500,
517
- detail=f"Erreur lors de l'inférence vidéo: {str(e)}"
518
- )
519
-
520
- finally:
521
- # Nettoyage du fichier temporaire
522
- if os.path.exists(temp_video_path):
523
- os.unlink(temp_video_path)
524
-
525
- except HTTPException:
526
- raise
527
- except Exception as e:
528
- raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
529
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  app_instance = app
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import Dict, List, Any, Optional
5
+ import json
6
+ import tempfile
7
+ import os
8
+ from PIL import Image
9
+ import numpy as np
10
+ import cv2
11
+ import torch
12
+ import torchvision.transforms as T
13
+ import torchvision.transforms.functional as f
14
+ import yaml
15
+ from tqdm import tqdm
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ from get_camera_params import get_camera_parameters
19
+
20
+ # Imports pour l'inférence automatique
21
+ from model.cls_hrnet import get_cls_net
22
+ from model.cls_hrnet_l import get_cls_net as get_cls_net_l
23
+ from utils.utils_calib import FramebyFrameCalib
24
+ from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, complete_keypoints, coords_to_dict
25
+
26
+ app = FastAPI(
27
+ title="Football Vision Calibration API",
28
+ description="API pour la calibration de caméras à partir de lignes de terrain de football",
29
+ version="1.0.0"
30
+ )
31
+
32
+ # Configuration CORS pour autoriser les requêtes depuis le frontend
33
+ app.add_middleware(
34
+ CORSMiddleware,
35
+ allow_origins=["*"], # En production, spécifiez les domaines autorisés
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
+ )
40
+
41
+ # Paramètres par défaut pour l'inférence
42
+ WEIGHTS_KP = "models/SV_FT_TSWC_kp"
43
+ WEIGHTS_LINE = "models/SV_FT_TSWC_lines"
44
+ # DEVICE = "cuda:0"
45
+ DEVICE = "cpu"
46
+ KP_THRESHOLD = 0.15
47
+ LINE_THRESHOLD = 0.15
48
+ PNL_REFINE = True
49
+ FRAME_STEP = 5
50
+
51
+ # Cache pour les modèles (éviter de les recharger à chaque requête)
52
+ _models_cache = None
53
+
54
+ # Paramètres pour HF Hub
55
+ HF_MODEL_REPO = "2nzi/SV_FT_TSWC_kp" # Remplacez par votre repo
56
+ WEIGHTS_KP_FILE = "SV_FT_TSWC_kp" # Nom du fichier dans le repo
57
+ WEIGHTS_LINE_FILE = "SV_FT_TSWC_lines" # Nom du fichier dans le repo
58
+
59
+ def load_inference_models():
60
+ """Charge les modèles d'inférence depuis Hugging Face Hub"""
61
+ global _models_cache
62
+
63
+ if _models_cache is not None:
64
+ return _models_cache
65
+
66
+ try:
67
+ # Device detection
68
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69
+ print(f"Using device: {device}")
70
+
71
+ # Télécharger les modèles depuis HF Hub
72
+ print("Téléchargement des modèles depuis Hugging Face Hub...")
73
+
74
+ weights_kp_path = hf_hub_download(
75
+ repo_id=HF_MODEL_REPO,
76
+ filename=WEIGHTS_KP_FILE,
77
+ cache_dir="./hf_cache"
78
+ )
79
+
80
+ weights_line_path = hf_hub_download(
81
+ repo_id=HF_MODEL_REPO,
82
+ filename=WEIGHTS_LINE_FILE,
83
+ cache_dir="./hf_cache"
84
+ )
85
+
86
+ print(f"Modèles téléchargés:")
87
+ print(f" - Keypoints: {weights_kp_path}")
88
+ print(f" - Lines: {weights_line_path}")
89
+
90
+ # Vérifier l'existence des fichiers de configuration
91
+ config_files = ["config/hrnetv2_w48.yaml", "config/hrnetv2_w48_l.yaml"]
92
+ for config_file in config_files:
93
+ if not os.path.exists(config_file):
94
+ raise FileNotFoundError(f"Fichier de configuration manquant: {config_file}")
95
+
96
+ # Charger les configurations
97
+ with open("config/hrnetv2_w48.yaml", 'r') as f:
98
+ cfg = yaml.safe_load(f)
99
+ with open("config/hrnetv2_w48_l.yaml", 'r') as f:
100
+ cfg_l = yaml.safe_load(f)
101
+
102
+ # Modèle keypoints
103
+ model = get_cls_net(cfg)
104
+ model.load_state_dict(torch.load(weights_kp_path, map_location=device))
105
+ model.to(device)
106
+ model.eval()
107
+
108
+ # Modèle lignes
109
+ model_l = get_cls_net_l(cfg_l)
110
+ model_l.load_state_dict(torch.load(weights_line_path, map_location=device))
111
+ model_l.to(device)
112
+ model_l.eval()
113
+
114
+ _models_cache = (model, model_l, device)
115
+ print("✅ Modèles chargés avec succès depuis HF Hub!")
116
+ return _models_cache
117
+
118
+ except Exception as e:
119
+ print(f"❌ Erreur lors du chargement des modèles: {e}")
120
+ raise HTTPException(
121
+ status_code=503,
122
+ detail=f"Modèles non disponibles: {str(e)}. Veuillez réessayer plus tard."
123
+ )
124
+
125
+ def process_frame_inference(frame, model, model_l, device, frame_width, frame_height):
126
+ """Traite une frame et retourne les paramètres de caméra"""
127
+ transform = T.Resize((540, 960))
128
+
129
+ # Préparer la frame pour l'inférence
130
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
131
+ frame_pil = Image.fromarray(frame_rgb)
132
+ frame_tensor = f.to_tensor(frame_pil).float().unsqueeze(0)
133
+
134
+ if frame_tensor.size()[-1] != 960:
135
+ frame_tensor = transform(frame_tensor)
136
+
137
+ frame_tensor = frame_tensor.to(device)
138
+ b, c, h, w = frame_tensor.size()
139
+
140
+ # Inférence
141
+ with torch.no_grad():
142
+ heatmaps = model(frame_tensor)
143
+ heatmaps_l = model_l(frame_tensor)
144
+
145
+ # Extraire les keypoints et lignes
146
+ kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
147
+ line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
148
+ kp_dict = coords_to_dict(kp_coords, threshold=KP_THRESHOLD)
149
+ lines_dict = coords_to_dict(line_coords, threshold=LINE_THRESHOLD)
150
+ kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h, normalize=True)
151
+
152
+ # Calibration
153
+ cam = FramebyFrameCalib(iwidth=frame_width, iheight=frame_height, denormalize=True)
154
+ cam.update(kp_dict, lines_dict)
155
+ final_params_dict = cam.heuristic_voting(refine_lines=PNL_REFINE)
156
+
157
+ return final_params_dict
158
+
159
+ # Modèles Pydantic pour la validation des données
160
+ class Point(BaseModel):
161
+ x: float
162
+ y: float
163
+
164
+ class LinePolygon(BaseModel):
165
+ points: List[Point]
166
+
167
+ class CalibrationRequest(BaseModel):
168
+ lines: Dict[str, List[Point]]
169
+
170
+ # Nouvelles classes pour les coordonnées
171
+ class ImageCoordinate(BaseModel):
172
+ x: float
173
+ y: float
174
+
175
+ class WorldCoordinate(BaseModel):
176
+ x: float # Coordonnée X sur le terrain (mètres)
177
+ y: float # Coordonnée Y sur le terrain (mètres)
178
+
179
+ class KeypointData(BaseModel):
180
+ id: int
181
+ image_coords: ImageCoordinate
182
+ world_coords: Optional[WorldCoordinate] = None
183
+ confidence: Optional[float] = None
184
+
185
+ class LineData(BaseModel):
186
+ id: int
187
+ start_image: ImageCoordinate
188
+ end_image: ImageCoordinate
189
+ start_world: Optional[WorldCoordinate] = None
190
+ end_world: Optional[WorldCoordinate] = None
191
+ confidence: Optional[float] = None
192
+
193
+ class DetectionData(BaseModel):
194
+ keypoints: List[KeypointData]
195
+ lines: List[LineData]
196
+
197
+ class CalibrationResponse(BaseModel):
198
+ status: str
199
+ camera_parameters: Dict[str, Any]
200
+ input_lines: Dict[str, List[Point]]
201
+ detections: Optional[DetectionData] = None
202
+ message: str
203
+
204
+ class InferenceImageResponse(BaseModel):
205
+ status: str
206
+ camera_parameters: Optional[Dict[str, Any]]
207
+ image_info: Dict[str, Any]
208
+ detections: Optional[DetectionData] = None
209
+ message: str
210
+
211
+ class FrameResult(BaseModel):
212
+ frame_number: int
213
+ timestamp_seconds: float
214
+ camera_parameters: Optional[Dict[str, Any]]
215
+ detections: Optional[DetectionData] = None
216
+
217
+ class InferenceVideoResponse(BaseModel):
218
+ status: str
219
+ video_info: Dict[str, Any]
220
+ frames_processed: int
221
+ frames_results: List[FrameResult]
222
+ message: str
223
+
224
+ # Fonction de conversion coordonnées image -> terrain
225
+ def image_to_world(point_2d, cam_params):
226
+ """
227
+ Convertit un point 2D de l'image vers des coordonnées 3D du terrain.
228
+
229
+ Args:
230
+ point_2d: [x, y] coordonnées dans l'image
231
+ cam_params: Paramètres de la caméra
232
+
233
+ Returns:
234
+ [x, y] coordonnées sur le terrain (Z=0)
235
+ """
236
+ try:
237
+ # Matrice de calibration
238
+ K = np.array([
239
+ [cam_params["cam_params"]["x_focal_length"], 0, cam_params["cam_params"]["principal_point"][0]],
240
+ [0, cam_params["cam_params"]["y_focal_length"], cam_params["cam_params"]["principal_point"][1]],
241
+ [0, 0, 1]
242
+ ])
243
+
244
+ # Matrice de rotation et position
245
+ R = np.array(cam_params["cam_params"]["rotation_matrix"])
246
+ camera_pos = np.array(cam_params["cam_params"]["position_meters"])
247
+
248
+ # Point 2D en coordonnées homogènes
249
+ point_2d_h = np.array([point_2d[0], point_2d[1], 1])
250
+
251
+ # Back-projection du rayon depuis la caméra
252
+ ray = np.linalg.inv(K) @ point_2d_h
253
+ ray = R.T @ ray
254
+
255
+ # Intersection avec le plan Z=0 (terrain)
256
+ if abs(ray[2]) < 1e-6: # Éviter division par zéro
257
+ return None
258
+
259
+ t = -camera_pos[2] / ray[2]
260
+ world_point = camera_pos + t * ray
261
+
262
+ return world_point[:2] # Retourner seulement X,Y (Z=0)
263
+
264
+ except Exception as e:
265
+ print(f"Erreur conversion image->monde: {e}")
266
+ return None
267
+
268
+ def process_detections(kp_dict, lines_dict, cam_params):
269
+ """
270
+ Traite les détections et convertit les coordonnées image vers terrain.
271
+
272
+ Args:
273
+ kp_dict: Dictionnaire des keypoints détectés
274
+ lines_dict: Dictionnaire des lignes détectées
275
+ cam_params: Paramètres de la caméra
276
+
277
+ Returns:
278
+ DetectionData: Données structurées avec coordonnées image et terrain
279
+ """
280
+ keypoints_data = []
281
+ lines_data = []
282
+
283
+ # Traiter les keypoints
284
+ if kp_dict:
285
+ for kp_id, kp_value in kp_dict.items():
286
+ image_coords = ImageCoordinate(x=kp_value['x'], y=kp_value['y'])
287
+
288
+ # Convertir vers coordonnées terrain
289
+ world_coords = None
290
+ if cam_params and 'cam_params' in cam_params:
291
+ world_point = image_to_world([kp_value['x'], kp_value['y']], cam_params)
292
+ if world_point is not None:
293
+ world_coords = WorldCoordinate(x=float(world_point[0]), y=float(world_point[1]))
294
+
295
+ keypoints_data.append(KeypointData(
296
+ id=kp_id,
297
+ image_coords=image_coords,
298
+ world_coords=world_coords,
299
+ confidence=kp_value.get('confidence', None)
300
+ ))
301
+
302
+ # Traiter les lignes
303
+ if lines_dict:
304
+ for line_id, line_value in lines_dict.items():
305
+ start_image = ImageCoordinate(x=line_value['x_1'], y=line_value['y_1'])
306
+ end_image = ImageCoordinate(x=line_value['x_2'], y=line_value['y_2'])
307
+
308
+ # Convertir vers coordonnées terrain
309
+ start_world = None
310
+ end_world = None
311
+ if cam_params and 'cam_params' in cam_params:
312
+ start_point = image_to_world([line_value['x_1'], line_value['y_1']], cam_params)
313
+ end_point = image_to_world([line_value['x_2'], line_value['y_2']], cam_params)
314
+
315
+ if start_point is not None:
316
+ start_world = WorldCoordinate(x=float(start_point[0]), y=float(start_point[1]))
317
+ if end_point is not None:
318
+ end_world = WorldCoordinate(x=float(end_point[0]), y=float(end_point[1]))
319
+
320
+ lines_data.append(LineData(
321
+ id=line_id,
322
+ start_image=start_image,
323
+ end_image=end_image,
324
+ start_world=start_world,
325
+ end_world=end_world,
326
+ confidence=line_value.get('confidence', None)
327
+ ))
328
+
329
+ return DetectionData(keypoints=keypoints_data, lines=lines_data)
330
+
331
+ def process_frame_inference_with_coords(frame, model, model_l, device, frame_width, frame_height):
332
+ """
333
+ Version enrichie qui retourne les paramètres de caméra ET les coordonnées détectées.
334
+ """
335
+ transform = T.Resize((540, 960))
336
+
337
+ # Préparer la frame pour l'inférence
338
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
339
+ frame_pil = Image.fromarray(frame_rgb)
340
+ frame_tensor = f.to_tensor(frame_pil).float().unsqueeze(0)
341
+
342
+ if frame_tensor.size()[-1] != 960:
343
+ frame_tensor = transform(frame_tensor)
344
+
345
+ frame_tensor = frame_tensor.to(device)
346
+ b, c, h, w = frame_tensor.size()
347
+
348
+ # Inférence
349
+ with torch.no_grad():
350
+ heatmaps = model(frame_tensor)
351
+ heatmaps_l = model_l(frame_tensor)
352
+
353
+ # Extraire les keypoints et lignes
354
+ kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
355
+ line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
356
+ kp_dict = coords_to_dict(kp_coords, threshold=KP_THRESHOLD)
357
+ lines_dict = coords_to_dict(line_coords, threshold=LINE_THRESHOLD)
358
+ kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h, normalize=True)
359
+
360
+ # Calibration
361
+ cam = FramebyFrameCalib(iwidth=frame_width, iheight=frame_height, denormalize=True)
362
+ cam.update(kp_dict, lines_dict)
363
+ final_params_dict = cam.heuristic_voting(refine_lines=PNL_REFINE)
364
+
365
+ return final_params_dict, kp_dict, lines_dict
366
+
367
+ @app.get("/")
368
+ async def root():
369
+ return {
370
+ "message": "Football Vision Calibration API",
371
+ "version": "1.0.0",
372
+ "endpoints": {
373
+ "/calibrate": "POST - Calibrer une caméra à partir d'une image et de lignes",
374
+ "/inference/image": "POST - Extraire les paramètres de caméra d'une image automatiquement",
375
+ "/inference/video": "POST - Extraire les paramètres de caméra d'une vidéo automatiquement",
376
+ "/health": "GET - Vérifier l'état de l'API"
377
+ }
378
+ }
379
+
380
+ @app.get("/health")
381
+ async def health_check():
382
+ return {"status": "healthy", "message": "API is running"}
383
+
384
+ @app.post("/calibrate", response_model=CalibrationResponse)
385
+ async def calibrate_camera(
386
+ image: UploadFile = File(..., description="Image du terrain de football"),
387
+ lines_data: str = Form(..., description="JSON des lignes du terrain")
388
+ ):
389
+ """
390
+ Calibrer une caméra à partir d'une image et des lignes du terrain.
391
+ Retourne aussi les coordonnées détectées sur l'image et le terrain.
392
+ """
393
+ try:
394
+ # Validation du format d'image - version robuste
395
+ content_type = getattr(image, 'content_type', None) or ""
396
+ filename = getattr(image, 'filename', "") or ""
397
+
398
+ # Vérifier le type MIME ou l'extension du fichier
399
+ image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
400
+ is_image_content = content_type.startswith('image/') if content_type else False
401
+ is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions)
402
+
403
+ if not is_image_content and not is_image_extension:
404
+ raise HTTPException(
405
+ status_code=400,
406
+ detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}"
407
+ )
408
+
409
+ # Parse des données de lignes
410
+ try:
411
+ lines_dict = json.loads(lines_data)
412
+ except json.JSONDecodeError:
413
+ raise HTTPException(status_code=400, detail="Format JSON invalide pour les lignes")
414
+
415
+ # Validation de la structure des lignes
416
+ validated_lines = {}
417
+ for line_name, points in lines_dict.items():
418
+ if not isinstance(points, list):
419
+ raise HTTPException(
420
+ status_code=400,
421
+ detail=f"Les points de la ligne '{line_name}' doivent être une liste"
422
+ )
423
+
424
+ validated_points = []
425
+ for i, point in enumerate(points):
426
+ if not isinstance(point, dict) or 'x' not in point or 'y' not in point:
427
+ raise HTTPException(
428
+ status_code=400,
429
+ detail=f"Point {i} de la ligne '{line_name}' doit avoir les clés 'x' et 'y'"
430
+ )
431
+ try:
432
+ validated_points.append({
433
+ "x": float(point['x']),
434
+ "y": float(point['y'])
435
+ })
436
+ except (ValueError, TypeError):
437
+ raise HTTPException(
438
+ status_code=400,
439
+ detail=f"Coordonnées invalides pour le point {i} de la ligne '{line_name}'"
440
+ )
441
+
442
+ validated_lines[line_name] = validated_points
443
+
444
+ # Sauvegarde temporaire de l'image
445
+ file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
446
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
447
+ content = await image.read()
448
+ temp_file.write(content)
449
+ temp_image_path = temp_file.name
450
+
451
+ try:
452
+ # Validation de l'image
453
+ pil_image = Image.open(temp_image_path)
454
+ pil_image.verify()
455
+
456
+ # Calibration de la caméra
457
+ camera_params = get_camera_parameters(temp_image_path, validated_lines)
458
+
459
+ # Pour l'endpoint calibrate, nous n'avons pas de détections automatiques
460
+ # mais nous pouvons traiter les lignes manuelles fournies
461
+ detections = None
462
+
463
+ # Formatage de la réponse
464
+ response = CalibrationResponse(
465
+ status="success",
466
+ camera_parameters=camera_params,
467
+ input_lines=validated_lines,
468
+ detections=detections,
469
+ message="Calibration réussie"
470
+ )
471
+
472
+ return response
473
+
474
+ except Exception as e:
475
+ raise HTTPException(
476
+ status_code=500,
477
+ detail=f"Erreur lors de la calibration: {str(e)}"
478
+ )
479
+
480
+ finally:
481
+ # Nettoyage du fichier temporaire
482
+ if os.path.exists(temp_image_path):
483
+ os.unlink(temp_image_path)
484
+
485
+ except HTTPException:
486
+ raise
487
+ except Exception as e:
488
+ raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
489
+
490
+ @app.post("/inference/image", response_model=InferenceImageResponse)
491
+ async def inference_image(
492
+ image: UploadFile = File(..., description="Image du terrain de football"),
493
+ kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"),
494
+ line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes")
495
+ ):
496
+ """
497
+ Extraire automatiquement les paramètres de caméra à partir d'une image.
498
+ Retourne les coordonnées détectées sur l'image et le terrain.
499
+ """
500
+ params = None
501
+ temp_image_path = None
502
+
503
+ try:
504
+ # Validation du format d'image - version robuste
505
+ content_type = getattr(image, 'content_type', None) or ""
506
+ filename = getattr(image, 'filename', "") or ""
507
+
508
+ # Vérifier le type MIME ou l'extension du fichier
509
+ image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
510
+ is_image_content = content_type.startswith('image/') if content_type else False
511
+ is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions)
512
+
513
+ if not is_image_content and not is_image_extension:
514
+ raise HTTPException(
515
+ status_code=400,
516
+ detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}"
517
+ )
518
+
519
+ # Sauvegarde temporaire de l'image
520
+ file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
521
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
522
+ content = await image.read()
523
+ temp_file.write(content)
524
+ temp_image_path = temp_file.name
525
+
526
+ # Charger les modèles
527
+ model, model_l, device = load_inference_models()
528
+
529
+ # Lire l'image
530
+ frame = cv2.imread(temp_image_path)
531
+ if frame is None:
532
+ raise HTTPException(status_code=400, detail="Impossible de lire l'image")
533
+
534
+ frame_height, frame_width = frame.shape[:2]
535
+
536
+ # Mettre à jour les seuils globaux
537
+ global KP_THRESHOLD, LINE_THRESHOLD
538
+ KP_THRESHOLD = kp_threshold
539
+ LINE_THRESHOLD = line_threshold
540
+
541
+ # Traitement avec coordonnées
542
+ params, kp_dict, lines_dict = process_frame_inference_with_coords(
543
+ frame, model, model_l, device, frame_width, frame_height
544
+ )
545
+
546
+ # Traiter les détections (coordonnées image + terrain)
547
+ detections = process_detections(kp_dict, lines_dict, params)
548
+
549
+ # Formatage de la réponse
550
+ response = InferenceImageResponse(
551
+ status="success" if params is not None else "failed",
552
+ camera_parameters=params,
553
+ image_info={
554
+ "filename": filename,
555
+ "width": frame_width,
556
+ "height": frame_height,
557
+ "kp_threshold": kp_threshold,
558
+ "line_threshold": line_threshold
559
+ },
560
+ detections=detections,
561
+ message="Paramètres extraits avec succès" if params is not None else "Échec de l'extraction des paramètres"
562
+ )
563
+
564
+ return response
565
+
566
+ except HTTPException:
567
+ raise
568
+ except Exception as e:
569
+ raise HTTPException(
570
+ status_code=500,
571
+ detail=f"Erreur lors de l'inférence: {str(e)}"
572
+ )
573
+ finally:
574
+ # Nettoyage du fichier temporaire
575
+ if temp_image_path and os.path.exists(temp_image_path):
576
+ os.unlink(temp_image_path)
577
+
578
+ @app.post("/inference/video", response_model=InferenceVideoResponse)
579
+ async def inference_video(
580
+ video: UploadFile = File(..., description="Vidéo du terrain de football"),
581
+ kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"),
582
+ line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes"),
583
+ frame_step: int = Form(FRAME_STEP, description="Traiter 1 frame sur N")
584
+ ):
585
+ """
586
+ Extraire automatiquement les paramètres de caméra à partir d'une vidéo.
587
+ Retourne les coordonnées détectées pour chaque frame traitée.
588
+ """
589
+ try:
590
+ # Validation du format vidéo - version robuste
591
+ content_type = getattr(video, 'content_type', None) or ""
592
+ filename = getattr(video, 'filename', "") or ""
593
+
594
+ video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv']
595
+ is_video_content = content_type.startswith('video/') if content_type else False
596
+ is_video_extension = any(filename.lower().endswith(ext) for ext in video_extensions)
597
+
598
+ if not is_video_content and not is_video_extension:
599
+ raise HTTPException(
600
+ status_code=400,
601
+ detail=f"Le fichier doit être une vidéo. Type détecté: {content_type}, Fichier: {filename}"
602
+ )
603
+
604
+ # Sauvegarde temporaire de la vidéo
605
+ file_extension = os.path.splitext(filename)[1] if filename else '.mp4'
606
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
607
+ content = await video.read()
608
+ temp_file.write(content)
609
+ temp_video_path = temp_file.name
610
+
611
+ try:
612
+ # Charger les modèles
613
+ model, model_l, device = load_inference_models()
614
+
615
+ # Ouvrir la vidéo
616
+ cap = cv2.VideoCapture(temp_video_path)
617
+ if not cap.isOpened():
618
+ raise HTTPException(status_code=400, detail="Impossible d'ouvrir la vidéo")
619
+
620
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
621
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
622
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
623
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
624
+
625
+ # Mettre à jour les seuils globaux
626
+ global KP_THRESHOLD, LINE_THRESHOLD
627
+ KP_THRESHOLD = kp_threshold
628
+ LINE_THRESHOLD = line_threshold
629
+
630
+ frames_results = []
631
+ frame_count = 0
632
+ processed_count = 0
633
+
634
+ while cap.isOpened():
635
+ ret, frame = cap.read()
636
+ if not ret:
637
+ break
638
+
639
+ # Traiter seulement 1 frame sur frame_step
640
+ if frame_count % frame_step != 0:
641
+ frame_count += 1
642
+ continue
643
+
644
+ # Traitement avec coordonnées
645
+ params, kp_dict, lines_dict = process_frame_inference_with_coords(
646
+ frame, model, model_l, device, frame_width, frame_height
647
+ )
648
+
649
+ if params is not None:
650
+ # Traiter les détections (coordonnées image + terrain)
651
+ detections = process_detections(kp_dict, lines_dict, params)
652
+
653
+ frame_result = FrameResult(
654
+ frame_number=frame_count,
655
+ timestamp_seconds=frame_count / fps,
656
+ camera_parameters=params,
657
+ detections=detections
658
+ )
659
+
660
+ frames_results.append(frame_result)
661
+ processed_count += 1
662
+
663
+ frame_count += 1
664
+
665
+ cap.release()
666
+
667
+ # Formatage de la réponse
668
+ response = InferenceVideoResponse(
669
+ status="success" if frames_results else "failed",
670
+ video_info={
671
+ "filename": filename,
672
+ "width": frame_width,
673
+ "height": frame_height,
674
+ "total_frames": total_frames,
675
+ "fps": fps,
676
+ "duration_seconds": total_frames / fps,
677
+ "kp_threshold": kp_threshold,
678
+ "line_threshold": line_threshold,
679
+ "frame_step": frame_step
680
+ },
681
+ frames_processed=processed_count,
682
+ frames_results=frames_results,
683
+ message=f"Paramètres extraits de {processed_count} frames" if frames_results else "Aucun paramètre extrait"
684
+ )
685
+
686
+ return response
687
+
688
+ except Exception as e:
689
+ raise HTTPException(
690
+ status_code=500,
691
+ detail=f"Erreur lors de l'inférence vidéo: {str(e)}"
692
+ )
693
+
694
+ finally:
695
+ # Nettoyage du fichier temporaire
696
+ if os.path.exists(temp_video_path):
697
+ os.unlink(temp_video_path)
698
+
699
+ except HTTPException:
700
+ raise
701
+ except Exception as e:
702
+ raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
703
+
704
  app_instance = app
requirements.txt CHANGED
@@ -1,25 +1,23 @@
1
- # API Framework
2
- fastapi==0.104.1
3
- uvicorn[standard]==0.24.0
4
- python-multipart==0.0.6
5
- pydantic==2.5.0
6
-
7
- # Core dependencies (compatibles Python 3.10)
8
- numpy==1.24.3
9
- opencv-python-headless==4.8.1.78
10
- pillow==10.1.0
11
- scipy==1.11.4
12
- PyYAML==6.0.1
13
- lsq-ellipse==2.2.1
14
- shapely==2.0.2
15
-
16
- # PyTorch CPU (compatible Python 3.10)
17
- torch==2.1.0
18
- torchvision==0.16.0
19
-
20
- # Utilities
21
- tqdm==4.66.1
22
-
23
-
24
-
25
  huggingface_hub
 
1
+ # API Framework
2
+ fastapi==0.104.1
3
+ uvicorn[standard]==0.24.0
4
+ python-multipart==0.0.6
5
+ pydantic==2.5.0
6
+
7
+ # Core dependencies (compatibles Python 3.10)
8
+ numpy==1.24.3
9
+ opencv-python-headless==4.8.1.78
10
+ pillow==10.1.0
11
+ scipy==1.11.4
12
+ PyYAML==6.0.1
13
+ lsq-ellipse==2.2.1
14
+ shapely==2.0.2
15
+
16
+ # PyTorch CPU (compatible Python 3.10)
17
+ torch==2.1.0
18
+ torchvision==0.16.0
19
+
20
+ # Utilities
21
+ tqdm==4.66.1
22
+
 
 
23
  huggingface_hub