Spaces:
Runtime error
Runtime error
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
from scipy.special import logit | |
df = pd.read_json("../results.json") | |
df = df[df["metric"] != "chrf"] | |
df = df.groupby(["task", "metric", "bcp_47"]).agg({"score": "mean"}).reset_index() | |
# Apply logit transformation to classification scores to reduce skewness | |
def transform_classification_scores(row): | |
if row['task'] == 'classification': | |
# Avoid division by zero and infinite values by clipping | |
score = np.clip(row['score'], 0.001, 0.999) | |
# Apply logit transformation (log(p/(1-p))) | |
return logit(score) | |
else: | |
return row['score'] | |
df['score'] = df.apply(transform_classification_scores, axis=1) | |
# Create a pivot table with tasks as columns and languages as rows | |
pivot_df = df.pivot_table( | |
values='score', | |
index='bcp_47', | |
columns='task', | |
aggfunc='mean' | |
) | |
# Sort and filter tasks | |
ordered_tasks = [ | |
'translation_from', | |
'translation_to', | |
'classification', | |
'mmlu', | |
'arc', | |
'mgsm', | |
] | |
# Drop 'truthfulqa' if present and reindex columns | |
pivot_df = pivot_df[[task for task in ordered_tasks if task in pivot_df.columns]] | |
# Calculate correlation matrix | |
correlation_matrix = pivot_df.corr() | |
# Create the correlation plot | |
plt.figure(figsize=(8, 6)) | |
# Create mask for upper triangle including diagonal to show only lower triangle | |
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool)) | |
# Create a heatmap | |
sns.heatmap( | |
correlation_matrix, | |
annot=True, | |
cmap='Blues', | |
center=0, | |
square=True, | |
mask=mask, | |
cbar_kws={"shrink": .8}, | |
fmt='.3f' | |
) | |
plt.xlabel('Tasks', fontsize=12) | |
plt.ylabel('Tasks', fontsize=12) | |
plt.xticks(rotation=45, ha='right') | |
plt.yticks(rotation=0) | |
plt.tight_layout() | |
# Save the plot | |
plt.savefig('task_correlation_matrix.png', dpi=300, bbox_inches='tight') | |
plt.show() | |
# Print correlation values for reference | |
print("Correlation Matrix:") | |
print("Note: Classification scores have been logit-transformed to reduce skewness") | |
print(correlation_matrix.round(3)) | |
# Also create a scatter plot matrix for pairwise relationships with highlighted languages | |
highlighted_languages = ['en', 'zh', 'hi', 'es', 'ar'] | |
# Create color mapping | |
def get_color_and_label(lang_code): | |
if lang_code in highlighted_languages: | |
color_map = {'en': 'red', 'zh': 'blue', 'hi': 'green', 'es': 'orange', 'ar': 'purple'} | |
return color_map[lang_code], lang_code | |
else: | |
return 'lightgray', 'Other' | |
# Create custom scatter plot matrix | |
tasks = pivot_df.columns.tolist() | |
n_tasks = len(tasks) | |
fig, axes = plt.subplots(n_tasks, n_tasks, figsize=(15, 12)) | |
fig.suptitle('Pairwise Task Performance', fontsize=16, fontweight='bold') | |
# Create legend elements | |
legend_elements = [] | |
for lang in highlighted_languages: | |
color, _ = get_color_and_label(lang) | |
legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=8, label=lang)) | |
legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgray', markersize=8, label='Other')) | |
for i, task_y in enumerate(tasks): | |
for j, task_x in enumerate(tasks): | |
ax = axes[i, j] | |
if i == j: | |
# Diagonal: histogram | |
task_data = pivot_df[task_y].dropna() | |
colors = [get_color_and_label(lang)[0] for lang in task_data.index] | |
ax.hist(task_data, bins=20, alpha=0.7, color='skyblue', edgecolor='black') | |
ax.set_title(f'{task_y}', fontsize=10) | |
else: | |
# Off-diagonal: scatter plot | |
for lang_code in pivot_df.index: | |
if pd.notna(pivot_df.loc[lang_code, task_x]) and pd.notna(pivot_df.loc[lang_code, task_y]): | |
color, _ = get_color_and_label(lang_code) | |
alpha = 0.8 if lang_code in highlighted_languages else 0.3 | |
size = 50 if lang_code in highlighted_languages else 20 | |
ax.scatter(pivot_df.loc[lang_code, task_x], pivot_df.loc[lang_code, task_y], | |
c=color, alpha=alpha, s=size) | |
# Set labels | |
if i == n_tasks - 1: | |
ax.set_xlabel(task_x, fontsize=10) | |
if j == 0: | |
ax.set_ylabel(task_y, fontsize=10) | |
# Remove tick labels except for edges | |
if i != n_tasks - 1: | |
ax.set_xticklabels([]) | |
if j != 0: | |
ax.set_yticklabels([]) | |
# Add legend | |
fig.legend( | |
handles=legend_elements, | |
loc='lower center', | |
bbox_to_anchor=(0.5, -0.05), | |
ncol=len(legend_elements), | |
frameon=False, | |
fontsize=10, | |
handletextpad=0.5, | |
columnspacing=1.0 | |
) | |
plt.tight_layout() | |
plt.savefig('task_scatter_matrix.png', dpi=300, bbox_inches='tight') | |
plt.show() | |