from fastapi import FastAPI, HTTPException, UploadFile, File, Form from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Dict, List, Any, Optional import json import tempfile import os from PIL import Image import numpy as np import cv2 import torch import torchvision.transforms as T import torchvision.transforms.functional as f import yaml from tqdm import tqdm from huggingface_hub import hf_hub_download from get_camera_params import get_camera_parameters # Imports pour l'inférence automatique from model.cls_hrnet import get_cls_net from model.cls_hrnet_l import get_cls_net as get_cls_net_l from utils.utils_calib import FramebyFrameCalib from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, complete_keypoints, coords_to_dict app = FastAPI( title="Football Vision Calibration API", description="API pour la calibration de caméras à partir de lignes de terrain de football", version="1.0.0" ) # Configuration CORS pour autoriser les requêtes depuis le frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], # En production, spécifiez les domaines autorisés allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Paramètres par défaut pour l'inférence WEIGHTS_KP = "models/SV_FT_TSWC_kp" WEIGHTS_LINE = "models/SV_FT_TSWC_lines" # DEVICE = "cuda:0" DEVICE = "cpu" KP_THRESHOLD = 0.15 LINE_THRESHOLD = 0.15 PNL_REFINE = True FRAME_STEP = 5 # Cache pour les modèles (éviter de les recharger à chaque requête) _models_cache = None # Paramètres pour HF Hub HF_MODEL_REPO = "2nzi/SV_FT_TSWC_kp" # Remplacez par votre repo WEIGHTS_KP_FILE = "SV_FT_TSWC_kp" # Nom du fichier dans le repo WEIGHTS_LINE_FILE = "SV_FT_TSWC_lines" # Nom du fichier dans le repo def load_inference_models(): """Charge les modèles d'inférence depuis Hugging Face Hub""" global _models_cache if _models_cache is not None: return _models_cache try: # Device detection device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Télécharger les modèles depuis HF Hub print("Téléchargement des modèles depuis Hugging Face Hub...") weights_kp_path = hf_hub_download( repo_id=HF_MODEL_REPO, filename=WEIGHTS_KP_FILE, cache_dir="./hf_cache" ) weights_line_path = hf_hub_download( repo_id=HF_MODEL_REPO, filename=WEIGHTS_LINE_FILE, cache_dir="./hf_cache" ) print(f"Modèles téléchargés:") print(f" - Keypoints: {weights_kp_path}") print(f" - Lines: {weights_line_path}") # Vérifier l'existence des fichiers de configuration config_files = ["config/hrnetv2_w48.yaml", "config/hrnetv2_w48_l.yaml"] for config_file in config_files: if not os.path.exists(config_file): raise FileNotFoundError(f"Fichier de configuration manquant: {config_file}") # Charger les configurations with open("config/hrnetv2_w48.yaml", 'r') as f: cfg = yaml.safe_load(f) with open("config/hrnetv2_w48_l.yaml", 'r') as f: cfg_l = yaml.safe_load(f) # Modèle keypoints model = get_cls_net(cfg) model.load_state_dict(torch.load(weights_kp_path, map_location=device)) model.to(device) model.eval() # Modèle lignes model_l = get_cls_net_l(cfg_l) model_l.load_state_dict(torch.load(weights_line_path, map_location=device)) model_l.to(device) model_l.eval() _models_cache = (model, model_l, device) print("✅ Modèles chargés avec succès depuis HF Hub!") return _models_cache except Exception as e: print(f"❌ Erreur lors du chargement des modèles: {e}") raise HTTPException( status_code=503, detail=f"Modèles non disponibles: {str(e)}. Veuillez réessayer plus tard." ) def process_frame_inference(frame, model, model_l, device, frame_width, frame_height): """Traite une frame et retourne les paramètres de caméra""" transform = T.Resize((540, 960)) # Préparer la frame pour l'inférence frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_pil = Image.fromarray(frame_rgb) frame_tensor = f.to_tensor(frame_pil).float().unsqueeze(0) if frame_tensor.size()[-1] != 960: frame_tensor = transform(frame_tensor) frame_tensor = frame_tensor.to(device) b, c, h, w = frame_tensor.size() # Inférence with torch.no_grad(): heatmaps = model(frame_tensor) heatmaps_l = model_l(frame_tensor) # Extraire les keypoints et lignes kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:]) line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:]) kp_dict = coords_to_dict(kp_coords, threshold=KP_THRESHOLD) lines_dict = coords_to_dict(line_coords, threshold=LINE_THRESHOLD) kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h, normalize=True) # Calibration cam = FramebyFrameCalib(iwidth=frame_width, iheight=frame_height, denormalize=True) cam.update(kp_dict, lines_dict) final_params_dict = cam.heuristic_voting(refine_lines=PNL_REFINE) return final_params_dict # Modèles Pydantic pour la validation des données class Point(BaseModel): x: float y: float class LinePolygon(BaseModel): points: List[Point] class CalibrationRequest(BaseModel): lines: Dict[str, List[Point]] # Nouvelles classes pour les coordonnées class ImageCoordinate(BaseModel): x: float y: float class WorldCoordinate(BaseModel): x: float # Coordonnée X sur le terrain (mètres) y: float # Coordonnée Y sur le terrain (mètres) class KeypointData(BaseModel): id: int image_coords: ImageCoordinate world_coords: Optional[WorldCoordinate] = None confidence: Optional[float] = None class LineData(BaseModel): id: int start_image: ImageCoordinate end_image: ImageCoordinate start_world: Optional[WorldCoordinate] = None end_world: Optional[WorldCoordinate] = None confidence: Optional[float] = None class DetectionData(BaseModel): keypoints: List[KeypointData] lines: List[LineData] class CalibrationResponse(BaseModel): status: str camera_parameters: Dict[str, Any] input_lines: Dict[str, List[Point]] detections: Optional[DetectionData] = None message: str class InferenceImageResponse(BaseModel): status: str camera_parameters: Optional[Dict[str, Any]] image_info: Dict[str, Any] detections: Optional[DetectionData] = None message: str class FrameResult(BaseModel): frame_number: int timestamp_seconds: float camera_parameters: Optional[Dict[str, Any]] detections: Optional[DetectionData] = None class InferenceVideoResponse(BaseModel): status: str video_info: Dict[str, Any] frames_processed: int frames_results: List[FrameResult] message: str # Fonction de conversion coordonnées image -> terrain def image_to_world(point_2d, cam_params): """ Convertit un point 2D de l'image vers des coordonnées 3D du terrain. Args: point_2d: [x, y] coordonnées dans l'image cam_params: Paramètres de la caméra Returns: [x, y] coordonnées sur le terrain (Z=0) """ try: # Matrice de calibration K = np.array([ [cam_params["cam_params"]["x_focal_length"], 0, cam_params["cam_params"]["principal_point"][0]], [0, cam_params["cam_params"]["y_focal_length"], cam_params["cam_params"]["principal_point"][1]], [0, 0, 1] ]) # Matrice de rotation et position R = np.array(cam_params["cam_params"]["rotation_matrix"]) camera_pos = np.array(cam_params["cam_params"]["position_meters"]) # Point 2D en coordonnées homogènes point_2d_h = np.array([point_2d[0], point_2d[1], 1]) # Back-projection du rayon depuis la caméra ray = np.linalg.inv(K) @ point_2d_h ray = R.T @ ray # Intersection avec le plan Z=0 (terrain) if abs(ray[2]) < 1e-6: # Éviter division par zéro return None t = -camera_pos[2] / ray[2] world_point = camera_pos + t * ray return world_point[:2] # Retourner seulement X,Y (Z=0) except Exception as e: print(f"Erreur conversion image->monde: {e}") return None def process_detections(kp_dict, lines_dict, cam_params): """ Traite les détections et convertit les coordonnées image vers terrain. Args: kp_dict: Dictionnaire des keypoints détectés lines_dict: Dictionnaire des lignes détectées cam_params: Paramètres de la caméra Returns: DetectionData: Données structurées avec coordonnées image et terrain """ keypoints_data = [] lines_data = [] # Traiter les keypoints if kp_dict: for kp_id, kp_value in kp_dict.items(): image_coords = ImageCoordinate(x=kp_value['x'], y=kp_value['y']) # Convertir vers coordonnées terrain world_coords = None if cam_params and 'cam_params' in cam_params: world_point = image_to_world([kp_value['x'], kp_value['y']], cam_params) if world_point is not None: world_coords = WorldCoordinate(x=float(world_point[0]), y=float(world_point[1])) keypoints_data.append(KeypointData( id=kp_id, image_coords=image_coords, world_coords=world_coords, confidence=kp_value.get('confidence', None) )) # Traiter les lignes if lines_dict: for line_id, line_value in lines_dict.items(): start_image = ImageCoordinate(x=line_value['x_1'], y=line_value['y_1']) end_image = ImageCoordinate(x=line_value['x_2'], y=line_value['y_2']) # Convertir vers coordonnées terrain start_world = None end_world = None if cam_params and 'cam_params' in cam_params: start_point = image_to_world([line_value['x_1'], line_value['y_1']], cam_params) end_point = image_to_world([line_value['x_2'], line_value['y_2']], cam_params) if start_point is not None: start_world = WorldCoordinate(x=float(start_point[0]), y=float(start_point[1])) if end_point is not None: end_world = WorldCoordinate(x=float(end_point[0]), y=float(end_point[1])) lines_data.append(LineData( id=line_id, start_image=start_image, end_image=end_image, start_world=start_world, end_world=end_world, confidence=line_value.get('confidence', None) )) return DetectionData(keypoints=keypoints_data, lines=lines_data) def process_frame_inference_with_coords(frame, model, model_l, device, frame_width, frame_height): """ Version enrichie qui retourne les paramètres de caméra ET les coordonnées détectées. """ transform = T.Resize((540, 960)) # Préparer la frame pour l'inférence frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_pil = Image.fromarray(frame_rgb) frame_tensor = f.to_tensor(frame_pil).float().unsqueeze(0) if frame_tensor.size()[-1] != 960: frame_tensor = transform(frame_tensor) frame_tensor = frame_tensor.to(device) b, c, h, w = frame_tensor.size() # Inférence with torch.no_grad(): heatmaps = model(frame_tensor) heatmaps_l = model_l(frame_tensor) # Extraire les keypoints et lignes kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:]) line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:]) kp_dict = coords_to_dict(kp_coords, threshold=KP_THRESHOLD) lines_dict = coords_to_dict(line_coords, threshold=LINE_THRESHOLD) kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h, normalize=True) # Calibration cam = FramebyFrameCalib(iwidth=frame_width, iheight=frame_height, denormalize=True) cam.update(kp_dict, lines_dict) final_params_dict = cam.heuristic_voting(refine_lines=PNL_REFINE) return final_params_dict, kp_dict, lines_dict @app.get("/") async def root(): return { "message": "Football Vision Calibration API", "version": "1.0.0", "endpoints": { "/calibrate": "POST - Calibrer une caméra à partir d'une image et de lignes", "/inference/image": "POST - Extraire les paramètres de caméra d'une image automatiquement", "/inference/video": "POST - Extraire les paramètres de caméra d'une vidéo automatiquement", "/health": "GET - Vérifier l'état de l'API" } } @app.get("/health") async def health_check(): return {"status": "healthy", "message": "API is running"} @app.post("/calibrate", response_model=CalibrationResponse) async def calibrate_camera( image: UploadFile = File(..., description="Image du terrain de football"), lines_data: str = Form(..., description="JSON des lignes du terrain") ): """ Calibrer une caméra à partir d'une image et des lignes du terrain. Retourne aussi les coordonnées détectées sur l'image et le terrain. """ try: # Validation du format d'image - version robuste content_type = getattr(image, 'content_type', None) or "" filename = getattr(image, 'filename', "") or "" # Vérifier le type MIME ou l'extension du fichier image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'] is_image_content = content_type.startswith('image/') if content_type else False is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions) if not is_image_content and not is_image_extension: raise HTTPException( status_code=400, detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}" ) # Parse des données de lignes try: lines_dict = json.loads(lines_data) except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Format JSON invalide pour les lignes") # Validation de la structure des lignes validated_lines = {} for line_name, points in lines_dict.items(): if not isinstance(points, list): raise HTTPException( status_code=400, detail=f"Les points de la ligne '{line_name}' doivent être une liste" ) validated_points = [] for i, point in enumerate(points): if not isinstance(point, dict) or 'x' not in point or 'y' not in point: raise HTTPException( status_code=400, detail=f"Point {i} de la ligne '{line_name}' doit avoir les clés 'x' et 'y'" ) try: validated_points.append({ "x": float(point['x']), "y": float(point['y']) }) except (ValueError, TypeError): raise HTTPException( status_code=400, detail=f"Coordonnées invalides pour le point {i} de la ligne '{line_name}'" ) validated_lines[line_name] = validated_points # Sauvegarde temporaire de l'image file_extension = os.path.splitext(filename)[1] if filename else '.jpg' with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: content = await image.read() temp_file.write(content) temp_image_path = temp_file.name try: # Validation de l'image pil_image = Image.open(temp_image_path) pil_image.verify() # Calibration de la caméra camera_params = get_camera_parameters(temp_image_path, validated_lines) # Pour l'endpoint calibrate, nous n'avons pas de détections automatiques # mais nous pouvons traiter les lignes manuelles fournies detections = None # Formatage de la réponse response = CalibrationResponse( status="success", camera_parameters=camera_params, input_lines=validated_lines, detections=detections, message="Calibration réussie" ) return response except Exception as e: raise HTTPException( status_code=500, detail=f"Erreur lors de la calibration: {str(e)}" ) finally: # Nettoyage du fichier temporaire if os.path.exists(temp_image_path): os.unlink(temp_image_path) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}") @app.post("/inference/image", response_model=InferenceImageResponse) async def inference_image( image: UploadFile = File(..., description="Image du terrain de football"), kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"), line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes") ): """ Extraire automatiquement les paramètres de caméra à partir d'une image. Retourne les coordonnées détectées sur l'image et le terrain. """ params = None temp_image_path = None try: # Validation du format d'image - version robuste content_type = getattr(image, 'content_type', None) or "" filename = getattr(image, 'filename', "") or "" # Vérifier le type MIME ou l'extension du fichier image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'] is_image_content = content_type.startswith('image/') if content_type else False is_image_extension = any(filename.lower().endswith(ext) for ext in image_extensions) if not is_image_content and not is_image_extension: raise HTTPException( status_code=400, detail=f"Le fichier doit être une image. Type détecté: {content_type}, Fichier: {filename}" ) # Sauvegarde temporaire de l'image file_extension = os.path.splitext(filename)[1] if filename else '.jpg' with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: content = await image.read() temp_file.write(content) temp_image_path = temp_file.name # Charger les modèles model, model_l, device = load_inference_models() # Lire l'image frame = cv2.imread(temp_image_path) if frame is None: raise HTTPException(status_code=400, detail="Impossible de lire l'image") frame_height, frame_width = frame.shape[:2] # Mettre à jour les seuils globaux global KP_THRESHOLD, LINE_THRESHOLD KP_THRESHOLD = kp_threshold LINE_THRESHOLD = line_threshold # Traitement avec coordonnées params, kp_dict, lines_dict = process_frame_inference_with_coords( frame, model, model_l, device, frame_width, frame_height ) # Traiter les détections (coordonnées image + terrain) detections = process_detections(kp_dict, lines_dict, params) # Formatage de la réponse response = InferenceImageResponse( status="success" if params is not None else "failed", camera_parameters=params, image_info={ "filename": filename, "width": frame_width, "height": frame_height, "kp_threshold": kp_threshold, "line_threshold": line_threshold }, detections=detections, message="Paramètres extraits avec succès" if params is not None else "Échec de l'extraction des paramètres" ) return response except HTTPException: raise except Exception as e: raise HTTPException( status_code=500, detail=f"Erreur lors de l'inférence: {str(e)}" ) finally: # Nettoyage du fichier temporaire if temp_image_path and os.path.exists(temp_image_path): os.unlink(temp_image_path) @app.post("/inference/video", response_model=InferenceVideoResponse) async def inference_video( video: UploadFile = File(..., description="Vidéo du terrain de football"), kp_threshold: float = Form(KP_THRESHOLD, description="Seuil pour les keypoints"), line_threshold: float = Form(LINE_THRESHOLD, description="Seuil pour les lignes"), frame_step: int = Form(FRAME_STEP, description="Traiter 1 frame sur N") ): """ Extraire automatiquement les paramètres de caméra à partir d'une vidéo. Retourne les coordonnées détectées pour chaque frame traitée. """ try: # Validation du format vidéo - version robuste content_type = getattr(video, 'content_type', None) or "" filename = getattr(video, 'filename', "") or "" video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv'] is_video_content = content_type.startswith('video/') if content_type else False is_video_extension = any(filename.lower().endswith(ext) for ext in video_extensions) if not is_video_content and not is_video_extension: raise HTTPException( status_code=400, detail=f"Le fichier doit être une vidéo. Type détecté: {content_type}, Fichier: {filename}" ) # Sauvegarde temporaire de la vidéo file_extension = os.path.splitext(filename)[1] if filename else '.mp4' with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: content = await video.read() temp_file.write(content) temp_video_path = temp_file.name try: # Charger les modèles model, model_l, device = load_inference_models() # Ouvrir la vidéo cap = cv2.VideoCapture(temp_video_path) if not cap.isOpened(): raise HTTPException(status_code=400, detail="Impossible d'ouvrir la vidéo") frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) # Mettre à jour les seuils globaux global KP_THRESHOLD, LINE_THRESHOLD KP_THRESHOLD = kp_threshold LINE_THRESHOLD = line_threshold frames_results = [] frame_count = 0 processed_count = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break # Traiter seulement 1 frame sur frame_step if frame_count % frame_step != 0: frame_count += 1 continue # Traitement avec coordonnées params, kp_dict, lines_dict = process_frame_inference_with_coords( frame, model, model_l, device, frame_width, frame_height ) if params is not None: # Traiter les détections (coordonnées image + terrain) detections = process_detections(kp_dict, lines_dict, params) frame_result = FrameResult( frame_number=frame_count, timestamp_seconds=frame_count / fps, camera_parameters=params, detections=detections ) frames_results.append(frame_result) processed_count += 1 frame_count += 1 cap.release() # Formatage de la réponse response = InferenceVideoResponse( status="success" if frames_results else "failed", video_info={ "filename": filename, "width": frame_width, "height": frame_height, "total_frames": total_frames, "fps": fps, "duration_seconds": total_frames / fps, "kp_threshold": kp_threshold, "line_threshold": line_threshold, "frame_step": frame_step }, frames_processed=processed_count, frames_results=frames_results, message=f"Paramètres extraits de {processed_count} frames" if frames_results else "Aucun paramètre extrait" ) return response except Exception as e: raise HTTPException( status_code=500, detail=f"Erreur lors de l'inférence vidéo: {str(e)}" ) finally: # Nettoyage du fichier temporaire if os.path.exists(temp_video_path): os.unlink(temp_video_path) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}") app_instance = app