kenkaneki's picture
cahages
f990f50
"""
Utility classes and functions for the CodeReview Bench Leaderboard display.
"""
from dataclasses import dataclass, field, fields
from enum import Enum, auto
from typing import List, Optional
class Mode(Enum):
"""Inference mode for the review model."""
CoT = auto() # Chain of Thought
Strict = auto()
def __str__(self):
"""String representation of the mode."""
return self.name
class ModelType(Enum):
"""Model types for the leaderboard."""
Unknown = auto()
OpenSource = auto()
ClosedSource = auto()
API = auto()
def to_str(self, separator: str = "-") -> str:
"""Convert enum to string with separator."""
if self == ModelType.Unknown:
return "Unknown"
elif self == ModelType.OpenSource:
return f"Open{separator}Source"
elif self == ModelType.ClosedSource:
return f"Closed{separator}Source"
elif self == ModelType.API:
return "API"
return "Unknown"
class ReviewModelType(str, Enum):
"""Review model types for the leaderboard."""
GPT_4 = "gpt-4"
GPT_3_5 = "gpt-3.5-turbo"
CLAUDE = "claude"
LLAMA = "llama"
GEMINI = "gemini"
CUSTOM = "custom"
def __str__(self):
"""String representation of the review model type."""
return self.value
class Precision(Enum):
"""Model precision types."""
Unknown = auto()
float16 = auto()
bfloat16 = auto()
float32 = auto()
int8 = auto()
int4 = auto()
NA = auto()
def __str__(self):
"""String representation of the precision type."""
return self.name
class WeightType(Enum):
"""Model weight types."""
Original = auto()
Delta = auto()
Adapter = auto()
def __str__(self):
"""String representation of the weight type."""
return self.name
@dataclass
class ColumnInfo:
"""Information about a column in the leaderboard."""
name: str
display_name: str
type: str = "text"
hidden: bool = False
never_hidden: bool = False
displayed_by_default: bool = True
@dataclass
class CodeReviewBenchColumn:
"""Columns for the CodeReview Bench leaderboard."""
# Core metadata
model_name: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="model_name",
display_name="Model",
never_hidden=True,
displayed_by_default=True
))
mode: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="mode",
display_name="Mode",
displayed_by_default=True
))
model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="model_type",
display_name="Access_Type",
displayed_by_default=True
))
submission_date: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="submission_date",
display_name="Submission_Date",
displayed_by_default=False
))
version: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="version",
display_name="Version",
displayed_by_default=False
))
review_model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="review_model_type",
display_name="Type",
displayed_by_default=False
))
base_model: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="base_model",
display_name="Base Model",
displayed_by_default=False
))
revision: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="revision",
display_name="Revision",
displayed_by_default=False
))
precision: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="precision",
display_name="Precision",
displayed_by_default=False
))
weight_type: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="weight_type",
display_name="Weight Type",
displayed_by_default=False
))
topic: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="topic",
display_name="Topic",
displayed_by_default=True
))
# LLM-based multimetric scores
readability: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="readability",
display_name="Readability",
type="number",
displayed_by_default=True
))
relevance: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="relevance",
display_name="Relevance",
type="number",
displayed_by_default=True
))
explanation_clarity: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="explanation_clarity",
display_name="Explanation_Clarity",
type="number",
displayed_by_default=True
))
problem_identification: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="problem_identification",
display_name="Problem_Identification",
type="number",
displayed_by_default=True
))
actionability: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="actionability",
display_name="Actionability",
type="number",
displayed_by_default=True
))
completeness: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="completeness",
display_name="Completeness",
type="number",
displayed_by_default=True
))
specificity: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="specificity",
display_name="Specificity",
type="number",
displayed_by_default=True
))
contextual_adequacy: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="contextual_adequacy",
display_name="Contextual_Adequacy",
type="number",
displayed_by_default=True
))
consistency: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="consistency",
display_name="Consistency",
type="number",
displayed_by_default=True
))
brevity: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="brevity",
display_name="Brevity",
type="number",
displayed_by_default=True
))
# LLM-based-exact-match metrics
pass_at_1: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="pass_at_1",
display_name="Pass@1",
type="number",
displayed_by_default=True
))
pass_at_5: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="pass_at_5",
display_name="Pass@5",
type="number",
displayed_by_default=True
))
pass_at_10: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="pass_at_10",
display_name="Pass@10",
type="number",
displayed_by_default=True
))
bleu_at_10: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="bleu_at_10",
display_name="BLEU@10",
type="number",
displayed_by_default=True
))
# Overall aggregated metrics
overall_score: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="overall_score",
display_name="Overall_Score",
type="number",
displayed_by_default=True
))
multimetric_average: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="multimetric_average",
display_name="Multimetric_Average",
type="number",
displayed_by_default=True
))
exact_match_average: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="exact_match_average",
display_name="Exact_Match_Average",
type="number",
displayed_by_default=True
))
total_evaluations: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="total_evaluations",
display_name="Total_Evaluations",
type="number",
displayed_by_default=True
))
# Language-specific metrics (Russian)
ru_readability: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="ru_readability",
display_name="RU_Readability",
type="number",
displayed_by_default=False
))
ru_relevance: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="ru_relevance",
display_name="RU_Relevance",
type="number",
displayed_by_default=False
))
ru_overall_score: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="ru_overall_score",
display_name="RU_Overall_Score",
type="number",
displayed_by_default=False
))
# Language-specific metrics (English)
en_readability: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="en_readability",
display_name="EN_Readability",
type="number",
displayed_by_default=False
))
en_relevance: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="en_relevance",
display_name="EN_Relevance",
type="number",
displayed_by_default=False
))
en_overall_score: ColumnInfo = field(default_factory=lambda: ColumnInfo(
name="en_overall_score",
display_name="EN_Overall_Score",
type="number",
displayed_by_default=False
))
# Create instances for easy access
CODEREVIEW_COLUMN = CodeReviewBenchColumn()
# Extract column lists for different views
COLS = [f.name for f in fields(CODEREVIEW_COLUMN)]
DISPLAY_COLS = [getattr(CODEREVIEW_COLUMN, f.name).name for f in fields(CODEREVIEW_COLUMN)
if getattr(CODEREVIEW_COLUMN, f.name).displayed_by_default]
# Manually reorder DISPLAY_COLS to put 'mode' after 'model_name'
def reorder_display_cols():
cols = DISPLAY_COLS
if 'model_name' in cols and 'mode' in cols:
cols.remove('mode')
model_name_index = cols.index('model_name')
cols.insert(model_name_index + 1, 'mode')
return cols
DISPLAY_COLS = reorder_display_cols()
METRIC_COLS = [getattr(CODEREVIEW_COLUMN, f.name).name for f in fields(CODEREVIEW_COLUMN)
if getattr(CODEREVIEW_COLUMN, f.name).type == "number"]
HIDDEN_COLS = [getattr(CODEREVIEW_COLUMN, f.name).name for f in fields(CODEREVIEW_COLUMN)
if getattr(CODEREVIEW_COLUMN, f.name).hidden]
NEVER_HIDDEN_COLS = [getattr(CODEREVIEW_COLUMN, f.name).name for f in fields(CODEREVIEW_COLUMN)
if getattr(CODEREVIEW_COLUMN, f.name).never_hidden]
# Categories for CodeReview Bench (Programming Languages)
CATEGORIES = [
'Python',
'Java',
'Scala',
'Go'
]
# Language taxonomies for CodeReview Bench
COMMENT_LANGUAGES = [
'ru', # Russian
'en' # English
]
# Topics for CodeReview Bench
TOPICS = [
'Code Reliability',
'Coding Standards',
'Code Organization',
'Performance Issues',
'Validation',
'Variables'
]
# Example categories
EXAMPLE_CATEGORIES = [
'Bug_Fix',
'Code_Style',
'Performance',
'Security',
'Refactoring',
'Documentation',
'Testing',
'Architecture',
'Other'
]
# Metrics for CodeReview Bench
MULTIMETRIC_METRICS = [
"readability",
"relevance",
"explanation_clarity",
"problem_identification",
"actionability",
"completeness",
"specificity",
"contextual_adequacy",
"consistency",
"brevity"
]
EXACT_MATCH_METRICS = [
"pass_at_1",
"pass_at_5",
"pass_at_10",
"bleu_at_10"
]
def get_all_column_choices():
"""
Get all available column choices for the multiselect dropdown.
Returns:
List of tuples with (column_name, display_name) for all columns.
"""
column_choices = []
default_visible_columns = get_default_visible_columns()
for f in fields(CODEREVIEW_COLUMN):
column_info = getattr(CODEREVIEW_COLUMN, f.name)
# Create a tuple with both the internal name and display name
if column_info.name not in default_visible_columns:
column_choices.append((column_info.name, column_info.display_name))
return column_choices
def get_default_visible_columns():
"""
Get the list of column names that should be visible by default.
Returns:
List of column names that are displayed by default.
"""
return [getattr(CODEREVIEW_COLUMN, f.name).name for f in fields(CODEREVIEW_COLUMN)
if getattr(CODEREVIEW_COLUMN, f.name).displayed_by_default]