Rivalcoder
Add Files
9a2edf3
raw
history blame
6.25 kB
import cv2
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from ultralytics import YOLO
import time
import os
import tempfile
from flask import Flask, request, jsonify
import gradio as gr
# Initialize Flask app and Gradio interface
app = Flask(__name__)
# Global variable to store detection history
detection_history = []
# Emotion labels
emotions = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
# Load models (cache in Hugging Face Space)
def load_models():
# Face detection model
face_model = YOLO('yolov8n-face.pt')
# Emotion model (simplified version of your CNN)
class EmotionCNN(torch.nn.Module):
def __init__(self, num_classes=7):
super().__init__()
self.features = torch.nn.Sequential(
torch.nn.Conv2d(1, 64, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2),
torch.nn.Conv2d(64, 128, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2),
torch.nn.Conv2d(128, 256, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.classifier = torch.nn.Sequential(
torch.nn.Dropout(0.5),
torch.nn.Linear(256*6*6, 1024),
torch.nn.ReLU(),
torch.nn.Dropout(0.5),
torch.nn.Linear(1024, num_classes)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
emotion_model = EmotionCNN()
# Load your pretrained weights here
# emotion_model.load_state_dict(torch.load('emotion_model.pth'))
emotion_model.eval()
return face_model, emotion_model
face_model, emotion_model = load_models()
# Preprocessing function
def preprocess_face(face_img):
transform = transforms.Compose([
transforms.Resize((48, 48)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
face_pil = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))
return transform(face_pil).unsqueeze(0)
# Process video function
def process_video(video_path):
global detection_history
detection_history = []
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return {"error": "Could not open video"}
frame_count = 0
fps = cap.get(cv2.CAP_PROP_FPS)
frame_skip = int(fps / 3) # Process ~3 frames per second
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_count % frame_skip != 0:
continue
# Face detection
results = face_model(frame)
for result in results:
boxes = result.boxes
if len(boxes) == 0:
continue
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
face_img = frame[y1:y2, x1:x2]
if face_img.size == 0:
continue
# Emotion prediction
face_tensor = preprocess_face(face_img)
with torch.no_grad():
output = emotion_model(face_tensor)
prob = torch.nn.functional.softmax(output, dim=1)[0]
pred_idx = torch.argmax(output).item()
confidence = prob[pred_idx].item()
detection_history.append({
"frame": frame_count,
"time": frame_count / fps,
"emotion": emotions[pred_idx],
"confidence": confidence,
"box": [x1, y1, x2, y2]
})
cap.release()
if not detection_history:
return {"error": "No faces detected"}
return {
"detections": detection_history,
"summary": {
"total_frames": frame_count,
"fps": fps,
"duration": frame_count / fps
}
}
# Flask API endpoint
@app.route('/api/predict', methods=['POST'])
def api_predict():
if 'file' not in request.files:
return jsonify({"error": "No file provided"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
# Save to temp file
temp_path = os.path.join(tempfile.gettempdir(), file.filename)
file.save(temp_path)
# Process video
result = process_video(temp_path)
# Clean up
os.remove(temp_path)
return jsonify(result)
# Gradio interface
def gradio_predict(video):
temp_path = os.path.join(tempfile.gettempdir(), video.name)
with open(temp_path, 'wb') as f:
f.write(video.read())
result = process_video(temp_path)
os.remove(temp_path)
if "error" in result:
return result["error"]
# Create visualization
cap = cv2.VideoCapture(video.name)
ret, frame = cap.read()
cap.release()
if ret:
# Draw last detection on frame
last_det = result["detections"][-1]
x1, y1, x2, y2 = last_det["box"]
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, f"{last_det['emotion']} ({last_det['confidence']:.2f})",
(x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
# Convert to RGB for Gradio
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame, result
return result
# Create Gradio interface
demo = gr.Interface(
fn=gradio_predict,
inputs=gr.Video(label="Upload Video"),
outputs=[
gr.Image(label="Detection Preview"),
gr.JSON(label="Results")
],
title="Video Emotion Detection",
description="Upload a video to detect emotions in faces"
)
# Mount Gradio app
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)