Rivalcoder commited on
Commit
9a2edf3
·
1 Parent(s): d674b3d
Files changed (2) hide show
  1. app.py +211 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ from ultralytics import YOLO
7
+ import time
8
+ import os
9
+ import tempfile
10
+ from flask import Flask, request, jsonify
11
+ import gradio as gr
12
+
13
+ # Initialize Flask app and Gradio interface
14
+ app = Flask(__name__)
15
+
16
+ # Global variable to store detection history
17
+ detection_history = []
18
+
19
+ # Emotion labels
20
+ emotions = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
21
+
22
+ # Load models (cache in Hugging Face Space)
23
+ def load_models():
24
+ # Face detection model
25
+ face_model = YOLO('yolov8n-face.pt')
26
+
27
+ # Emotion model (simplified version of your CNN)
28
+ class EmotionCNN(torch.nn.Module):
29
+ def __init__(self, num_classes=7):
30
+ super().__init__()
31
+ self.features = torch.nn.Sequential(
32
+ torch.nn.Conv2d(1, 64, 3, padding=1),
33
+ torch.nn.ReLU(),
34
+ torch.nn.MaxPool2d(2),
35
+ torch.nn.Conv2d(64, 128, 3, padding=1),
36
+ torch.nn.ReLU(),
37
+ torch.nn.MaxPool2d(2),
38
+ torch.nn.Conv2d(128, 256, 3, padding=1),
39
+ torch.nn.ReLU(),
40
+ torch.nn.MaxPool2d(2)
41
+ )
42
+ self.classifier = torch.nn.Sequential(
43
+ torch.nn.Dropout(0.5),
44
+ torch.nn.Linear(256*6*6, 1024),
45
+ torch.nn.ReLU(),
46
+ torch.nn.Dropout(0.5),
47
+ torch.nn.Linear(1024, num_classes)
48
+ )
49
+
50
+ def forward(self, x):
51
+ x = self.features(x)
52
+ x = torch.flatten(x, 1)
53
+ x = self.classifier(x)
54
+ return x
55
+
56
+ emotion_model = EmotionCNN()
57
+ # Load your pretrained weights here
58
+ # emotion_model.load_state_dict(torch.load('emotion_model.pth'))
59
+ emotion_model.eval()
60
+
61
+ return face_model, emotion_model
62
+
63
+ face_model, emotion_model = load_models()
64
+
65
+ # Preprocessing function
66
+ def preprocess_face(face_img):
67
+ transform = transforms.Compose([
68
+ transforms.Resize((48, 48)),
69
+ transforms.Grayscale(),
70
+ transforms.ToTensor(),
71
+ transforms.Normalize(mean=[0.5], std=[0.5])
72
+ ])
73
+ face_pil = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))
74
+ return transform(face_pil).unsqueeze(0)
75
+
76
+ # Process video function
77
+ def process_video(video_path):
78
+ global detection_history
79
+ detection_history = []
80
+
81
+ cap = cv2.VideoCapture(video_path)
82
+ if not cap.isOpened():
83
+ return {"error": "Could not open video"}
84
+
85
+ frame_count = 0
86
+ fps = cap.get(cv2.CAP_PROP_FPS)
87
+ frame_skip = int(fps / 3) # Process ~3 frames per second
88
+
89
+ while True:
90
+ ret, frame = cap.read()
91
+ if not ret:
92
+ break
93
+
94
+ frame_count += 1
95
+ if frame_count % frame_skip != 0:
96
+ continue
97
+
98
+ # Face detection
99
+ results = face_model(frame)
100
+
101
+ for result in results:
102
+ boxes = result.boxes
103
+ if len(boxes) == 0:
104
+ continue
105
+
106
+ for box in boxes:
107
+ x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
108
+ face_img = frame[y1:y2, x1:x2]
109
+
110
+ if face_img.size == 0:
111
+ continue
112
+
113
+ # Emotion prediction
114
+ face_tensor = preprocess_face(face_img)
115
+ with torch.no_grad():
116
+ output = emotion_model(face_tensor)
117
+ prob = torch.nn.functional.softmax(output, dim=1)[0]
118
+ pred_idx = torch.argmax(output).item()
119
+ confidence = prob[pred_idx].item()
120
+
121
+ detection_history.append({
122
+ "frame": frame_count,
123
+ "time": frame_count / fps,
124
+ "emotion": emotions[pred_idx],
125
+ "confidence": confidence,
126
+ "box": [x1, y1, x2, y2]
127
+ })
128
+
129
+ cap.release()
130
+
131
+ if not detection_history:
132
+ return {"error": "No faces detected"}
133
+
134
+ return {
135
+ "detections": detection_history,
136
+ "summary": {
137
+ "total_frames": frame_count,
138
+ "fps": fps,
139
+ "duration": frame_count / fps
140
+ }
141
+ }
142
+
143
+ # Flask API endpoint
144
+ @app.route('/api/predict', methods=['POST'])
145
+ def api_predict():
146
+ if 'file' not in request.files:
147
+ return jsonify({"error": "No file provided"}), 400
148
+
149
+ file = request.files['file']
150
+ if file.filename == '':
151
+ return jsonify({"error": "No selected file"}), 400
152
+
153
+ # Save to temp file
154
+ temp_path = os.path.join(tempfile.gettempdir(), file.filename)
155
+ file.save(temp_path)
156
+
157
+ # Process video
158
+ result = process_video(temp_path)
159
+
160
+ # Clean up
161
+ os.remove(temp_path)
162
+
163
+ return jsonify(result)
164
+
165
+ # Gradio interface
166
+ def gradio_predict(video):
167
+ temp_path = os.path.join(tempfile.gettempdir(), video.name)
168
+ with open(temp_path, 'wb') as f:
169
+ f.write(video.read())
170
+
171
+ result = process_video(temp_path)
172
+ os.remove(temp_path)
173
+
174
+ if "error" in result:
175
+ return result["error"]
176
+
177
+ # Create visualization
178
+ cap = cv2.VideoCapture(video.name)
179
+ ret, frame = cap.read()
180
+ cap.release()
181
+
182
+ if ret:
183
+ # Draw last detection on frame
184
+ last_det = result["detections"][-1]
185
+ x1, y1, x2, y2 = last_det["box"]
186
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
187
+ cv2.putText(frame, f"{last_det['emotion']} ({last_det['confidence']:.2f})",
188
+ (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
189
+
190
+ # Convert to RGB for Gradio
191
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
192
+ return frame, result
193
+ return result
194
+
195
+ # Create Gradio interface
196
+ demo = gr.Interface(
197
+ fn=gradio_predict,
198
+ inputs=gr.Video(label="Upload Video"),
199
+ outputs=[
200
+ gr.Image(label="Detection Preview"),
201
+ gr.JSON(label="Results")
202
+ ],
203
+ title="Video Emotion Detection",
204
+ description="Upload a video to detect emotions in faces"
205
+ )
206
+
207
+ # Mount Gradio app
208
+ app = gr.mount_gradio_app(app, demo, path="/")
209
+
210
+ if __name__ == "__main__":
211
+ app.run(host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ opencv-python
4
+ ultralytics
5
+ gradio
6
+ flask
7
+ numpy
8
+ Pillow