Spaces:
Sleeping
Sleeping
import cv2 | |
import torch | |
import numpy as np | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from ultralytics import YOLO | |
import time | |
import os | |
import tempfile | |
from flask import Flask, request, jsonify | |
import gradio as gr | |
# Initialize Flask app and Gradio interface | |
app = Flask(__name__) | |
# Global variable to store detection history | |
detection_history = [] | |
# Emotion labels | |
emotions = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral'] | |
# Load models (cache in Hugging Face Space) | |
def load_models(): | |
# Face detection model | |
face_model = YOLO('yolov8n-face.pt') | |
# Emotion model (simplified version of your CNN) | |
class EmotionCNN(torch.nn.Module): | |
def __init__(self, num_classes=7): | |
super().__init__() | |
self.features = torch.nn.Sequential( | |
torch.nn.Conv2d(1, 64, 3, padding=1), | |
torch.nn.ReLU(), | |
torch.nn.MaxPool2d(2), | |
torch.nn.Conv2d(64, 128, 3, padding=1), | |
torch.nn.ReLU(), | |
torch.nn.MaxPool2d(2), | |
torch.nn.Conv2d(128, 256, 3, padding=1), | |
torch.nn.ReLU(), | |
torch.nn.MaxPool2d(2) | |
) | |
self.classifier = torch.nn.Sequential( | |
torch.nn.Dropout(0.5), | |
torch.nn.Linear(256*6*6, 1024), | |
torch.nn.ReLU(), | |
torch.nn.Dropout(0.5), | |
torch.nn.Linear(1024, num_classes) | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = torch.flatten(x, 1) | |
x = self.classifier(x) | |
return x | |
emotion_model = EmotionCNN() | |
# Load your pretrained weights here | |
# emotion_model.load_state_dict(torch.load('emotion_model.pth')) | |
emotion_model.eval() | |
return face_model, emotion_model | |
face_model, emotion_model = load_models() | |
# Preprocessing function | |
def preprocess_face(face_img): | |
transform = transforms.Compose([ | |
transforms.Resize((48, 48)), | |
transforms.Grayscale(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
]) | |
face_pil = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB)) | |
return transform(face_pil).unsqueeze(0) | |
# Process video function | |
def process_video(video_path): | |
global detection_history | |
detection_history = [] | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
return {"error": "Could not open video"} | |
frame_count = 0 | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
frame_skip = int(fps / 3) # Process ~3 frames per second | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame_count += 1 | |
if frame_count % frame_skip != 0: | |
continue | |
# Face detection | |
results = face_model(frame) | |
for result in results: | |
boxes = result.boxes | |
if len(boxes) == 0: | |
continue | |
for box in boxes: | |
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) | |
face_img = frame[y1:y2, x1:x2] | |
if face_img.size == 0: | |
continue | |
# Emotion prediction | |
face_tensor = preprocess_face(face_img) | |
with torch.no_grad(): | |
output = emotion_model(face_tensor) | |
prob = torch.nn.functional.softmax(output, dim=1)[0] | |
pred_idx = torch.argmax(output).item() | |
confidence = prob[pred_idx].item() | |
detection_history.append({ | |
"frame": frame_count, | |
"time": frame_count / fps, | |
"emotion": emotions[pred_idx], | |
"confidence": confidence, | |
"box": [x1, y1, x2, y2] | |
}) | |
cap.release() | |
if not detection_history: | |
return {"error": "No faces detected"} | |
return { | |
"detections": detection_history, | |
"summary": { | |
"total_frames": frame_count, | |
"fps": fps, | |
"duration": frame_count / fps | |
} | |
} | |
# Flask API endpoint | |
def api_predict(): | |
if 'file' not in request.files: | |
return jsonify({"error": "No file provided"}), 400 | |
file = request.files['file'] | |
if file.filename == '': | |
return jsonify({"error": "No selected file"}), 400 | |
# Save to temp file | |
temp_path = os.path.join(tempfile.gettempdir(), file.filename) | |
file.save(temp_path) | |
# Process video | |
result = process_video(temp_path) | |
# Clean up | |
os.remove(temp_path) | |
return jsonify(result) | |
# Gradio interface | |
def gradio_predict(video): | |
temp_path = os.path.join(tempfile.gettempdir(), video.name) | |
with open(temp_path, 'wb') as f: | |
f.write(video.read()) | |
result = process_video(temp_path) | |
os.remove(temp_path) | |
if "error" in result: | |
return result["error"] | |
# Create visualization | |
cap = cv2.VideoCapture(video.name) | |
ret, frame = cap.read() | |
cap.release() | |
if ret: | |
# Draw last detection on frame | |
last_det = result["detections"][-1] | |
x1, y1, x2, y2 = last_det["box"] | |
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
cv2.putText(frame, f"{last_det['emotion']} ({last_det['confidence']:.2f})", | |
(x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) | |
# Convert to RGB for Gradio | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
return frame, result | |
return result | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=gradio_predict, | |
inputs=gr.Video(label="Upload Video"), | |
outputs=[ | |
gr.Image(label="Detection Preview"), | |
gr.JSON(label="Results") | |
], | |
title="Video Emotion Detection", | |
description="Upload a video to detect emotions in faces" | |
) | |
# Mount Gradio app | |
app = gr.mount_gradio_app(app, demo, path="/") | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) |