Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from datasets import load_dataset, get_dataset_config_names | |
import torch | |
import re | |
import json | |
import pandas as pd | |
import traceback | |
import spaces | |
from datetime import datetime | |
# --- Environment and Caching --- | |
# It's good practice to ensure the cache directory exists. | |
CACHE_DIR = "evaluation_cache" | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
EVAL_FILE = os.path.join(CACHE_DIR, "evals.jsonl") | |
# Cache to avoid reloading models and dataset configs | |
model_cache = {} | |
benchmark_subject_cache = {} | |
# Use environment variable for the Hugging Face token | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
# --- Constants for Benchmarks --- | |
MMLU_DATASET = "cais/mmlu" | |
BENCHMARK_MAP = { | |
"MMLU": MMLU_DATASET, | |
} | |
# --- Data Loading and Preparation --- | |
def get_all_benchmark_options(): | |
""" | |
Fetches and caches the available subjects (configs) for each benchmark dataset. | |
This function now populates a global cache to avoid repeated API calls. | |
""" | |
if benchmark_subject_cache: | |
return benchmark_subject_cache | |
print("Fetching benchmark configurations for the first time...") | |
for key, dataset_id in BENCHMARK_MAP.items(): | |
try: | |
subjects = get_dataset_config_names(dataset_id, token=HF_TOKEN) | |
benchmark_subject_cache[key] = ["ALL"] + sorted([s for s in subjects if s != 'all']) | |
except Exception as e: | |
print(f"Warning: Could not load configs for {key} ({dataset_id}). It might be private or unavailable. Error: {e}") | |
benchmark_subject_cache[key] = ["ALL"] | |
print("Benchmark configurations cached.") | |
return benchmark_subject_cache | |
# Initialize the cache on startup | |
ALL_BENCHMARK_SUBJECTS = get_all_benchmark_options() | |
def load_model(model_id): | |
""" | |
Loads a Hugging Face model and tokenizer, creating a text-generation pipeline. | |
Uses a cache to avoid reloading models. | |
""" | |
if not model_id: | |
raise ValueError("Model ID cannot be empty.") | |
gr.Info(f"Attempting to load model: {model_id}...") | |
if model_id in model_cache: | |
gr.Info(f"Model '{model_id}' found in cache.") | |
return model_cache[model_id] | |
try: | |
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32 | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=HF_TOKEN, | |
torch_dtype=dtype, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
).to("cuda" if torch.cuda.is_available() else "cpu") | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
model_cache[model_id] = generator | |
gr.Info(f"Model '{model_id}' loaded successfully.") | |
return generator | |
except Exception as e: | |
raise RuntimeError(f"Failed to load model '{model_id}'. Please verify the model ID and your Hugging Face token. Error: {e}") | |
# --- Evaluation Logic --- | |
def format_prompt(item): | |
"""Formats the MMLU question and choices into a standardized prompt.""" | |
prompt = f"Question: {item['question']}\n\nChoices:\nA. {item['choices'][0]}\nB. {item['choices'][1]}\nC. {item['choices'][2]}\nD. {item['choices'][3]}\n\nAnswer:" | |
return prompt, item['answer'] | |
def get_choice_letter(index): | |
"""Converts a numerical choice index (0-3) to a letter (A-D).""" | |
return chr(ord('A') + index) if 0 <= index <= 3 else None | |
def extract_predicted_letter(output_text): | |
"""Extracts the predicted letter from the model's output.""" | |
match = re.search(r"Answer:\s*([ABCD])", output_text.strip(), re.IGNORECASE) | |
if match: | |
return match.group(1).upper() | |
match = re.search(r"^\s*([ABCD])\b", output_text.strip()) | |
if match: | |
return match.group(1).upper() | |
return None | |
def make_progress_html(text, percentage): | |
"""Helper function to create the HTML for the progress bar.""" | |
return f""" | |
<div class="progress-container"> | |
<div class="progress-bar" style="width: {percentage}%;"> | |
{text} | |
</div> | |
</div> | |
""" | |
def run_evaluation(model_id, benchmark_category, subject_name, sample_count): | |
""" | |
Main generator function to orchestrate the evaluation, yielding progress updates. | |
""" | |
try: | |
# 1. Initial yield to set up the UI for loading state | |
yield { | |
progress_box: gr.update(visible=True), | |
progress_text_output: gr.update(value=f"Preparing evaluation for **{model_id}**..."), | |
progress_bar_output: gr.update(value=make_progress_html("Loading Model...", 0)), | |
result_summary_box: gr.update(visible=False), | |
details_box: gr.update(visible=False), | |
error_box: gr.update(visible=False), | |
} | |
generator = load_model(model_id) | |
dataset_id = BENCHMARK_MAP.get(benchmark_category) | |
if not dataset_id: | |
raise ValueError(f"Invalid benchmark category: {benchmark_category}") | |
subjects_to_run = [] | |
if subject_name == "ALL": | |
subjects_to_run = [s for s in ALL_BENCHMARK_SUBJECTS.get(benchmark_category, []) if s != "ALL"] | |
else: | |
subjects_to_run = [subject_name] | |
if not subjects_to_run: | |
gr.Warning(f"No subjects found for '{benchmark_category}'.") | |
yield { progress_box: gr.update(visible=False) } | |
return | |
all_results_details = [] | |
summary_lines = [] | |
total_correct = 0 | |
total_samples = 0 | |
# 2. Main evaluation loop | |
for i, subject in enumerate(subjects_to_run): | |
overall_progress_text = f"**Overall Progress ({i+1}/{len(subjects_to_run)} subjects)**" | |
yield { | |
progress_text_output: gr.update(value=f"{overall_progress_text}\n\nLoading dataset for **{subject}**...") | |
} | |
try: | |
# Load dataset for the current subject | |
dataset = load_dataset(dataset_id, subject, token=HF_TOKEN, split="test") | |
num_samples = min(sample_count, len(dataset)) | |
dataset = dataset.shuffle(seed=42).select(range(num_samples)) | |
correct_predictions_subject = 0 | |
subject_details = [] | |
# Loop over samples within the subject | |
for j, item in enumerate(dataset): | |
prompt, correct_answer_idx = format_prompt(item) | |
expected_letter = get_choice_letter(correct_answer_idx) | |
full_prompt_text = generator.tokenizer.decode(generator.tokenizer.encode(prompt), skip_special_tokens=True) | |
raw_output = generator(prompt, max_new_tokens=5, do_sample=False, pad_token_id=generator.tokenizer.eos_token_id)[0]["generated_text"] | |
generated_text_only = raw_output[len(full_prompt_text):].strip() | |
predicted_letter = extract_predicted_letter(generated_text_only) | |
is_correct = (predicted_letter == expected_letter) | |
if is_correct: | |
correct_predictions_subject += 1 | |
subject_details.append({ | |
"Question": item['question'], | |
"Correct": "β " if is_correct else "β", | |
"Expected": expected_letter, | |
"Predicted": predicted_letter or "N/A", | |
"Model Output": generated_text_only | |
}) | |
# Yield progress update for each sample | |
percentage = ((j + 1) / num_samples) * 100 | |
progress_bar_text = f"Evaluating: {subject} ({j+1}/{num_samples})" | |
yield { | |
progress_bar_output: gr.update(value=make_progress_html(f"{percentage:.1f}%", percentage)), | |
progress_text_output: gr.update(value=f"{overall_progress_text}\n\n{progress_bar_text}") | |
} | |
accuracy = (correct_predictions_subject / num_samples) * 100 if num_samples > 0 else 0 | |
all_results_details.extend(subject_details) | |
total_correct += correct_predictions_subject | |
total_samples += num_samples | |
summary_lines.append(f"- **{subject}**: {accuracy:.2f}% ({correct_predictions_subject}/{num_samples})") | |
except Exception as e: | |
error_trace = traceback.format_exc() | |
gr.Error(f"Skipping {subject} due to an error: {e}") | |
summary_lines.append(f"- **{subject}**: Evaluation failed. See logs for details:\n```\n{error_trace}\n```") | |
continue | |
# 3. Final processing and result preparation | |
overall_accuracy = (total_correct / total_samples) * 100 if total_samples > 0 else 0 | |
if subject_name == "ALL": | |
result_summary = f"### Overall Average Accuracy: {overall_accuracy:.2f}%\n" | |
result_summary += f"across {total_samples:,} total samples from {len(subjects_to_run)} subjects.\n\n---\n\n**Breakdown by Subject:**\n" | |
result_summary += "\n".join(summary_lines) | |
else: | |
result_summary = f"### Accuracy for {benchmark_category} - {subject_name}: {overall_accuracy:.2f}%\n" | |
result_summary += f"({total_correct:,}/{total_samples:,} correct)" | |
# Write final result to the JSONL file | |
record = { | |
"model_id": model_id, | |
"benchmark": benchmark_category, | |
"accuracy": overall_accuracy, | |
"subject": subject_name, | |
"sample_count": total_samples, | |
"timestamp": datetime.now().isoformat() | |
} | |
with open(EVAL_FILE, "a") as f: | |
f.write(json.dumps(record) + "\n") | |
gr.Info("Evaluation completed successfully!") | |
df_details = pd.DataFrame(all_results_details) | |
# 4. Final yield to show results and hide progress UI | |
yield { | |
progress_box: gr.update(visible=False), | |
result_summary_box: gr.update(visible=True), | |
result_summary_output: gr.update(value=result_summary), | |
details_box: gr.update(visible=True), | |
detailed_results_df: gr.update(value=df_details), | |
error_box: gr.update(visible=False) | |
} | |
except Exception as e: | |
error_message = f"An unexpected error occurred: {e}" | |
error_details = traceback.format_exc() | |
gr.Error(error_message) | |
# Yield to show error message and hide progress UI | |
yield { | |
progress_box: gr.update(visible=False), | |
result_summary_box: gr.update(visible=False), | |
details_box: gr.update(visible=False), | |
error_box: gr.update(visible=True), | |
error_output: gr.update(value=error_message), | |
error_details_output: gr.update(value=error_details), | |
} | |
# --- UI Helper Functions --- | |
def update_subject_dropdown(benchmark_category): | |
"""Updates the subject dropdown choices based on the selected benchmark.""" | |
choices = ALL_BENCHMARK_SUBJECTS.get(benchmark_category, []) | |
default_value = "ALL" if "ALL" in choices else (choices[0] if choices else None) | |
return gr.update(choices=choices, value=default_value) | |
def load_leaderboard(benchmark_filter, progress=gr.Progress()): | |
""" | |
Loads and processes evaluation data to display on the leaderboard. | |
""" | |
progress(0, desc="Loading Leaderboard...") | |
try: | |
if not os.path.exists(EVAL_FILE): | |
return pd.DataFrame(columns=["Rank", "Model ID", "Avg. Accuracy (%)", "Total Samples", "Date"]) | |
df = pd.read_json(EVAL_FILE, lines=True) | |
if df.empty: | |
return pd.DataFrame(columns=["Rank", "Model ID", "Avg. Accuracy (%)", "Total Samples", "Date"]) | |
df['accuracy'] = pd.to_numeric(df['accuracy'], errors='coerce') | |
df.dropna(subset=['accuracy'], inplace=True) | |
# Filter for 'ALL' subject runs for the selected benchmark | |
df_filtered = df[(df['benchmark'] == benchmark_filter) & (df['subject'] == 'ALL')].copy() | |
if df_filtered.empty: | |
return pd.DataFrame(columns=["Rank", "Model ID", "Avg. Accuracy (%)", "Total Samples", "Date"]) | |
df_filtered['timestamp'] = pd.to_datetime(df_filtered['timestamp']) | |
latest_evals = df_filtered.loc[df_filtered.groupby('model_id')['timestamp'].idxmax()].copy() | |
leaderboard_df = latest_evals.sort_values(by="accuracy", ascending=False).copy() | |
leaderboard_df.insert(0, 'Rank', range(1, len(leaderboard_df) + 1)) | |
leaderboard_df.rename(columns={ | |
'model_id': 'Model ID', | |
'accuracy': 'Avg. Accuracy (%)', | |
'sample_count': 'Total Samples', | |
'timestamp': 'Date' | |
}, inplace=True) | |
leaderboard_df['Avg. Accuracy (%)'] = leaderboard_df['Avg. Accuracy (%)'].map('{:.2f}'.format) | |
leaderboard_df['Date'] = leaderboard_df['Date'].dt.strftime('%Y-%m-%d') | |
progress(1, desc="Done.") | |
return leaderboard_df[['Rank', 'Model ID', 'Avg. Accuracy (%)', 'Total Samples', 'Date']] | |
except Exception as e: | |
gr.Error(f"Error loading leaderboard: {e}") | |
traceback.print_exc() | |
return pd.DataFrame(columns=["Rank", "Model ID", "Avg. Accuracy (%)", "Total Samples", "Date"]) | |
# --- Gradio Interface Definition --- | |
custom_css = """ | |
/* --- Global & Layout (Bigger to fit screen) --- */ | |
body { font-family: 'Inter', sans-serif; background-color: #1a1a1a; color: #f0f0f0; } /* Dark background, light text */ | |
.gradio-container { max-width: 95% !important; margin: auto; padding: 20px; } /* Wider container */ | |
.gr-group { border-radius: 12px !important; box-shadow: 0 4px 12px rgba(0,0,0,0.3) !important; border: 1px solid #333 !important; background-color: #2a2a2a; } | |
.gr-panel { border-radius: 12px !important; box-shadow: 0 4px 12px rgba(0,0,0,0.3) !important; border: 1px solid #333 !important; background-color: #2a2a2a; } | |
/* --- Typography (Orange Hues) --- */ | |
h1 { text-align: center; font-size: 3rem !important; font-weight: 800; color: #ff8c00; margin-bottom: 0.5rem; letter-spacing: -1.5px; } /* Orange title */ | |
h3, h4 { color: #ffa500; } /* Orange headings */ | |
.subtitle { text-align: center; color: #cccccc; font-size: 1.2rem; margin-bottom: 2.5rem; max-width: 900px; margin-left: auto; margin-right: auto;} | |
label { color: #f0f0f0 !important; } /* Label text color */ | |
/* --- Progress Bar --- */ | |
.progress-container { background-color: #3a3a3a; border-radius: 8px; overflow: hidden; border: 1px solid #555; height: 28px; padding: 4px; } | |
.progress-bar { background: linear-gradient(90deg, #ff8c00, #ffa500); height: 100%; border-radius: 5px; transition: width 0.3s ease-in-out; display: flex; align-items: center; justify-content: center; color: #1a1a1a; font-weight: 600; font-size: 0.9rem; } | |
/* --- Tabs --- */ | |
.gradio-tabs { background-color: #2a2a2a; border-radius: 12px; } | |
.gradio-tabs button { background-color: #3a3a3a !important; color: #f0f0f0 !important; border-radius: 8px 8px 0 0 !important; transition: all 0.3s ease; } | |
.gradio-tabs button.selected { background-color: #ff8c00 !important; color: #1a1a1a !important; font-weight: 700; } | |
/* --- Inputs --- */ | |
.gr-textbox, .gr-dropdown, .gr-slider { background-color: #3a3a3a !important; color: #f0f0f0 !important; border: 1px solid #555 !important; border-radius: 8px !important; } | |
/* --- Buttons --- */ | |
.gr-button-primary { background-color: #ff8c00 !important; color: #1a1a1a !important; box-shadow: 0 4px 10px rgba(255, 140, 0, 0.3); border: none; } | |
.gr-button-primary:hover { transform: translateY(-2px); box-shadow: 0 6px 15px rgba(255, 140, 0, 0.5); background-color: #ffa500 !important; } | |
/* --- Dataframe / Table Styling --- */ | |
.leaderboard-table .gr-dataframe thead th { background-color: #3a3a3a !important; color: #ffa500 !important; font-weight: 600 !important; text-align: left; padding: 12px 15px; border-bottom: 2px solid #555; } | |
.leaderboard-table .gr-dataframe tbody tr:nth-of-type(even) { background-color: #2f2f2f; } | |
.leaderboard-table .gr-dataframe tbody tr:hover { background-color: #4a4a4a; } | |
.leaderboard-table .gr-dataframe tbody td { padding: 12px 15px; border-bottom: 1px solid #3a3a3a; color: #f0f0f0; } | |
/* --- Error & Result Panes --- */ | |
#error-display-box { background-color: #4a1e1e !important; border-color: #8c2f2f !important; color: #ffc9c9 !important; } | |
#result-summary-box { background-color: #1e3a2a !important; border-color: #2f8c4a !important; color: #c9ffc9 !important; } | |
.gr-markdown p { color: #f0f0f0 !important; } .gr-markdown strong { color: #ffa500 !important; } | |
.gradio-message { background-color: #ff8c00 !important; color: #1a1a1a !important; border: 1px solid #ff8c00 !important; } | |
""" | |
with gr.Blocks(theme=gr.themes.Base(), css=custom_css) as demo: | |
gr.Markdown("<h1>π SuperBench Eval: Evaluate models and view leaderboards π</h1>") | |
gr.Markdown("<p class='subtitle'>Benchmark leading models on MMLU. Your results contribute to a live leaderboard. Select a benchmark and run an evaluation, or view the current standings.</p>") | |
with gr.Tabs() as tabs: | |
# --- Leaderboard Tab --- | |
with gr.TabItem("π Leaderboard", id=0): | |
with gr.Column(): | |
with gr.Row(): | |
leaderboard_type_toggle = gr.Radio( | |
["MMLU"], label="Select Benchmark", value="MMLU", interactive=True | |
) | |
refresh_button = gr.Button("π Refresh", size="sm") | |
leaderboard_table_output = gr.DataFrame( | |
headers=["Rank", "Model ID", "Avg. Accuracy (%)", "Total Samples", "Date"], | |
interactive=False, datatype=["number", "str", "str", "number", "str"], | |
row_count=15, elem_classes="leaderboard-table", | |
) | |
# --- Evaluation Tab --- | |
with gr.TabItem("π Run Evaluation", id=1): | |
with gr.Row(variant='panel'): | |
with gr.Column(scale=2): | |
with gr.Group(): | |
gr.Markdown("### 1. Configure Evaluation") | |
model_id_input = gr.Textbox( | |
label="Hugging Face Model ID", placeholder="e.g., meta-llama/Meta-Llama-3-8B-Instruct", | |
interactive=True, scale=2 | |
) | |
benchmark_selection_radio = gr.Radio( | |
["MMLU"], label="Benchmark", value="MMLU", interactive=True | |
) | |
with gr.Row(): | |
benchmark_subject_dropdown = gr.Dropdown( | |
label="Subject", choices=ALL_BENCHMARK_SUBJECTS.get("MMLU", []), | |
value="ALL", interactive=True | |
) | |
sample_count_slider = gr.Slider( | |
label="Samples per Subject", minimum=5, maximum=100, value=10, step=5, interactive=True | |
) | |
run_button = gr.Button("Start Evaluation", variant="primary", scale=1) | |
with gr.Column(scale=3): | |
gr.Markdown("### 2. View Results") | |
# NEW: Progress Bar UI | |
with gr.Group(visible=False) as progress_box: | |
progress_text_output = gr.Markdown("Starting...") | |
progress_bar_output = gr.HTML(make_progress_html("Waiting...", 0)) | |
# Panel for displaying the summary of results | |
with gr.Group(visible=False) as result_summary_box: | |
result_summary_output = gr.Markdown(elem_id="result-summary-box") | |
# Panel for displaying errors | |
with gr.Group(visible=False) as error_box: | |
error_output = gr.Textbox(label="Error Message", interactive=False, elem_id="error-display-box") | |
error_details_output = gr.Textbox(label="Error Details (Traceback)", interactive=False, lines=8) | |
# Panel for detailed, row-by-row results | |
with gr.Group(visible=False) as details_box: | |
gr.Markdown("#### Detailed Evaluation Log") | |
detailed_results_df = gr.DataFrame( | |
headers=["Question", "Correct", "Expected", "Predicted", "Model Output"], | |
datatype=["str", "str", "str", "str", "str"], | |
interactive=False, row_count=10, wrap=True, | |
) | |
# --- Event Handlers & Logic --- | |
benchmark_selection_radio.change( | |
fn=update_subject_dropdown, | |
inputs=[benchmark_selection_radio], | |
outputs=[benchmark_subject_dropdown] | |
) | |
# Main evaluation trigger, now handles a generator for progress updates | |
run_button.click( | |
fn=run_evaluation, | |
inputs=[model_id_input, benchmark_selection_radio, benchmark_subject_dropdown, sample_count_slider], | |
outputs=[ | |
progress_box, progress_text_output, progress_bar_output, | |
result_summary_box, result_summary_output, | |
error_box, error_output, error_details_output, | |
details_box, detailed_results_df | |
] | |
).then( | |
# After evaluation, refresh the leaderboard | |
load_leaderboard, inputs=[leaderboard_type_toggle], outputs=[leaderboard_table_output] | |
) | |
# --- Leaderboard Loading Logic --- | |
demo.load( | |
fn=load_leaderboard, | |
inputs=[leaderboard_type_toggle], | |
outputs=[leaderboard_table_output] | |
) | |
leaderboard_type_toggle.change( | |
fn=load_leaderboard, | |
inputs=[leaderboard_type_toggle], | |
outputs=[leaderboard_table_output], | |
show_progress='minimal' | |
) | |
refresh_button.click( | |
fn=load_leaderboard, | |
inputs=[leaderboard_type_toggle], | |
outputs=[leaderboard_table_output], | |
show_progress='full' | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |