import cv2 import torch import numpy as np from PIL import Image import torchvision.transforms as transforms from ultralytics import YOLO import tempfile import time import os import json import gradio as gr from fastapi import FastAPI, UploadFile, File, HTTPException import uvicorn # Initialize FastAPI app = FastAPI() # Global variable for face detections largest_face_detections = [] # Load models yolo_model_path = "yolov8n-face.pt" emotion_model_path = "best_emotion_model.pth" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Check if models exist if os.path.exists(yolo_model_path): yolo_model = YOLO(yolo_model_path) else: raise FileNotFoundError(f"YOLO model not found at {yolo_model_path}") if os.path.exists(emotion_model_path): from torch import nn class EmotionCNN(nn.Module): def __init__(self, num_classes=7): super(EmotionCNN, self).__init__() self.conv1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)) self.fc = nn.Sequential(nn.Linear(64 * 24 * 24, 1024), nn.ReLU(), nn.Linear(1024, num_classes)) def forward(self, x): x = self.conv1(x) x = x.view(x.size(0), -1) x = self.fc(x) return x emotion_model = EmotionCNN(num_classes=7) checkpoint = torch.load(emotion_model_path, map_location=device) emotion_model.load_state_dict(checkpoint['model_state_dict']) emotion_model.to(device) emotion_model.eval() else: raise FileNotFoundError(f"Emotion model not found at {emotion_model_path}") # Emotion labels emotions = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral'] def preprocess_face(face_img): """Preprocess face image for emotion detection""" transform = transforms.Compose([ transforms.Resize((48, 48)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) face_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB)).convert('L') face_tensor = transform(face_img).unsqueeze(0) return face_tensor def process_video(video_path: str): """Process video and return emotion results""" global largest_face_detections largest_face_detections = [] cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return {"success": False, "message": "Could not open video file"} while True: ret, frame = cap.read() if not ret: break largest_face_area = 0 current_detection = None results = yolo_model(frame, stream=True) for result in results: boxes = result.boxes for box in boxes: x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy()) face_img = frame[y1:y2, x1:x2] if face_img.size == 0: continue face_tensor = preprocess_face(face_img).to(device) with torch.no_grad(): output = emotion_model(face_tensor) probabilities = torch.nn.functional.softmax(output, dim=1) emotion_idx = torch.argmax(output, dim=1).item() confidence = probabilities[0][emotion_idx].item() emotion = emotions[emotion_idx] if (x2 - x1) * (y2 - y1) > largest_face_area: largest_face_area = (x2 - x1) * (y2 - y1) current_detection = {"emotion": emotion, "confidence": confidence} if current_detection: largest_face_detections.append(current_detection) cap.release() if not largest_face_detections: return {"success": True, "message": "No faces detected", "results": []} return { "success": True, "message": "Video processed", "results": largest_face_detections } @app.post("/api/video") async def handle_video(file: UploadFile = File(...)): """API endpoint for video emotion detection""" try: with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp: tmp.write(await file.read()) video_path = tmp.name result = process_video(video_path) os.remove(video_path) return result except Exception as e: return {"success": False, "message": "Error processing video", "error": str(e)} # Gradio UI def gradio_process(video): with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp: tmp.write(video) video_path = tmp.name result = process_video(video_path) os.remove(video_path) return result with gr.Blocks() as demo: gr.Markdown("# Video Emotion Analysis") with gr.Row(): with gr.Column(): video_input = gr.File(label="Upload a video", file_types=[".mp4"]) submit_btn = gr.Button("Analyze") with gr.Column(): output = gr.JSON(label="Results") submit_btn.click(fn=gradio_process, inputs=video_input, outputs=output) app = gr.mount_gradio_app(app, demo, path="/") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)