Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
4 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
5 |
+
from sklearn.model_selection import train_test_split
|
6 |
+
from sklearn.linear_model import LogisticRegression
|
7 |
+
from sklearn.multioutput import MultiOutputClassifier
|
8 |
+
from sklearn.metrics import classification_report, f1_score, accuracy_score, hamming_loss
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
# Load dataset
|
12 |
+
splits = {'train': 'simplified/train-00000-of-00001.parquet'}
|
13 |
+
df = pd.read_parquet("hf://datasets/google-research-datasets/go_emotions/" + splits["train"])
|
14 |
+
|
15 |
+
emotion_labels = [
|
16 |
+
"admiration", "amusement", "anger", "annoyance", "approval",
|
17 |
+
"caring", "confusion", "curiosity", "desire", "disappointment",
|
18 |
+
"disapproval", "disgust", "embarrassment", "excitement", "fear",
|
19 |
+
"gratitude", "grief", "joy", "love", "nervousness",
|
20 |
+
"optimism", "pride", "realization", "relief", "remorse",
|
21 |
+
"sadness", "surprise", "neutral"
|
22 |
+
]
|
23 |
+
|
24 |
+
index_to_emotion = {i: label for i, label in enumerate(emotion_labels)}
|
25 |
+
|
26 |
+
mlb = MultiLabelBinarizer(classes=range(28))
|
27 |
+
y = mlb.fit_transform(df['labels'])
|
28 |
+
|
29 |
+
vectorizer = TfidfVectorizer(max_features=5000)
|
30 |
+
X = vectorizer.fit_transform(df['text'])
|
31 |
+
|
32 |
+
# Placeholder for trained model
|
33 |
+
model = None
|
34 |
+
metrics_report = ""
|
35 |
+
|
36 |
+
def train_model(test_size=0.2, max_iter=1000, random_state=42):
|
37 |
+
global model, metrics_report
|
38 |
+
|
39 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
40 |
+
X, y, test_size=test_size, random_state=random_state
|
41 |
+
)
|
42 |
+
|
43 |
+
model = MultiOutputClassifier(LogisticRegression(max_iter=max_iter))
|
44 |
+
model.fit(X_train, y_train)
|
45 |
+
y_pred = model.predict(X_test)
|
46 |
+
|
47 |
+
# Calculate standard classification report + other metrics
|
48 |
+
report = classification_report(
|
49 |
+
y_test, y_pred, target_names=[str(i) for i in range(28)]
|
50 |
+
)
|
51 |
+
micro_f1 = f1_score(y_test, y_pred, average="micro")
|
52 |
+
macro_f1 = f1_score(y_test, y_pred, average="macro")
|
53 |
+
acc = accuracy_score(y_test, y_pred)
|
54 |
+
hamming = hamming_loss(y_test, y_pred)
|
55 |
+
|
56 |
+
metrics_summary = f"""
|
57 |
+
Micro F1-score: {micro_f1:.4f}
|
58 |
+
Macro F1-score: {macro_f1:.4f}
|
59 |
+
Accuracy (Exact Match): {acc:.4f}
|
60 |
+
Hamming Loss: {hamming:.4f}
|
61 |
+
"""
|
62 |
+
|
63 |
+
# Save the full report
|
64 |
+
metrics_report = metrics_summary.strip() + "\n\n" + report
|
65 |
+
|
66 |
+
return "Training Complete!"
|
67 |
+
|
68 |
+
def predict_emotions(text):
|
69 |
+
if model is None:
|
70 |
+
return "Please train the model first.", ""
|
71 |
+
|
72 |
+
vectorized = vectorizer.transform([text])
|
73 |
+
probas = model.predict_proba(vectorized)
|
74 |
+
|
75 |
+
result = {}
|
76 |
+
for i, emotion in enumerate(mlb.classes_):
|
77 |
+
prob_class_1 = probas[i][0][1]
|
78 |
+
result[emotion] = round(prob_class_1 * 100, 2)
|
79 |
+
|
80 |
+
sorted_result = sorted(result.items(), key=lambda x: x[1], reverse=True)
|
81 |
+
return sorted_result
|
82 |
+
|
83 |
+
def predict_and_display(sentence):
|
84 |
+
predictions = predict_emotions(sentence)
|
85 |
+
if isinstance(predictions, str):
|
86 |
+
return predictions, ""
|
87 |
+
|
88 |
+
max_len = max(len(index_to_emotion[emo_id]) for emo_id, _ in predictions)
|
89 |
+
result = "```" + "\nEmotion Predictions:\n\n"
|
90 |
+
for emo_id, score in predictions:
|
91 |
+
emo_name = index_to_emotion[emo_id]
|
92 |
+
result += f"{emo_name.ljust(max_len)} → {score}%\n"
|
93 |
+
result += "```"
|
94 |
+
top_emotion = index_to_emotion[predictions[0][0]]
|
95 |
+
return result, top_emotion
|
96 |
+
|
97 |
+
# Gradio App
|
98 |
+
with gr.Blocks(title="Interactive Emotion Detector", theme=gr.themes.Soft()) as demo:
|
99 |
+
with gr.Tabs():
|
100 |
+
with gr.Tab("Emotion Detection"):
|
101 |
+
gr.Markdown("## Emotion Detection")
|
102 |
+
with gr.Row():
|
103 |
+
with gr.Column():
|
104 |
+
input_text = gr.Textbox(
|
105 |
+
lines=3, placeholder="Enter a sentence...", label="Input Sentence"
|
106 |
+
)
|
107 |
+
submit_btn = gr.Button("Analyze Emotion")
|
108 |
+
with gr.Column():
|
109 |
+
output_text = gr.Markdown(label="Prediction Results")
|
110 |
+
top_emotion = gr.Label(label="Top Emotion")
|
111 |
+
submit_btn.click(
|
112 |
+
fn=predict_and_display,
|
113 |
+
inputs=input_text,
|
114 |
+
outputs=[output_text, top_emotion]
|
115 |
+
)
|
116 |
+
|
117 |
+
with gr.Tab("Dataset"):
|
118 |
+
gr.Markdown("## Dataset Information")
|
119 |
+
|
120 |
+
def dataset_info():
|
121 |
+
df = pd.read_parquet("hf://datasets/google-research-datasets/go_emotions/simplified/train-00000-of-00001.parquet")
|
122 |
+
|
123 |
+
total_samples = len(df)
|
124 |
+
emotions = sorted(set(e for label in df['labels'] for e in label))
|
125 |
+
emotion_names = [emotion_labels[i] for i in emotions]
|
126 |
+
|
127 |
+
# Count distribution
|
128 |
+
all_labels = [emotion_labels[i] for sublist in df['labels'] for i in sublist]
|
129 |
+
label_counts = pd.Series(all_labels).value_counts().sort_index()
|
130 |
+
label_df = pd.DataFrame({
|
131 |
+
"Emotion": label_counts.index,
|
132 |
+
"Count": label_counts.values
|
133 |
+
})
|
134 |
+
|
135 |
+
stats = f"""
|
136 |
+
**Total Samples**: {total_samples}
|
137 |
+
**Emotion Classes**: {', '.join(emotion_names)}
|
138 |
+
"""
|
139 |
+
|
140 |
+
return stats, label_df
|
141 |
+
|
142 |
+
stats_display = gr.Markdown()
|
143 |
+
dist_table = gr.Dataframe(headers=["Emotion", "Count"], interactive=False)
|
144 |
+
|
145 |
+
load_btn = gr.Button("Load Dataset Info")
|
146 |
+
load_btn.click(fn=dataset_info, inputs=[], outputs=[stats_display, dist_table])
|
147 |
+
|
148 |
+
|
149 |
+
with gr.Tab("EDA"):
|
150 |
+
gr.Markdown("## Exploratory Data Analysis")
|
151 |
+
|
152 |
+
eda_btn = gr.Button("Run EDA")
|
153 |
+
|
154 |
+
eda_output = gr.Plot(label="EDA Output")
|
155 |
+
|
156 |
+
def run_eda():
|
157 |
+
import matplotlib.pyplot as plt
|
158 |
+
from collections import Counter
|
159 |
+
import re
|
160 |
+
|
161 |
+
# Define the label map inside the function
|
162 |
+
label_map = [
|
163 |
+
'admiration', 'amusement', 'anger', 'annoyance', 'approval',
|
164 |
+
'caring', 'confusion', 'curiosity', 'desire', 'disappointment',
|
165 |
+
'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear',
|
166 |
+
'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism',
|
167 |
+
'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise',
|
168 |
+
'neutral'
|
169 |
+
]
|
170 |
+
|
171 |
+
fig, axs = plt.subplots(2, 2, figsize=(18, 10))
|
172 |
+
|
173 |
+
# Label distribution
|
174 |
+
label_counts = df['labels'].explode().value_counts().sort_index()
|
175 |
+
axs[0, 0].bar(label_map, label_counts)
|
176 |
+
axs[0, 0].set_title("Label Frequency Distribution")
|
177 |
+
axs[0, 0].tick_params(axis='x', rotation=45)
|
178 |
+
|
179 |
+
# Labels per example
|
180 |
+
df['num_labels'] = df['labels'].apply(len)
|
181 |
+
df['num_labels'].value_counts().sort_index().plot(kind='bar', ax=axs[0, 1])
|
182 |
+
axs[0, 1].set_title("Number of Labels per Example")
|
183 |
+
|
184 |
+
# Text length distribution
|
185 |
+
df['text_length'] = df['text'].apply(len)
|
186 |
+
df['text_length'].hist(bins=50, ax=axs[1, 0])
|
187 |
+
axs[1, 0].set_title("Distribution of Text Lengths")
|
188 |
+
axs[1, 0].set_xlabel("Text Length (characters)")
|
189 |
+
axs[1, 0].set_ylabel("Frequency")
|
190 |
+
|
191 |
+
# Most common words
|
192 |
+
all_words = " ".join(df['text']).lower()
|
193 |
+
tokens = re.findall(r'\b\w+\b', all_words)
|
194 |
+
common_words = Counter(tokens).most_common(20)
|
195 |
+
words, freqs = zip(*common_words)
|
196 |
+
axs[1, 1].bar(words, freqs)
|
197 |
+
axs[1, 1].set_title("Top 20 Most Common Words")
|
198 |
+
axs[1, 1].tick_params(axis='x', rotation=45)
|
199 |
+
|
200 |
+
plt.tight_layout()
|
201 |
+
return fig
|
202 |
+
|
203 |
+
eda_btn.click(fn=run_eda, inputs=[], outputs=eda_output)
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
with gr.Tab("Train Model"):
|
208 |
+
gr.Markdown("## Train Your Emotion Model")
|
209 |
+
test_size = gr.Slider(0.1, 0.5, step=0.05, value=0.2, label="Test Size")
|
210 |
+
max_iter = gr.Slider(100, 5000, step=100, value=1000, label="Max Iterations")
|
211 |
+
random_state = gr.Number(value=42, label="Random State")
|
212 |
+
train_button = gr.Button("Train Model")
|
213 |
+
train_status = gr.Textbox(label="Training Status")
|
214 |
+
train_button.click(
|
215 |
+
fn=train_model,
|
216 |
+
inputs=[test_size, max_iter, random_state],
|
217 |
+
outputs=train_status
|
218 |
+
)
|
219 |
+
|
220 |
+
with gr.Tab("Results"):
|
221 |
+
gr.Markdown("## Evaluation Metrics")
|
222 |
+
results_output = gr.Markdown(label="Classification Report")
|
223 |
+
|
224 |
+
def get_report():
|
225 |
+
return "```\n" + metrics_report + "\n```"
|
226 |
+
|
227 |
+
refresh_btn = gr.Button("Refresh Report")
|
228 |
+
refresh_btn.click(
|
229 |
+
fn=get_report,
|
230 |
+
inputs=[],
|
231 |
+
outputs=results_output
|
232 |
+
)
|
233 |
+
|
234 |
+
demo.launch(debug = True)
|