Spaces:
Sleeping
Sleeping
""" | |
Utility classes and functions for the GuardBench Leaderboard display. | |
""" | |
from dataclasses import dataclass, field, fields | |
from enum import Enum, auto | |
from typing import List, Optional | |
class ModelType(Enum): | |
"""Model types for the leaderboard.""" | |
Unknown = auto() | |
OpenSource = auto() | |
ClosedSource = auto() | |
API = auto() | |
def to_str(self, separator: str = " ") -> str: | |
"""Convert enum to string with separator.""" | |
if self == ModelType.Unknown: | |
return "Unknown" | |
elif self == ModelType.OpenSource: | |
return f"Open{separator}Source" | |
elif self == ModelType.ClosedSource: | |
return f"Closed{separator}Source" | |
elif self == ModelType.API: | |
return "API" | |
return "Unknown" | |
class Precision(Enum): | |
"""Model precision types.""" | |
Unknown = auto() | |
float16 = auto() | |
bfloat16 = auto() | |
float32 = auto() | |
int8 = auto() | |
int4 = auto() | |
def __str__(self): | |
"""String representation of the precision type.""" | |
return self.name | |
class WeightType(Enum): | |
"""Model weight types.""" | |
Original = auto() | |
Delta = auto() | |
Adapter = auto() | |
def __str__(self): | |
"""String representation of the weight type.""" | |
return self.name | |
class ColumnInfo: | |
"""Information about a column in the leaderboard.""" | |
name: str | |
display_name: str | |
type: str = "text" | |
hidden: bool = False | |
never_hidden: bool = False | |
displayed_by_default: bool = True | |
class GuardBenchColumn: | |
"""Columns for the GuardBench leaderboard.""" | |
model_name: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="model_name", | |
display_name="Model", | |
never_hidden=True, | |
displayed_by_default=True | |
)) | |
model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="model_type", | |
display_name="Type", | |
displayed_by_default=True | |
)) | |
# Metrics for all categories | |
default_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_prompts_f1", | |
display_name="Default Prompts F1", | |
type="number", | |
displayed_by_default=True | |
)) | |
jailbreaked_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_prompts_f1", | |
display_name="Jailbreaked Prompts F1", | |
type="number", | |
displayed_by_default=True | |
)) | |
default_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="default_answers_f1", | |
display_name="Default Answers F1", | |
type="number", | |
displayed_by_default=True | |
)) | |
jailbreaked_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="jailbreaked_answers_f1", | |
display_name="Jailbreaked Answers F1", | |
type="number", | |
displayed_by_default=True | |
)) | |
# Average metrics | |
average_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="average_f1", | |
display_name="Average F1", | |
type="number", | |
displayed_by_default=True, | |
never_hidden=True | |
)) | |
average_recall: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="average_recall", | |
display_name="Average Recall", | |
type="number", | |
displayed_by_default=False | |
)) | |
average_precision: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="average_precision", | |
display_name="Average Precision", | |
type="number", | |
displayed_by_default=False | |
)) | |
# Additional metadata | |
submission_date: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
name="submission_date", | |
display_name="Submission Date", | |
displayed_by_default=False | |
)) | |
# Create instances for easy access | |
GUARDBENCH_COLUMN = GuardBenchColumn() | |
# Extract column lists for different views | |
COLS = [f.name for f in fields(GUARDBENCH_COLUMN)] | |
DISPLAY_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).displayed_by_default] | |
METRIC_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).type == "number"] | |
HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).hidden] | |
NEVER_HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
if getattr(GUARDBENCH_COLUMN, f.name).never_hidden] | |
# Categories in GuardBench | |
CATEGORIES = [ | |
"Criminal, Violent, and Terrorist Activity", | |
"Manipulation, Deception, and Misinformation", | |
"Creative Content Involving Illicit Themes", | |
"Sexual Content and Violence", | |
"Political Corruption and Legal Evasion", | |
"Labor Exploitation and Human Trafficking", | |
"Environmental and Industrial Harm", | |
"Animal Cruelty and Exploitation", | |
"Self–Harm and Suicidal Ideation", | |
"Safe Prompts" | |
] | |
# Test types in GuardBench | |
TEST_TYPES = [ | |
"default_prompts", | |
"jailbreaked_prompts", | |
"default_answers", | |
"jailbreaked_answers" | |
] | |
# Metrics in GuardBench | |
METRICS = [ | |
"f1_binary", | |
"recall_binary", | |
"precision_binary", | |
"error_ratio", | |
"avg_runtime_ms" | |
] | |