SuperBench-Eval / app.py
Enderchef's picture
Update app.py
aae1544 verified
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from datasets import load_dataset, get_dataset_config_names
import torch
import re
import json
import pandas as pd
import traceback
import spaces
from datetime import datetime
# --- Environment and Caching ---
# It's good practice to ensure the cache directory exists.
CACHE_DIR = "evaluation_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
EVAL_FILE = os.path.join(CACHE_DIR, "evals.jsonl")
# Cache to avoid reloading models and dataset configs
model_cache = {}
benchmark_subject_cache = {}
# Use environment variable for the Hugging Face token
HF_TOKEN = os.environ.get("HF_TOKEN")
# --- Constants for Benchmarks ---
MMLU_DATASET = "cais/mmlu"
BENCHMARK_MAP = {
"MMLU": MMLU_DATASET,
}
# --- Data Loading and Preparation ---
def get_all_benchmark_options():
"""
Fetches and caches the available subjects (configs) for each benchmark dataset.
This function now populates a global cache to avoid repeated API calls.
"""
if benchmark_subject_cache:
return benchmark_subject_cache
print("Fetching benchmark configurations for the first time...")
for key, dataset_id in BENCHMARK_MAP.items():
try:
subjects = get_dataset_config_names(dataset_id, token=HF_TOKEN)
benchmark_subject_cache[key] = ["ALL"] + sorted([s for s in subjects if s != 'all'])
except Exception as e:
print(f"Warning: Could not load configs for {key} ({dataset_id}). It might be private or unavailable. Error: {e}")
benchmark_subject_cache[key] = ["ALL"]
print("Benchmark configurations cached.")
return benchmark_subject_cache
# Initialize the cache on startup
ALL_BENCHMARK_SUBJECTS = get_all_benchmark_options()
@spaces.GPU()
def load_model(model_id):
"""
Loads a Hugging Face model and tokenizer, creating a text-generation pipeline.
Uses a cache to avoid reloading models.
"""
if not model_id:
raise ValueError("Model ID cannot be empty.")
gr.Info(f"Attempting to load model: {model_id}...")
if model_id in model_cache:
gr.Info(f"Model '{model_id}' found in cache.")
return model_cache[model_id]
try:
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=HF_TOKEN,
torch_dtype=dtype,
trust_remote_code=True,
low_cpu_mem_usage=True,
).to("cuda" if torch.cuda.is_available() else "cpu")
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
model_cache[model_id] = generator
gr.Info(f"Model '{model_id}' loaded successfully.")
return generator
except Exception as e:
raise RuntimeError(f"Failed to load model '{model_id}'. Please verify the model ID and your Hugging Face token. Error: {e}")
# --- Evaluation Logic ---
def format_prompt(item):
"""Formats the MMLU question and choices into a standardized prompt."""
prompt = f"Question: {item['question']}\n\nChoices:\nA. {item['choices'][0]}\nB. {item['choices'][1]}\nC. {item['choices'][2]}\nD. {item['choices'][3]}\n\nAnswer:"
return prompt, item['answer']
def get_choice_letter(index):
"""Converts a numerical choice index (0-3) to a letter (A-D)."""
return chr(ord('A') + index) if 0 <= index <= 3 else None
def extract_predicted_letter(output_text):
"""Extracts the predicted letter from the model's output."""
match = re.search(r"Answer:\s*([ABCD])", output_text.strip(), re.IGNORECASE)
if match:
return match.group(1).upper()
match = re.search(r"^\s*([ABCD])\b", output_text.strip())
if match:
return match.group(1).upper()
return None
def make_progress_html(text, percentage):
"""Helper function to create the HTML for the progress bar."""
return f"""
<div class="progress-container">
<div class="progress-bar" style="width: {percentage}%;">
{text}
</div>
</div>
"""
@spaces.GPU()
def run_evaluation(model_id, benchmark_category, subject_name, sample_count):
"""
Main generator function to orchestrate the evaluation, yielding progress updates.
"""
try:
# 1. Initial yield to set up the UI for loading state
yield {
progress_box: gr.update(visible=True),
progress_text_output: gr.update(value=f"Preparing evaluation for **{model_id}**..."),
progress_bar_output: gr.update(value=make_progress_html("Loading Model...", 0)),
result_summary_box: gr.update(visible=False),
details_box: gr.update(visible=False),
error_box: gr.update(visible=False),
}
generator = load_model(model_id)
dataset_id = BENCHMARK_MAP.get(benchmark_category)
if not dataset_id:
raise ValueError(f"Invalid benchmark category: {benchmark_category}")
subjects_to_run = []
if subject_name == "ALL":
subjects_to_run = [s for s in ALL_BENCHMARK_SUBJECTS.get(benchmark_category, []) if s != "ALL"]
else:
subjects_to_run = [subject_name]
if not subjects_to_run:
gr.Warning(f"No subjects found for '{benchmark_category}'.")
yield { progress_box: gr.update(visible=False) }
return
all_results_details = []
summary_lines = []
total_correct = 0
total_samples = 0
# 2. Main evaluation loop
for i, subject in enumerate(subjects_to_run):
overall_progress_text = f"**Overall Progress ({i+1}/{len(subjects_to_run)} subjects)**"
yield {
progress_text_output: gr.update(value=f"{overall_progress_text}\n\nLoading dataset for **{subject}**...")
}
try:
# Load dataset for the current subject
dataset = load_dataset(dataset_id, subject, token=HF_TOKEN, split="test")
num_samples = min(sample_count, len(dataset))
dataset = dataset.shuffle(seed=42).select(range(num_samples))
correct_predictions_subject = 0
subject_details = []
# Loop over samples within the subject
for j, item in enumerate(dataset):
prompt, correct_answer_idx = format_prompt(item)
expected_letter = get_choice_letter(correct_answer_idx)
full_prompt_text = generator.tokenizer.decode(generator.tokenizer.encode(prompt), skip_special_tokens=True)
raw_output = generator(prompt, max_new_tokens=5, do_sample=False, pad_token_id=generator.tokenizer.eos_token_id)[0]["generated_text"]
generated_text_only = raw_output[len(full_prompt_text):].strip()
predicted_letter = extract_predicted_letter(generated_text_only)
is_correct = (predicted_letter == expected_letter)
if is_correct:
correct_predictions_subject += 1
subject_details.append({
"Question": item['question'],
"Correct": "βœ…" if is_correct else "❌",
"Expected": expected_letter,
"Predicted": predicted_letter or "N/A",
"Model Output": generated_text_only
})
# Yield progress update for each sample
percentage = ((j + 1) / num_samples) * 100
progress_bar_text = f"Evaluating: {subject} ({j+1}/{num_samples})"
yield {
progress_bar_output: gr.update(value=make_progress_html(f"{percentage:.1f}%", percentage)),
progress_text_output: gr.update(value=f"{overall_progress_text}\n\n{progress_bar_text}")
}
accuracy = (correct_predictions_subject / num_samples) * 100 if num_samples > 0 else 0
all_results_details.extend(subject_details)
total_correct += correct_predictions_subject
total_samples += num_samples
summary_lines.append(f"- **{subject}**: {accuracy:.2f}% ({correct_predictions_subject}/{num_samples})")
except Exception as e:
error_trace = traceback.format_exc()
gr.Error(f"Skipping {subject} due to an error: {e}")
summary_lines.append(f"- **{subject}**: Evaluation failed. See logs for details:\n```\n{error_trace}\n```")
continue
# 3. Final processing and result preparation
overall_accuracy = (total_correct / total_samples) * 100 if total_samples > 0 else 0
if subject_name == "ALL":
result_summary = f"### Overall Average Accuracy: {overall_accuracy:.2f}%\n"
result_summary += f"across {total_samples:,} total samples from {len(subjects_to_run)} subjects.\n\n---\n\n**Breakdown by Subject:**\n"
result_summary += "\n".join(summary_lines)
else:
result_summary = f"### Accuracy for {benchmark_category} - {subject_name}: {overall_accuracy:.2f}%\n"
result_summary += f"({total_correct:,}/{total_samples:,} correct)"
# Write final result to the JSONL file
record = {
"model_id": model_id,
"benchmark": benchmark_category,
"accuracy": overall_accuracy,
"subject": subject_name,
"sample_count": total_samples,
"timestamp": datetime.now().isoformat()
}
with open(EVAL_FILE, "a") as f:
f.write(json.dumps(record) + "\n")
gr.Info("Evaluation completed successfully!")
df_details = pd.DataFrame(all_results_details)
# 4. Final yield to show results and hide progress UI
yield {
progress_box: gr.update(visible=False),
result_summary_box: gr.update(visible=True),
result_summary_output: gr.update(value=result_summary),
details_box: gr.update(visible=True),
detailed_results_df: gr.update(value=df_details),
error_box: gr.update(visible=False)
}
except Exception as e:
error_message = f"An unexpected error occurred: {e}"
error_details = traceback.format_exc()
gr.Error(error_message)
# Yield to show error message and hide progress UI
yield {
progress_box: gr.update(visible=False),
result_summary_box: gr.update(visible=False),
details_box: gr.update(visible=False),
error_box: gr.update(visible=True),
error_output: gr.update(value=error_message),
error_details_output: gr.update(value=error_details),
}
# --- UI Helper Functions ---
def update_subject_dropdown(benchmark_category):
"""Updates the subject dropdown choices based on the selected benchmark."""
choices = ALL_BENCHMARK_SUBJECTS.get(benchmark_category, [])
default_value = "ALL" if "ALL" in choices else (choices[0] if choices else None)
return gr.update(choices=choices, value=default_value)
def load_leaderboard(benchmark_filter, progress=gr.Progress()):
"""
Loads and processes evaluation data to display on the leaderboard.
"""
progress(0, desc="Loading Leaderboard...")
try:
if not os.path.exists(EVAL_FILE):
return pd.DataFrame(columns=["Rank", "Model ID", "Avg. Accuracy (%)", "Total Samples", "Date"])
df = pd.read_json(EVAL_FILE, lines=True)
if df.empty:
return pd.DataFrame(columns=["Rank", "Model ID", "Avg. Accuracy (%)", "Total Samples", "Date"])
df['accuracy'] = pd.to_numeric(df['accuracy'], errors='coerce')
df.dropna(subset=['accuracy'], inplace=True)
# Filter for 'ALL' subject runs for the selected benchmark
df_filtered = df[(df['benchmark'] == benchmark_filter) & (df['subject'] == 'ALL')].copy()
if df_filtered.empty:
return pd.DataFrame(columns=["Rank", "Model ID", "Avg. Accuracy (%)", "Total Samples", "Date"])
df_filtered['timestamp'] = pd.to_datetime(df_filtered['timestamp'])
latest_evals = df_filtered.loc[df_filtered.groupby('model_id')['timestamp'].idxmax()].copy()
leaderboard_df = latest_evals.sort_values(by="accuracy", ascending=False).copy()
leaderboard_df.insert(0, 'Rank', range(1, len(leaderboard_df) + 1))
leaderboard_df.rename(columns={
'model_id': 'Model ID',
'accuracy': 'Avg. Accuracy (%)',
'sample_count': 'Total Samples',
'timestamp': 'Date'
}, inplace=True)
leaderboard_df['Avg. Accuracy (%)'] = leaderboard_df['Avg. Accuracy (%)'].map('{:.2f}'.format)
leaderboard_df['Date'] = leaderboard_df['Date'].dt.strftime('%Y-%m-%d')
progress(1, desc="Done.")
return leaderboard_df[['Rank', 'Model ID', 'Avg. Accuracy (%)', 'Total Samples', 'Date']]
except Exception as e:
gr.Error(f"Error loading leaderboard: {e}")
traceback.print_exc()
return pd.DataFrame(columns=["Rank", "Model ID", "Avg. Accuracy (%)", "Total Samples", "Date"])
# --- Gradio Interface Definition ---
custom_css = """
/* --- Global & Layout (Bigger to fit screen) --- */
body { font-family: 'Inter', sans-serif; background-color: #1a1a1a; color: #f0f0f0; } /* Dark background, light text */
.gradio-container { max-width: 95% !important; margin: auto; padding: 20px; } /* Wider container */
.gr-group { border-radius: 12px !important; box-shadow: 0 4px 12px rgba(0,0,0,0.3) !important; border: 1px solid #333 !important; background-color: #2a2a2a; }
.gr-panel { border-radius: 12px !important; box-shadow: 0 4px 12px rgba(0,0,0,0.3) !important; border: 1px solid #333 !important; background-color: #2a2a2a; }
/* --- Typography (Orange Hues) --- */
h1 { text-align: center; font-size: 3rem !important; font-weight: 800; color: #ff8c00; margin-bottom: 0.5rem; letter-spacing: -1.5px; } /* Orange title */
h3, h4 { color: #ffa500; } /* Orange headings */
.subtitle { text-align: center; color: #cccccc; font-size: 1.2rem; margin-bottom: 2.5rem; max-width: 900px; margin-left: auto; margin-right: auto;}
label { color: #f0f0f0 !important; } /* Label text color */
/* --- Progress Bar --- */
.progress-container { background-color: #3a3a3a; border-radius: 8px; overflow: hidden; border: 1px solid #555; height: 28px; padding: 4px; }
.progress-bar { background: linear-gradient(90deg, #ff8c00, #ffa500); height: 100%; border-radius: 5px; transition: width 0.3s ease-in-out; display: flex; align-items: center; justify-content: center; color: #1a1a1a; font-weight: 600; font-size: 0.9rem; }
/* --- Tabs --- */
.gradio-tabs { background-color: #2a2a2a; border-radius: 12px; }
.gradio-tabs button { background-color: #3a3a3a !important; color: #f0f0f0 !important; border-radius: 8px 8px 0 0 !important; transition: all 0.3s ease; }
.gradio-tabs button.selected { background-color: #ff8c00 !important; color: #1a1a1a !important; font-weight: 700; }
/* --- Inputs --- */
.gr-textbox, .gr-dropdown, .gr-slider { background-color: #3a3a3a !important; color: #f0f0f0 !important; border: 1px solid #555 !important; border-radius: 8px !important; }
/* --- Buttons --- */
.gr-button-primary { background-color: #ff8c00 !important; color: #1a1a1a !important; box-shadow: 0 4px 10px rgba(255, 140, 0, 0.3); border: none; }
.gr-button-primary:hover { transform: translateY(-2px); box-shadow: 0 6px 15px rgba(255, 140, 0, 0.5); background-color: #ffa500 !important; }
/* --- Dataframe / Table Styling --- */
.leaderboard-table .gr-dataframe thead th { background-color: #3a3a3a !important; color: #ffa500 !important; font-weight: 600 !important; text-align: left; padding: 12px 15px; border-bottom: 2px solid #555; }
.leaderboard-table .gr-dataframe tbody tr:nth-of-type(even) { background-color: #2f2f2f; }
.leaderboard-table .gr-dataframe tbody tr:hover { background-color: #4a4a4a; }
.leaderboard-table .gr-dataframe tbody td { padding: 12px 15px; border-bottom: 1px solid #3a3a3a; color: #f0f0f0; }
/* --- Error & Result Panes --- */
#error-display-box { background-color: #4a1e1e !important; border-color: #8c2f2f !important; color: #ffc9c9 !important; }
#result-summary-box { background-color: #1e3a2a !important; border-color: #2f8c4a !important; color: #c9ffc9 !important; }
.gr-markdown p { color: #f0f0f0 !important; } .gr-markdown strong { color: #ffa500 !important; }
.gradio-message { background-color: #ff8c00 !important; color: #1a1a1a !important; border: 1px solid #ff8c00 !important; }
"""
with gr.Blocks(theme=gr.themes.Base(), css=custom_css) as demo:
gr.Markdown("<h1>πŸ† SuperBench Eval: Evaluate models and view leaderboards πŸ†</h1>")
gr.Markdown("<p class='subtitle'>Benchmark leading models on MMLU. Your results contribute to a live leaderboard. Select a benchmark and run an evaluation, or view the current standings.</p>")
with gr.Tabs() as tabs:
# --- Leaderboard Tab ---
with gr.TabItem("πŸ“Š Leaderboard", id=0):
with gr.Column():
with gr.Row():
leaderboard_type_toggle = gr.Radio(
["MMLU"], label="Select Benchmark", value="MMLU", interactive=True
)
refresh_button = gr.Button("πŸ”„ Refresh", size="sm")
leaderboard_table_output = gr.DataFrame(
headers=["Rank", "Model ID", "Avg. Accuracy (%)", "Total Samples", "Date"],
interactive=False, datatype=["number", "str", "str", "number", "str"],
row_count=15, elem_classes="leaderboard-table",
)
# --- Evaluation Tab ---
with gr.TabItem("πŸš€ Run Evaluation", id=1):
with gr.Row(variant='panel'):
with gr.Column(scale=2):
with gr.Group():
gr.Markdown("### 1. Configure Evaluation")
model_id_input = gr.Textbox(
label="Hugging Face Model ID", placeholder="e.g., meta-llama/Meta-Llama-3-8B-Instruct",
interactive=True, scale=2
)
benchmark_selection_radio = gr.Radio(
["MMLU"], label="Benchmark", value="MMLU", interactive=True
)
with gr.Row():
benchmark_subject_dropdown = gr.Dropdown(
label="Subject", choices=ALL_BENCHMARK_SUBJECTS.get("MMLU", []),
value="ALL", interactive=True
)
sample_count_slider = gr.Slider(
label="Samples per Subject", minimum=5, maximum=100, value=10, step=5, interactive=True
)
run_button = gr.Button("Start Evaluation", variant="primary", scale=1)
with gr.Column(scale=3):
gr.Markdown("### 2. View Results")
# NEW: Progress Bar UI
with gr.Group(visible=False) as progress_box:
progress_text_output = gr.Markdown("Starting...")
progress_bar_output = gr.HTML(make_progress_html("Waiting...", 0))
# Panel for displaying the summary of results
with gr.Group(visible=False) as result_summary_box:
result_summary_output = gr.Markdown(elem_id="result-summary-box")
# Panel for displaying errors
with gr.Group(visible=False) as error_box:
error_output = gr.Textbox(label="Error Message", interactive=False, elem_id="error-display-box")
error_details_output = gr.Textbox(label="Error Details (Traceback)", interactive=False, lines=8)
# Panel for detailed, row-by-row results
with gr.Group(visible=False) as details_box:
gr.Markdown("#### Detailed Evaluation Log")
detailed_results_df = gr.DataFrame(
headers=["Question", "Correct", "Expected", "Predicted", "Model Output"],
datatype=["str", "str", "str", "str", "str"],
interactive=False, row_count=10, wrap=True,
)
# --- Event Handlers & Logic ---
benchmark_selection_radio.change(
fn=update_subject_dropdown,
inputs=[benchmark_selection_radio],
outputs=[benchmark_subject_dropdown]
)
# Main evaluation trigger, now handles a generator for progress updates
run_button.click(
fn=run_evaluation,
inputs=[model_id_input, benchmark_selection_radio, benchmark_subject_dropdown, sample_count_slider],
outputs=[
progress_box, progress_text_output, progress_bar_output,
result_summary_box, result_summary_output,
error_box, error_output, error_details_output,
details_box, detailed_results_df
]
).then(
# After evaluation, refresh the leaderboard
load_leaderboard, inputs=[leaderboard_type_toggle], outputs=[leaderboard_table_output]
)
# --- Leaderboard Loading Logic ---
demo.load(
fn=load_leaderboard,
inputs=[leaderboard_type_toggle],
outputs=[leaderboard_table_output]
)
leaderboard_type_toggle.change(
fn=load_leaderboard,
inputs=[leaderboard_type_toggle],
outputs=[leaderboard_table_output],
show_progress='minimal'
)
refresh_button.click(
fn=load_leaderboard,
inputs=[leaderboard_type_toggle],
outputs=[leaderboard_table_output],
show_progress='full'
)
if __name__ == "__main__":
demo.launch(debug=True)