Spaces:
Sleeping
Sleeping
""" | |
Utility classes and functions for the GuardBench Leaderboard display. | |
""" | |
from dataclasses import dataclass, field, fields | |
from enum import Enum, auto | |
from typing import List, Optional | |
class Mode(Enum): | |
"""Inference mode for the guard model.""" | |
CoT = auto() # Chain of Thought | |
Strict = auto() | |
def __str__(self): | |
"""String representation of the mode.""" | |
return self.name | |
class ModelType(Enum): | |
"""Model types for the leaderboard.""" | |
Unknown = auto() | |
OpenSource = auto() | |
ClosedSource = auto() | |
API = auto() | |
def to_str(self, separator: str = " ") -> str: | |
"""Convert enum to string with separator.""" | |
if self == ModelType.Unknown: | |
return "Unknown" | |
elif self == ModelType.OpenSource: | |
return f"Open{separator}Source" | |
elif self == ModelType.ClosedSource: | |
return f"Closed{separator}Source" | |
elif self == ModelType.API: | |
return "API" | |
return "Unknown" | |
class GuardModelType(str, Enum): | |
"""Guard model types for the leaderboard.""" | |
LLAMA_GUARD = "llama_guard" | |
PROMPT_GUARD_CLF = "prompt_guard_clf" | |
ATLA_SELENE = "atla_selene" | |
GEMMA_SHIELD = "gemma_shield" | |
LLM_REGEXP = "llm_regexp" | |
LLM_SO = "llm_so" | |
WC_GUARD = "wc_guard" | |
def __str__(self): | |
"""String representation of the guard model type.""" | |
return self.name | |
class Precision(Enum): | |
"""Model precision types.""" | |
Unknown = auto() | |
float16 = auto() | |
bfloat16 = auto() | |
float32 = auto() | |
int8 = auto() | |
int4 = auto() | |
NA = auto() | |
def __str__(self): | |
"""String representation of the precision type.""" | |
return self.name | |
class WeightType(Enum): | |
"""Model weight types.""" | |
Original = auto() | |
Delta = auto() | |
Adapter = auto() | |
def __str__(self): | |
"""String representation of the weight type.""" | |
return self.name | |
class ColumnInfo: | |
"""Information about a column in the leaderboard.""" | |
name: str | |
display_name: str | |
type: str = "text" | |
hidden: bool = False | |
never_hidden: bool = False | |
displayed_by_default: bool = True | |
class GuardBenchColumn: | |
"""Columns for the GuardBench leaderboard.""" | |
# Core metadata | |
model_name: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="model_name", | |
display_name="Model", | |
never_hidden=True, | |
displayed_by_default=True | |
)) | |
mode: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="mode", | |
display_name="Mode", | |
displayed_by_default=True | |
)) | |
model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="model_type", | |
display_name="Access_Type", | |
displayed_by_default=True | |
)) | |
submission_date: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="submission_date", | |
display_name="Submission_Date", | |
displayed_by_default=False | |
)) | |
version: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="version", | |
display_name="Version", | |
displayed_by_default=False | |
)) | |
guard_model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="guard_model_type", | |
display_name="Type", | |
displayed_by_default=True | |
)) | |
base_model: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="base_model", | |
display_name="Base Model", | |
displayed_by_default=False | |
)) | |
revision: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="revision", | |
display_name="Revision", | |
displayed_by_default=False | |
)) | |
precision: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="precision", | |
display_name="Precision", | |
displayed_by_default=False | |
)) | |
weight_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="weight_type", | |
display_name="Weight Type", | |
displayed_by_default=False | |
)) | |
# Default prompts metrics | |
default_prompts_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_prompts_f1_binary", | |
display_name="Default_Prompts_F1_Binary", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_prompts_f1", | |
display_name="Default_Prompts_F1", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_prompts_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_prompts_recall_binary", | |
display_name="Default_Prompts_Recall", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_prompts_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_prompts_precision_binary", | |
display_name="Default_Prompts_Precision", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_prompts_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_prompts_error_ratio", | |
display_name="Default_Prompts_Error_Ratio", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_prompts_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_prompts_avg_runtime_ms", | |
display_name="Default_Prompts_Avg_Runtime_ms", | |
type="number", | |
displayed_by_default=False | |
)) | |
# Jailbreaked prompts metrics | |
jailbreaked_prompts_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_prompts_f1_binary", | |
display_name="Jailbreaked_Prompts_F1_Binary", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_prompts_f1", | |
display_name="Jailbreaked_Prompts_F1", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_prompts_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_prompts_recall_binary", | |
display_name="Jailbreaked_Prompts_Recall", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_prompts_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_prompts_precision_binary", | |
display_name="Jailbreaked_Prompts_Precision", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_prompts_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_prompts_error_ratio", | |
display_name="Jailbreaked_Prompts_Error_Ratio", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_prompts_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_prompts_avg_runtime_ms", | |
display_name="Jailbreaked_Prompts_Avg_Runtime_ms", | |
type="number", | |
displayed_by_default=False | |
)) | |
# Default answers metrics | |
default_answers_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_answers_f1_binary", | |
display_name="Default_Answers_F1_Binary", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_answers_f1", | |
display_name="Default_Answers_F1", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_answers_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_answers_recall_binary", | |
display_name="Default_Answers_Recall", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_answers_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_answers_precision_binary", | |
display_name="Default_Answers_Precision", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_answers_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_answers_error_ratio", | |
display_name="Default_Answers_Error_Ratio", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_answers_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_answers_avg_runtime_ms", | |
display_name="Default_Answers_Avg_Runtime_ms", | |
type="number", | |
displayed_by_default=False | |
)) | |
# Jailbreaked answers metrics | |
jailbreaked_answers_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_answers_f1_binary", | |
display_name="Jailbreaked_Answers_F1_Binary", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_answers_f1", | |
display_name="Jailbreaked_Answers_F1", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_answers_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_answers_recall_binary", | |
display_name="Jailbreaked_Answers_Recall", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_answers_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_answers_precision_binary", | |
display_name="Jailbreaked_Answers_Precision", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_answers_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_answers_error_ratio", | |
display_name="Jailbreaked_Answers_Error_Ratio", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_answers_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_answers_avg_runtime_ms", | |
display_name="Jailbreaked_Answers_Avg_Runtime_ms", | |
type="number", | |
displayed_by_default=False | |
)) | |
integral_score: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="integral_score", | |
display_name="Integral_Score", | |
type="number", | |
displayed_by_default=True | |
)) | |
# Calculated overall metrics (renamed) | |
macro_accuracy: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="macro_accuracy", | |
display_name="Macro_Accuracy", | |
type="number", | |
displayed_by_default=True | |
)) | |
macro_recall: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="macro_recall", | |
display_name="Macro_Recall", | |
type="number", | |
displayed_by_default=True | |
)) | |
macro_precision: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="macro_precision", | |
display_name="Macro Precision", | |
type="number", | |
displayed_by_default=False | |
)) | |
# NEW Summary Metrics | |
micro_avg_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="micro_avg_error_ratio", | |
display_name="Micro_Error", | |
type="number", | |
displayed_by_default=True | |
)) | |
micro_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="micro_avg_runtime_ms", | |
display_name="Micro_Avg_time_ms", | |
type="number", | |
displayed_by_default=True | |
)) | |
total_evals_count: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="total_evals_count", | |
display_name="Total_Count", | |
type="number", | |
displayed_by_default=True | |
)) | |
# Create instances for easy access | |
GUARDBENCH_COLUMN = GuardBenchColumn() | |
# Extract column lists for different views | |
COLS = [f.name for f in fields(GUARDBENCH_COLUMN)] | |
DISPLAY_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).displayed_by_default] | |
# Manually reorder DISPLAY_COLS to put 'mode' after 'model_name' | |
def reorder_display_cols(): | |
cols = DISPLAY_COLS | |
if 'model_name' in cols and 'mode' in cols: | |
cols.remove('mode') | |
model_name_index = cols.index('model_name') | |
cols.insert(model_name_index + 1, 'mode') | |
return cols | |
DISPLAY_COLS = reorder_display_cols() | |
METRIC_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).type == "number"] | |
HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).hidden] | |
NEVER_HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).never_hidden] | |
# Categories in GuardBench | |
CATEGORIES = [ | |
'Political Corruption and Legal Evasion', | |
'Financial Fraud and Unethical Business', | |
'AI Manipulation and Jailbreaking', | |
'Child Exploitation and Abuse', | |
'Hate Speech, Extremism, and Discrimination', | |
'Labor Exploitation and Human Trafficking', | |
'Manipulation, Deception, and Misinformation', | |
'Environmental and Industrial Harm', | |
'Academic Dishonesty and Cheating', | |
'Self–Harm and Suicidal Ideation', | |
'Animal Cruelty and Exploitation', | |
'Criminal, Violent, and Terrorist Activity', | |
'Drug– and Substance–Related Activities', | |
'Sexual Content and Violence', | |
'Weapon, Explosives, and Hazardous Materials', | |
'Cybercrime, Hacking, and Digital Exploits', | |
'Creative Content Involving Illicit Themes', | |
'Safe Prompts' | |
] | |
# Test types in GuardBench | |
TEST_TYPES = [ | |
"default_prompts", | |
"jailbreaked_prompts", | |
"default_answers", | |
"jailbreaked_answers" | |
] | |
# Metrics in GuardBench | |
METRICS = [ | |
"f1_binary", | |
"recall_binary", | |
"precision_binary", | |
"error_ratio", | |
"avg_runtime_ms", | |
"accuracy" | |
] | |
def get_all_column_choices(): | |
""" | |
Get all available column choices for the multiselect dropdown. | |
Returns: | |
List of tuples with (column_name, display_name) for all columns. | |
""" | |
column_choices = [] | |
default_visible_columns = get_default_visible_columns() | |
for f in fields(GUARDBENCH_COLUMN): | |
column_info = getattr(GUARDBENCH_COLUMN, f.name) | |
# Create a tuple with both the internal name and display name | |
if column_info.name not in default_visible_columns: | |
column_choices.append((column_info.name, column_info.display_name)) | |
return column_choices | |
def get_default_visible_columns(): | |
""" | |
Get the list of column names that should be visible by default. | |
Returns: | |
List of column names that are displayed by default. | |
""" | |
return [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).displayed_by_default] | |