Spaces:
Sleeping
Sleeping
""" | |
GuardBench Leaderboard Application | |
""" | |
import os | |
import json | |
import tempfile | |
import logging | |
import gradio as gr | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from apscheduler.schedulers.background import BackgroundScheduler | |
import numpy as np | |
from gradio.themes.utils import fonts, colors | |
from dataclasses import fields, dataclass | |
from src.about import ( | |
CITATION_BUTTON_LABEL, | |
CITATION_BUTTON_TEXT, | |
EVALUATION_QUEUE_TEXT, | |
INTRODUCTION_TEXT, | |
LLM_BENCHMARKS_TEXT, | |
TITLE, | |
) | |
from src.display.css_html_js import custom_css | |
from src.display.utils import ( | |
GUARDBENCH_COLUMN, | |
DISPLAY_COLS, | |
METRIC_COLS, | |
HIDDEN_COLS, | |
NEVER_HIDDEN_COLS, | |
CATEGORIES, | |
TEST_TYPES, | |
ModelType, | |
Mode, | |
Precision, | |
WeightType, | |
GuardModelType, | |
get_all_column_choices, | |
get_default_visible_columns, | |
) | |
from src.display.formatting import styled_message, styled_error, styled_warning | |
from src.envs import ( | |
ADMIN_USERNAME, | |
ADMIN_PASSWORD, | |
RESULTS_DATASET_ID, | |
SUBMITTER_TOKEN, | |
TOKEN, | |
DATA_PATH | |
) | |
from src.populate import get_leaderboard_df, get_category_leaderboard_df | |
from src.submission.submit import process_submission | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Ensure data directory exists | |
os.makedirs(DATA_PATH, exist_ok=True) | |
# Available benchmark versions | |
BENCHMARK_VERSIONS = ["v0"] | |
CURRENT_VERSION = "v0" | |
# Initialize leaderboard data | |
try: | |
logger.info("Initializing leaderboard data...") | |
LEADERBOARD_DF = get_leaderboard_df(version=CURRENT_VERSION) | |
logger.info(f"Loaded leaderboard with {len(LEADERBOARD_DF)} entries") | |
except Exception as e: | |
logger.error(f"Error loading leaderboard data: {e}") | |
LEADERBOARD_DF = pd.DataFrame() | |
custom_theme = gr.themes.Default( | |
primary_hue=colors.slate, | |
secondary_hue=colors.slate, | |
neutral_hue=colors.neutral, | |
font=(fonts.GoogleFont("Inter"), "sans-serif") | |
).set( | |
# font_size="16px", | |
body_background_fill="#0f0f10", | |
body_background_fill_dark="#0f0f10", | |
body_text_color="#f4f4f5", | |
body_text_color_subdued="#a1a1aa", | |
block_background_fill="#1e1e1e", # Cooler Grey | |
block_border_color="#333333", # Cooler Grey | |
block_shadow="none", | |
# Swapped primary and secondary button styles | |
button_primary_background_fill="#121212", # Changed to specific color for Refresh button | |
button_primary_text_color="#f4f4f5", | |
button_primary_border_color="#333333", # Keep border grey or change to #121212? | |
button_secondary_background_fill="#f4f4f5", | |
button_secondary_text_color="#0f0f10", | |
button_secondary_border_color="#f4f4f5", | |
input_background_fill="#1e1e1e", # Cooler Grey | |
input_border_color="#333333", # Cooler Grey | |
input_placeholder_color="#71717a", | |
table_border_color="#333333", # Cooler Grey | |
table_even_background_fill="#2d2d2d", # Cooler Grey (Slightly lighter) | |
table_odd_background_fill="#1e1e1e", # Cooler Grey | |
table_text_color="#f4f4f5", | |
link_text_color="#ffffff", | |
border_color_primary="#333333", # Cooler Grey | |
background_fill_secondary="#333333", # Cooler Grey | |
color_accent="#f4f4f5", | |
border_color_accent="#333333", # Cooler Grey | |
button_primary_background_fill_hover="#424242", # Cooler Grey | |
block_title_text_color="#f4f4f5", | |
accordion_text_color="#f4f4f5", | |
panel_background_fill="#1e1e1e", # Cooler Grey | |
panel_border_color="#333333", # Cooler Grey | |
# Explicitly setting primary/secondary/accent colors/borders | |
background_fill_primary="#0f0f10", | |
background_fill_primary_dark="#0f0f10", | |
background_fill_secondary_dark="#333333", # Cooler Grey | |
border_color_primary_dark="#333333", # Cooler Grey | |
border_color_accent_dark="#333333", # Cooler Grey | |
border_color_accent_subdued="#424242", # Cooler Grey | |
border_color_accent_subdued_dark="#424242", # Cooler Grey | |
color_accent_soft="#a1a1aa", | |
color_accent_soft_dark="#a1a1aa", | |
# Explicitly setting input hover/focus states | |
input_background_fill_dark="#1e1e1e", # Cooler Grey | |
input_background_fill_focus="#424242", # Cooler Grey | |
input_background_fill_focus_dark="#424242",# Cooler Grey | |
input_background_fill_hover="#2d2d2d", # Cooler Grey | |
input_background_fill_hover_dark="#2d2d2d", # Cooler Grey | |
input_border_color_dark="#333333", # Cooler Grey | |
input_border_color_focus="#f4f4f5", | |
input_border_color_focus_dark="#f4f4f5", | |
input_border_color_hover="#424242", # Cooler Grey | |
input_border_color_hover_dark="#424242", # Cooler Grey | |
input_placeholder_color_dark="#71717a", | |
# Explicitly set dark variants for table backgrounds | |
table_even_background_fill_dark="#2d2d2d", # Cooler Grey | |
table_odd_background_fill_dark="#1e1e1e", # Cooler Grey | |
# Explicitly set dark text variants | |
body_text_color_dark="#f4f4f5", | |
body_text_color_subdued_dark="#a1a1aa", | |
block_title_text_color_dark="#f4f4f5", | |
accordion_text_color_dark="#f4f4f5", | |
table_text_color_dark="#f4f4f5", | |
# Explicitly set dark panel/block variants | |
panel_background_fill_dark="#1e1e1e", # Cooler Grey | |
panel_border_color_dark="#333333", # Cooler Grey | |
block_background_fill_dark="#1e1e1e", # Cooler Grey | |
block_border_color_dark="#333333", # Cooler Grey | |
) | |
class ColumnInfo: | |
"""Information about a column in the leaderboard.""" | |
name: str | |
display_name: str | |
type: str = "text" | |
hidden: bool = False | |
never_hidden: bool = False | |
displayed_by_default: bool = True | |
def update_column_choices(df): | |
"""Update column choices based on what's actually in the dataframe""" | |
if df is None or df.empty: | |
return get_all_column_choices() | |
# Get columns that actually exist in the dataframe | |
existing_columns = list(df.columns) | |
# Get all possible columns with their display names | |
all_columns = get_all_column_choices() | |
# Filter to only include columns that exist in the dataframe | |
valid_columns = [(col_name, display_name) for col_name, display_name in all_columns | |
if col_name in existing_columns] | |
# Return default if there are no valid columns | |
if not valid_columns: | |
return get_all_column_choices() | |
return valid_columns | |
# Update the column_selector initialization | |
def get_initial_columns(): | |
"""Get initial columns to show in the dropdown""" | |
try: | |
# Get available columns in the main dataframe | |
available_cols = list(LEADERBOARD_DF.columns) | |
logger.info(f"Available columns in LEADERBOARD_DF: {available_cols}") | |
# If dataframe is empty, use default visible columns | |
if not available_cols: | |
return get_default_visible_columns() | |
# Get default visible columns that actually exist in the dataframe | |
valid_defaults = [col for col in get_default_visible_columns() if col in available_cols] | |
# If none of the defaults exist, return all available columns | |
if not valid_defaults: | |
return available_cols | |
return valid_defaults | |
except Exception as e: | |
logger.error(f"Error getting initial columns: {e}") | |
return get_default_visible_columns() | |
def init_leaderboard(dataframe, visible_columns=None): | |
""" | |
Initialize a standard Gradio Dataframe component for the leaderboard. | |
""" | |
if dataframe is None or dataframe.empty: | |
# Create an empty dataframe with the right columns | |
columns = [getattr(GUARDBENCH_COLUMN, col).name for col in DISPLAY_COLS] | |
dataframe = pd.DataFrame(columns=columns) | |
logger.warning("Initializing empty leaderboard") | |
# print("\n\n", "dataframe", dataframe, "--------------------------------\n\n") | |
# Determine which columns to display | |
display_column_names = [getattr(GUARDBENCH_COLUMN, col).name for col in DISPLAY_COLS] | |
hidden_column_names = [getattr(GUARDBENCH_COLUMN, col).name for col in HIDDEN_COLS] | |
# Columns that should always be shown | |
always_visible = [getattr(GUARDBENCH_COLUMN, col).name for col in NEVER_HIDDEN_COLS] | |
# Use provided visible columns if specified, otherwise use default | |
if visible_columns is None: | |
# Determine which columns to show initially | |
visible_columns = [col for col in display_column_names if col not in hidden_column_names] | |
# Always include the never-hidden columns | |
for col in always_visible: | |
if col not in visible_columns and col in dataframe.columns: | |
visible_columns.append(col) | |
# Make sure we only include columns that actually exist in the dataframe | |
visible_columns = [col for col in visible_columns if col in dataframe.columns] | |
# Map GuardBench column types to Gradio's expected datatype strings | |
# Valid Gradio datatypes are: 'str', 'number', 'bool', 'date', 'markdown', 'html', 'image' | |
type_mapping = { | |
'text': 'str', | |
'number': 'number', | |
'bool': 'bool', | |
'date': 'date', | |
'markdown': 'markdown', | |
'html': 'html', | |
'image': 'image' | |
} | |
# Create a list of datatypes in the format Gradio expects | |
datatypes = [] | |
for col in visible_columns: | |
# Find the corresponding GUARDBENCH_COLUMN entry | |
col_type = None | |
for display_col in DISPLAY_COLS: | |
if getattr(GUARDBENCH_COLUMN, display_col).name == col: | |
orig_type = getattr(GUARDBENCH_COLUMN, display_col).type | |
# Map to Gradio's expected types | |
col_type = type_mapping.get(orig_type, 'str') | |
break | |
# Default to 'str' if type not found or not mappable | |
if col_type is None: | |
col_type = 'str' | |
datatypes.append(col_type) | |
# Create a dummy column for search functionality if it doesn't exist | |
if 'search_dummy' not in dataframe.columns: | |
dataframe['search_dummy'] = dataframe.apply( | |
lambda row: ' '.join(str(val) for val in row.values if pd.notna(val)), | |
axis=1 | |
) | |
# Select only the visible columns for display | |
visible_columns.remove('model_name') | |
visible_columns = ['model_name'] + visible_columns | |
display_df = dataframe[visible_columns].copy() | |
# print(f"--- DataFrame inside init_leaderboard (before rounding) ---") | |
# print(display_df[['model_name', 'macro_accuracy', 'macro_recall', 'total_evals_count']].head() if all(c in display_df.columns for c in ['model_name', 'macro_accuracy', 'macro_recall', 'total_evals_count']) else "Relevant columns not present") | |
# print(f"-------------------------------------------------------------") | |
# Round numeric columns to 3 decimal places for display | |
numeric_cols = display_df.select_dtypes(include=np.number).columns | |
for col in numeric_cols: | |
# Avoid rounding integer columns like counts | |
if not pd.api.types.is_integer_dtype(display_df[col]): | |
# Format floats to exactly 3 decimal places, preserving trailing zeros | |
display_df[col] = display_df[col].apply(lambda x: f"{x:.3f}" if pd.notna(x) else None) | |
column_info_map = {f.name: getattr(GUARDBENCH_COLUMN, f.name) for f in fields(GUARDBENCH_COLUMN)} | |
column_mapping = {col: column_info_map.get(col, ColumnInfo(col, col)).display_name for col in visible_columns} | |
# Rename columns in the DataFrame | |
display_df.rename(columns=column_mapping, inplace=True) | |
# Apply styling - note: styling might need adjustment if it relies on column names | |
styler = display_df.style.set_properties(**{'text-align': 'right'}) | |
return gr.Dataframe( | |
value=styler, | |
datatype=datatypes, | |
interactive=False, | |
wrap=True, | |
elem_id="leaderboard-table", | |
row_count=len(display_df) | |
) | |
def search_filter_leaderboard(df, search_query="", model_types=None, version=CURRENT_VERSION): | |
""" | |
Filter the leaderboard based on search query and model types. | |
""" | |
if df is None or df.empty: | |
return df | |
filtered_df = df.copy() | |
# Add search dummy column if it doesn't exist | |
if 'search_dummy' not in filtered_df.columns: | |
filtered_df['search_dummy'] = filtered_df.apply( | |
lambda row: ' '.join(str(val) for val in row.values if pd.notna(val)), | |
axis=1 | |
) | |
# Apply model type filter | |
if model_types and len(model_types) > 0: | |
filtered_df = filtered_df[filtered_df[GUARDBENCH_COLUMN.model_type.name].isin(model_types)] | |
# Apply search query | |
if search_query: | |
search_terms = [term.strip() for term in search_query.split(";") if term.strip()] | |
if search_terms: | |
combined_mask = None | |
for term in search_terms: | |
mask = filtered_df['search_dummy'].str.contains(term, case=False, na=False) | |
if combined_mask is None: | |
combined_mask = mask | |
else: | |
combined_mask = combined_mask | mask | |
if combined_mask is not None: | |
filtered_df = filtered_df[combined_mask] | |
# Drop the search dummy column before returning | |
visible_columns = [col for col in filtered_df.columns if col != 'search_dummy'] | |
return filtered_df[visible_columns] | |
def refresh_data_with_filters(version=CURRENT_VERSION, search_query="", model_types=None, selected_columns=None): | |
""" | |
Refresh the leaderboard data and update all components with filtering. | |
Ensures we handle cases where dataframes might have limited columns. | |
""" | |
try: | |
logger.info(f"Performing refresh of leaderboard data with filters...") | |
# Get new data | |
main_df = get_leaderboard_df(version=version) | |
category_dfs = [get_category_leaderboard_df(category, version=version) for category in CATEGORIES] | |
selected_columns = [x.lower().replace(" ", "_").replace("(", "").replace(")", "").replace("_recall", "_recall_binary").replace("_precision", "_precision_binary") for x in selected_columns] | |
# Log the actual columns we have | |
logger.info(f"Main dataframe columns: {list(main_df.columns)}") | |
# Apply filters to each dataframe | |
filtered_main_df = search_filter_leaderboard(main_df, search_query, model_types, version) | |
filtered_category_dfs = [ | |
search_filter_leaderboard(df, search_query, model_types, version) | |
for df in category_dfs | |
] | |
# Get available columns from the dataframe | |
available_columns = list(filtered_main_df.columns) | |
# Filter selected columns to only those available in the data | |
if selected_columns: | |
# Convert display names to internal names first | |
internal_selected_columns = [x.lower().replace(" ", "_").replace("(", "").replace(")", "").replace("_recall", "_recall_binary").replace("_precision", "_precision_binary") for x in selected_columns] | |
valid_selected_columns = [col for col in internal_selected_columns if col in available_columns] | |
if not valid_selected_columns and 'model_name' in available_columns: | |
# Fallback if conversion/filtering leads to empty selection | |
valid_selected_columns = ['model_name'] + [col for col in get_default_visible_columns() if col in available_columns] | |
else: | |
# If no columns were selected in the dropdown, use default visible columns that exist | |
valid_selected_columns = [col for col in get_default_visible_columns() if col in available_columns] | |
# Initialize dataframes for display with valid selected columns | |
main_dataframe = init_leaderboard(filtered_main_df, valid_selected_columns) | |
# For category dataframes, get columns that actually exist in each one | |
category_dataframes = [] | |
for df in filtered_category_dfs: | |
df_columns = list(df.columns) | |
df_valid_columns = [col for col in valid_selected_columns if col in df_columns] | |
if not df_valid_columns and 'model_name' in df_columns: | |
df_valid_columns = ['model_name'] + get_default_visible_columns() | |
category_dataframes.append(init_leaderboard(df, df_valid_columns)) | |
return main_dataframe, *category_dataframes | |
except Exception as e: | |
logger.error(f"Error in refresh with filters: {e}") | |
# Return the current leaderboards on error | |
return leaderboard, *[tab.children[0] for tab in category_tabs.children[1:len(CATEGORIES)+1]] | |
def submit_results( | |
model_name: str, | |
base_model: str, | |
revision: str, | |
precision: str, | |
weight_type: str, | |
model_type: str, | |
mode: str, | |
submission_file: tempfile._TemporaryFileWrapper, | |
version: str, | |
guard_model_type: GuardModelType | |
): | |
""" | |
Handle submission of results with model metadata. | |
""" | |
if submission_file is None: | |
return styled_error("No submission file provided") | |
if not model_name: | |
return styled_error("Model name is required") | |
if not model_type: | |
return styled_error("Please select a model type") | |
if not mode: | |
return styled_error("Please select an inference mode") | |
file_path = submission_file.name | |
logger.info(f"Received submission for model {model_name}: {file_path}") | |
# Add metadata to the submission | |
metadata = { | |
"model_name": model_name, | |
"base_model": base_model, | |
"revision": revision if revision else "main", | |
"precision": precision, | |
"weight_type": weight_type, | |
"model_type": model_type, | |
"mode": mode, | |
"version": version, | |
"guard_model_type": guard_model_type | |
} | |
# Process the submission | |
result = process_submission(file_path, metadata, version=version) | |
# Refresh the leaderboard data | |
global LEADERBOARD_DF | |
try: | |
logger.info(f"Refreshing leaderboard data after submission for version {version}...") | |
LEADERBOARD_DF = get_leaderboard_df(version=version) | |
logger.info("Refreshed leaderboard data after submission") | |
except Exception as e: | |
logger.error(f"Error refreshing leaderboard data: {e}") | |
return result | |
def refresh_data(version=CURRENT_VERSION): | |
""" | |
Refresh the leaderboard data and update all components. | |
""" | |
try: | |
logger.info(f"Performing scheduled refresh of leaderboard data...") | |
# Get new data | |
main_df = get_leaderboard_df(version=version) | |
category_dfs = [get_category_leaderboard_df(category, version=version) for category in CATEGORIES] | |
# For gr.Dataframe, we return the actual dataframes | |
return main_df, *category_dfs | |
except Exception as e: | |
logger.error(f"Error in scheduled refresh: {e}") | |
return None, *[None for _ in CATEGORIES] | |
def update_leaderboards(version): | |
""" | |
Update all leaderboard components with data for the selected version. | |
""" | |
try: | |
new_df = get_leaderboard_df(version=version) | |
category_dfs = [get_category_leaderboard_df(category, version=version) for category in CATEGORIES] | |
return new_df, *category_dfs | |
except Exception as e: | |
logger.error(f"Error updating leaderboards for version {version}: {e}") | |
return None, *[None for _ in CATEGORIES] | |
def create_performance_plot(selected_models, category, metric="f1_binary", version=CURRENT_VERSION): | |
""" | |
Create a radar plot comparing model performance for selected models. | |
""" | |
if category == "All Results": | |
df = get_leaderboard_df(version=version) | |
else: | |
df = get_category_leaderboard_df(category, version=version) | |
if df.empty: | |
return go.Figure() | |
# Filter for selected models | |
df = df[df['model_name'].isin(selected_models)] | |
# Get the relevant metric columns | |
metric_cols = [col for col in df.columns if metric in col] | |
# Create figure | |
fig = go.Figure() | |
# Custom colors for different models | |
colors = ['#8FCCCC', '#C2A4B6', '#98B4A6', '#B68F7C'] # Pale Cyan, Pale Pink, Pale Green, Pale Orange | |
# Add traces for each model | |
for idx, model in enumerate(selected_models): | |
model_data = df[df['model_name'] == model] | |
if not model_data.empty: | |
values = model_data[metric_cols].values[0].tolist() | |
# Add the first value again at the end to complete the polygon | |
values = values + [values[0]] | |
# Clean up test type names | |
categories = [col.replace(f'_{metric}', '') for col in metric_cols] | |
# Add the first category again at the end to complete the polygon | |
categories = categories + [categories[0]] | |
fig.add_trace(go.Scatterpolar( | |
r=values, | |
theta=categories, | |
name=model, | |
line_color=colors[idx % len(colors)], | |
fill='toself' | |
)) | |
# Update layout with all settings at once | |
fig.update_layout( | |
paper_bgcolor='#000000', | |
plot_bgcolor='#000000', | |
font={'color': '#ffffff'}, | |
title={ | |
'text': f'{category} - {metric.upper()} Score Comparison', | |
'font': {'color': '#ffffff', 'size': 24} | |
}, | |
polar=dict( | |
bgcolor='#000000', | |
radialaxis=dict( | |
visible=True, | |
range=[0, 1], | |
gridcolor='#333333', | |
linecolor='#333333', | |
tickfont={'color': '#ffffff'}, | |
), | |
angularaxis=dict( | |
gridcolor='#333333', | |
linecolor='#333333', | |
tickfont={'color': '#ffffff'}, | |
) | |
), | |
height=600, | |
showlegend=True, | |
legend=dict( | |
yanchor="top", | |
y=0.99, | |
xanchor="right", | |
x=0.99, | |
bgcolor='rgba(0,0,0,0.5)', | |
font={'color': '#ffffff'} | |
) | |
) | |
return fig | |
def update_model_choices(version): | |
""" | |
Update the list of available models for the given version. | |
""" | |
df = get_leaderboard_df(version=version) | |
if df.empty: | |
return [] | |
return sorted(df['model_name'].unique().tolist()) | |
def update_visualization(selected_models, selected_category, selected_metric, version): | |
""" | |
Update the visualization based on user selections. | |
""" | |
if not selected_models: | |
return go.Figure() | |
return create_performance_plot(selected_models, selected_category, selected_metric, version) | |
# Create Gradio app | |
demo = gr.Blocks(css=custom_css, theme=custom_theme) | |
CATEGORY_DISPLAY_MAP = { | |
'Political Corruption and Legal Evasion': 'Corruption & Legal Evasion', | |
'Financial Fraud and Unethical Business': 'Financial Fraud', | |
'AI Manipulation and Jailbreaking': 'AI Jailbreaking', | |
'Child Exploitation and Abuse': 'Child Exploitation', | |
'Hate Speech, Extremism, and Discrimination': 'Hate Speech', | |
'Labor Exploitation and Human Trafficking': 'Labor Exploitation', | |
'Manipulation, Deception, and Misinformation': 'Misinformation', | |
'Environmental and Industrial Harm': 'Environmental Harm', | |
'Academic Dishonesty and Cheating': 'Academic Dishonesty', | |
'Self–Harm and Suicidal Ideation': 'Self-Harm', | |
'Animal Cruelty and Exploitation': 'Animal Harm', | |
'Criminal, Violent, and Terrorist Activity': 'Crime & Violence', | |
'Drug– and Substance–Related Activities': 'Drug Use', | |
'Sexual Content and Violence': 'Sexual Content', | |
'Weapon, Explosives, and Hazardous Materials': 'Weapons & Harmful Materials', | |
'Cybercrime, Hacking, and Digital Exploits': 'Cybercrime', | |
'Creative Content Involving Illicit Themes': 'Illicit Creative', | |
'Safe Prompts': 'Safe Prompts' | |
} | |
# Create reverse mapping for lookups | |
CATEGORY_REVERSE_MAP = {v: k for k, v in CATEGORY_DISPLAY_MAP.items()} | |
with demo: | |
gr.HTML(TITLE) | |
# gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text") | |
gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text") | |
with gr.Row(): | |
tabs = gr.Tabs(elem_classes="tab-buttons") | |
with tabs: | |
with gr.TabItem("Leaderboard", elem_id="guardbench-leaderboard-tab", id=0): | |
with gr.Row(): | |
version_selector = gr.Dropdown( | |
choices=BENCHMARK_VERSIONS, | |
label="Benchmark Version", | |
value=CURRENT_VERSION, | |
interactive=True, | |
elem_classes="version-selector", | |
scale=1, | |
visible=False | |
) | |
with gr.Row(): | |
search_input = gr.Textbox( | |
placeholder="Search by models (use ; to split)", | |
label="Search", | |
elem_id="search-bar", | |
scale=2 | |
) | |
model_type_filter = gr.Dropdown( | |
choices=[t.to_str(" : ") for t in ModelType if t != ModelType.Unknown], | |
label="Access Type", | |
multiselect=True, | |
value=[], | |
interactive=True, | |
scale=1 | |
) | |
column_selector = gr.Dropdown( | |
choices=get_all_column_choices(), | |
label="Columns", | |
multiselect=True, | |
value=get_initial_columns(), | |
interactive=True, | |
scale=1 | |
) | |
with gr.Row(): | |
refresh_button = gr.Button("Refresh", scale=0, elem_id="refresh-button") | |
# Create tabs for each category | |
with gr.Tabs(elem_classes="category-tabs") as category_tabs: | |
# First tab for average metrics across all categories | |
with gr.TabItem("All Results", elem_id="overall-tab"): | |
leaderboard = init_leaderboard(LEADERBOARD_DF) | |
# Create a tab for each category using display names | |
for category in CATEGORIES: | |
display_name = CATEGORY_DISPLAY_MAP.get(category, category) | |
elem_id = f"category-{display_name.lower().replace(' ', '-').replace('&', 'and')}-tab" | |
with gr.TabItem(display_name, elem_id=elem_id): | |
category_df = get_category_leaderboard_df(category, version=CURRENT_VERSION) | |
category_leaderboard = init_leaderboard(category_df) | |
# Connect search and filter inputs to update function | |
def update_with_search_filters(version=CURRENT_VERSION, search_query="", model_types=None, selected_columns=None): | |
""" | |
Update the leaderboards with search and filter settings. | |
""" | |
return refresh_data_with_filters(version, search_query, model_types, selected_columns) | |
# Refresh button functionality | |
refresh_button.click( | |
fn=refresh_data_with_filters, | |
inputs=[version_selector, search_input, model_type_filter, column_selector], | |
outputs=[leaderboard] + [category_tabs.children[i].children[0] for i in range(1, len(CATEGORIES) + 1)] | |
) | |
# Search input functionality | |
search_input.change( | |
fn=refresh_data_with_filters, | |
inputs=[version_selector, search_input, model_type_filter, column_selector], | |
outputs=[leaderboard] + [category_tabs.children[i].children[0] for i in range(1, len(CATEGORIES) + 1)] | |
) | |
# Model type filter functionality | |
model_type_filter.change( | |
fn=refresh_data_with_filters, | |
inputs=[version_selector, search_input, model_type_filter, column_selector], | |
outputs=[leaderboard] + [category_tabs.children[i].children[0] for i in range(1, len(CATEGORIES) + 1)] | |
) | |
# Version selector functionality | |
version_selector.change( | |
fn=refresh_data_with_filters, | |
inputs=[version_selector, search_input, model_type_filter, column_selector], | |
outputs=[leaderboard] + [category_tabs.children[i].children[0] for i in range(1, len(CATEGORIES) + 1)] | |
) | |
# Update the update_columns function to handle updating all tabs at once | |
def update_columns(selected_columns): | |
""" | |
Update all leaderboards to show the selected columns. | |
Ensures all selected columns are preserved in the update. | |
""" | |
try: | |
logger.info(f"Updating columns to show: {selected_columns}") | |
# If no columns are selected, use default visible columns | |
if not selected_columns or len(selected_columns) == 0: | |
selected_columns = get_default_visible_columns() | |
logger.info(f"No columns selected, using defaults: {selected_columns}") | |
# Convert display names to internal names | |
internal_selected_columns = [x.lower().replace(" ", "_").replace("(", "").replace(")", "").replace("_recall", "_recall_binary").replace("_precision", "_precision_binary") for x in selected_columns] | |
# Get the current data with ALL columns preserved | |
main_df = get_leaderboard_df(version=version_selector.value) | |
# Get category dataframes with ALL columns preserved | |
category_dfs = [get_category_leaderboard_df(category, version=version_selector.value) | |
for category in CATEGORIES] | |
# Log columns for debugging | |
logger.info(f"Main dataframe columns: {list(main_df.columns)}") | |
logger.info(f"Selected columns (internal): {internal_selected_columns}") | |
# IMPORTANT: Make sure model_name is always included | |
if 'model_name' in main_df.columns and 'model_name' not in internal_selected_columns: | |
internal_selected_columns = ['model_name'] + internal_selected_columns | |
# Initialize the main leaderboard with the selected columns | |
# We're passing the internal_selected_columns directly to preserve the selection | |
main_leaderboard = init_leaderboard(main_df, internal_selected_columns) | |
# Initialize category dataframes with the same selected columns | |
# This ensures consistency across all tabs | |
category_leaderboards = [] | |
for df in category_dfs: | |
# Use the same selected columns for each category | |
# init_leaderboard will automatically handle filtering to columns that exist | |
category_leaderboards.append(init_leaderboard(df, internal_selected_columns)) | |
return main_leaderboard, *category_leaderboards | |
except Exception as e: | |
logger.error(f"Error updating columns: {e}") | |
import traceback | |
logger.error(traceback.format_exc()) | |
return leaderboard, *[tab.children[0] for tab in category_tabs.children[1:len(CATEGORIES)+1]] | |
# Connect column selector to update function | |
column_selector.change( | |
fn=update_columns, | |
inputs=[column_selector], | |
outputs=[leaderboard] + [category_tabs.children[i].children[0] for i in range(1, len(CATEGORIES) + 1)] | |
) | |
with gr.TabItem("Visualize", elem_id="guardbench-viz-tab", id=1): | |
with gr.Row(): | |
with gr.Column(): | |
viz_version_selector = gr.Dropdown( | |
choices=BENCHMARK_VERSIONS, | |
label="Benchmark Version", | |
value=CURRENT_VERSION, | |
interactive=True, | |
visible=False | |
) | |
model_selector = gr.Dropdown( | |
choices=update_model_choices(CURRENT_VERSION), | |
label="Select Models to Compare", | |
multiselect=True, | |
interactive=True | |
) | |
with gr.Column(): | |
# Add Overall Performance to categories, use display names | |
viz_categories_display = ["All Results"] + [CATEGORY_DISPLAY_MAP.get(cat, cat) for cat in CATEGORIES] | |
category_selector = gr.Dropdown( | |
choices=viz_categories_display, | |
label="Select Category", | |
value=viz_categories_display[0], | |
interactive=True | |
) | |
metric_selector = gr.Dropdown( | |
choices=["f1_binary", "precision_binary", "recall_binary"], | |
label="Select Metric", | |
value="f1_binary", | |
interactive=True | |
) | |
plot_output = gr.Plot() | |
# Update visualization when any selector changes | |
for control in [viz_version_selector, model_selector, category_selector, metric_selector]: | |
control.change( | |
fn=lambda sm, sc, s_metric, v: update_visualization(sm, CATEGORY_REVERSE_MAP.get(sc, sc), s_metric, v), | |
inputs=[model_selector, category_selector, metric_selector, viz_version_selector], | |
outputs=plot_output | |
) | |
# Update model choices when version changes | |
viz_version_selector.change( | |
fn=update_model_choices, | |
inputs=[viz_version_selector], | |
outputs=[model_selector] | |
) | |
# with gr.TabItem("About", elem_id="guardbench-about-tab", id=2): | |
# gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text") | |
with gr.TabItem("Submit", elem_id="guardbench-submit-tab", id=3): | |
gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text") | |
with gr.Row(): | |
# with gr.Column(scale=3): | |
# gr.Markdown("# ✉️✨ Submit your results here!", elem_classes="markdown-text") | |
with gr.Column(scale=1): | |
# Add version selector specifically for the submission tab | |
submission_version_selector = gr.Dropdown( | |
choices=BENCHMARK_VERSIONS, | |
label="Benchmark Version", | |
value=CURRENT_VERSION, | |
interactive=True, | |
elem_classes="version-selector", | |
visible=False | |
) | |
with gr.Row(): | |
with gr.Column(): | |
model_name_textbox = gr.Textbox(label="Model name") | |
mode_selector = gr.Dropdown( | |
choices=[m.name for m in Mode], | |
label="Mode", | |
multiselect=False, | |
value=None, | |
interactive=True, | |
) | |
revision_name_textbox = gr.Textbox(label="Revision commit", placeholder="main") | |
model_type = gr.Dropdown( | |
choices=[t.to_str(" : ") for t in ModelType if t != ModelType.Unknown], | |
label="Model type", | |
multiselect=False, | |
value=None, | |
interactive=True, | |
) | |
guard_model_type = gr.Dropdown( | |
choices=[t.name for t in GuardModelType], | |
label="Guard model type", | |
multiselect=False, | |
value=GuardModelType.LLM_REGEXP.name, | |
interactive=True, | |
) | |
with gr.Column(): | |
precision = gr.Dropdown( | |
choices=[i.name for i in Precision if i != Precision.Unknown], | |
label="Precision", | |
multiselect=False, | |
value="float16", | |
interactive=True, | |
) | |
weight_type = gr.Dropdown( | |
choices=[i.name for i in WeightType], | |
label="Weights type", | |
multiselect=False, | |
value="Original", | |
interactive=True, | |
) | |
base_model_name_textbox = gr.Textbox(label="Base model (for delta or adapter weights)") | |
with gr.Row(): | |
file_input = gr.File( | |
label="Upload JSONL Results File", | |
file_types=[".jsonl"] | |
) | |
submit_button = gr.Button("Submit Results") | |
result_output = gr.Markdown() | |
submit_button.click( | |
fn=submit_results, | |
inputs=[ | |
model_name_textbox, | |
base_model_name_textbox, | |
revision_name_textbox, | |
precision, | |
weight_type, | |
model_type, | |
mode_selector, | |
file_input, | |
submission_version_selector, | |
guard_model_type | |
], | |
outputs=result_output | |
) | |
# Version selector functionality | |
version_selector.change( | |
fn=update_leaderboards, | |
inputs=[version_selector], | |
outputs=[leaderboard] + [category_tabs.children[i].children[0] for i in range(1, len(CATEGORIES) + 1)] | |
) | |
# Set up the scheduler to refresh data periodically | |
scheduler = BackgroundScheduler() | |
scheduler.add_job(refresh_data, 'interval', minutes=30) | |
scheduler.start() | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |