Spaces:
Sleeping
Sleeping
""" | |
Utility classes and functions for the GuardBench Leaderboard display. | |
""" | |
from dataclasses import dataclass, field, fields | |
from enum import Enum, auto | |
from typing import List, Optional | |
class ModelType(Enum): | |
"""Model types for the leaderboard.""" | |
Unknown = auto() | |
OpenSource = auto() | |
ClosedSource = auto() | |
API = auto() | |
def to_str(self, separator: str = " ") -> str: | |
"""Convert enum to string with separator.""" | |
if self == ModelType.Unknown: | |
return "Unknown" | |
elif self == ModelType.OpenSource: | |
return f"Open{separator}Source" | |
elif self == ModelType.ClosedSource: | |
return f"Closed{separator}Source" | |
elif self == ModelType.API: | |
return "API" | |
return "Unknown" | |
class GuardModelType(str, Enum): | |
"""Guard model types for the leaderboard.""" | |
LLAMA_GUARD = "llama_guard" | |
PROMPT_GUARD_CLF = "prompt_guard_clf" | |
ATLA_SELENE = "atla_selene" | |
GEMMA_SHIELD = "gemma_shield" | |
LLM_REGEXP = "llm_regexp" | |
LLM_SO = "llm_so" | |
def __str__(self): | |
"""String representation of the guard model type.""" | |
return self.name | |
class Precision(Enum): | |
"""Model precision types.""" | |
Unknown = auto() | |
float16 = auto() | |
bfloat16 = auto() | |
float32 = auto() | |
int8 = auto() | |
int4 = auto() | |
def __str__(self): | |
"""String representation of the precision type.""" | |
return self.name | |
class WeightType(Enum): | |
"""Model weight types.""" | |
Original = auto() | |
Delta = auto() | |
Adapter = auto() | |
def __str__(self): | |
"""String representation of the weight type.""" | |
return self.name | |
class ColumnInfo: | |
"""Information about a column in the leaderboard.""" | |
name: str | |
display_name: str | |
type: str = "text" | |
hidden: bool = False | |
never_hidden: bool = False | |
displayed_by_default: bool = True | |
class GuardBenchColumn: | |
"""Columns for the GuardBench leaderboard.""" | |
# Core metadata | |
model_name: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="model_name", | |
display_name="Model", | |
never_hidden=True, | |
displayed_by_default=True | |
)) | |
model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="model_type", | |
display_name="Type", | |
displayed_by_default=True | |
)) | |
submission_date: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="submission_date", | |
display_name="Submission Date", | |
displayed_by_default=False | |
)) | |
version: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="version", | |
display_name="Version", | |
displayed_by_default=False | |
)) | |
guard_model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="guard_model_type", | |
display_name="Guard Model Type", | |
displayed_by_default=True | |
)) | |
base_model: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="base_model", | |
display_name="Base Model", | |
displayed_by_default=False | |
)) | |
revision: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="revision", | |
display_name="Revision", | |
displayed_by_default=False | |
)) | |
precision: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="precision", | |
display_name="Precision", | |
displayed_by_default=False | |
)) | |
weight_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="weight_type", | |
display_name="Weight Type", | |
displayed_by_default=False | |
)) | |
# Default prompts metrics | |
default_prompts_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_prompts_f1_binary", | |
display_name="Default Prompts F1 Binary", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_prompts_f1", | |
display_name="Default Prompts F1", | |
type="number", | |
displayed_by_default=True | |
)) | |
default_prompts_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_prompts_recall_binary", | |
display_name="Default Prompts Recall", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_prompts_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_prompts_precision_binary", | |
display_name="Default Prompts Precision", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_prompts_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_prompts_error_ratio", | |
display_name="Default Prompts Error Ratio", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_prompts_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_prompts_avg_runtime_ms", | |
display_name="Default Prompts Avg Runtime (ms)", | |
type="number", | |
displayed_by_default=False | |
)) | |
# Jailbreaked prompts metrics | |
jailbreaked_prompts_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_prompts_f1_binary", | |
display_name="Jailbreaked Prompts F1 Binary", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_prompts_f1", | |
display_name="Jailbreaked Prompts F1", | |
type="number", | |
displayed_by_default=True | |
)) | |
jailbreaked_prompts_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_prompts_recall_binary", | |
display_name="Jailbreaked Prompts Recall", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_prompts_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_prompts_precision_binary", | |
display_name="Jailbreaked Prompts Precision", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_prompts_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_prompts_error_ratio", | |
display_name="Jailbreaked Prompts Error Ratio", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_prompts_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_prompts_avg_runtime_ms", | |
display_name="Jailbreaked Prompts Avg Runtime (ms)", | |
type="number", | |
displayed_by_default=False | |
)) | |
# Default answers metrics | |
default_answers_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_answers_f1_binary", | |
display_name="Default Answers F1 Binary", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_answers_f1", | |
display_name="Default Answers F1", | |
type="number", | |
displayed_by_default=True | |
)) | |
default_answers_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_answers_recall_binary", | |
display_name="Default Answers Recall", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_answers_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_answers_precision_binary", | |
display_name="Default Answers Precision", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_answers_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_answers_error_ratio", | |
display_name="Default Answers Error Ratio", | |
type="number", | |
displayed_by_default=False | |
)) | |
default_answers_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_answers_avg_runtime_ms", | |
display_name="Default Answers Avg Runtime (ms)", | |
type="number", | |
displayed_by_default=False | |
)) | |
# Jailbreaked answers metrics | |
jailbreaked_answers_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_answers_f1_binary", | |
display_name="Jailbreaked Answers F1 Binary", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_answers_f1", | |
display_name="Jailbreaked Answers F1", | |
type="number", | |
displayed_by_default=True | |
)) | |
jailbreaked_answers_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_answers_recall_binary", | |
display_name="Jailbreaked Answers Recall", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_answers_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_answers_precision_binary", | |
display_name="Jailbreaked Answers Precision", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_answers_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_answers_error_ratio", | |
display_name="Jailbreaked Answers Error Ratio", | |
type="number", | |
displayed_by_default=False | |
)) | |
jailbreaked_answers_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_answers_avg_runtime_ms", | |
display_name="Jailbreaked Answers Avg Runtime (ms)", | |
type="number", | |
displayed_by_default=False | |
)) | |
# Create instances for easy access | |
GUARDBENCH_COLUMN = GuardBenchColumn() | |
# Extract column lists for different views | |
COLS = [f.name for f in fields(GUARDBENCH_COLUMN)] | |
DISPLAY_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).displayed_by_default] | |
METRIC_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).type == "number"] | |
HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).hidden] | |
NEVER_HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).never_hidden] | |
# Categories in GuardBench | |
CATEGORIES = [ | |
"Criminal, Violent, and Terrorist Activity", | |
"Manipulation, Deception, and Misinformation", | |
"Creative Content Involving Illicit Themes", | |
"Sexual Content and Violence", | |
"Political Corruption and Legal Evasion", | |
"Labor Exploitation and Human Trafficking", | |
"Environmental and Industrial Harm", | |
"Animal Cruelty and Exploitation", | |
"Self–Harm and Suicidal Ideation", | |
"Safe Prompts" | |
] | |
# Test types in GuardBench | |
TEST_TYPES = [ | |
"default_prompts", | |
"jailbreaked_prompts", | |
"default_answers", | |
"jailbreaked_answers" | |
] | |
# Metrics in GuardBench | |
METRICS = [ | |
"f1_binary", | |
"recall_binary", | |
"precision_binary", | |
"error_ratio", | |
"avg_runtime_ms" | |
] | |
def get_all_column_choices(): | |
""" | |
Get all available column choices for the multiselect dropdown. | |
Returns: | |
List of tuples with (column_name, display_name) for all columns. | |
""" | |
column_choices = [] | |
default_visible_columns = get_default_visible_columns() | |
for f in fields(GUARDBENCH_COLUMN): | |
column_info = getattr(GUARDBENCH_COLUMN, f.name) | |
# Create a tuple with both the internal name and display name | |
if column_info.name not in default_visible_columns: | |
column_choices.append((column_info.name, column_info.display_name)) | |
return column_choices | |
def get_default_visible_columns(): | |
""" | |
Get the list of column names that should be visible by default. | |
Returns: | |
List of column names that are displayed by default. | |
""" | |
return [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).displayed_by_default] | |