Spaces:
Running
Running
File size: 5,613 Bytes
8474f02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from .base_benchmark import BaseBenchmark
from typing import Dict, Any, Optional, Tuple, List
from datasets import load_dataset
import re
import random
from .prompt_templates import get_mmlu_prompt
from .evaluation_utils import extract_answer_mmlu
class MMLUBenchmark(BaseBenchmark):
"""MMLU (Massive Multitask Language Understanding) benchmark"""
SUBJECTS = [
'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics',
'clinical_knowledge', 'college_biology', 'college_chemistry',
'college_computer_science', 'college_mathematics', 'college_medicine',
'college_physics', 'computer_security', 'conceptual_physics',
'econometrics', 'electrical_engineering', 'elementary_mathematics',
'formal_logic', 'global_facts', 'high_school_biology',
'high_school_chemistry', 'high_school_computer_science',
'high_school_european_history', 'high_school_geography',
'high_school_government_and_politics', 'high_school_macroeconomics',
'high_school_mathematics', 'high_school_microeconomics',
'high_school_physics', 'high_school_psychology', 'high_school_statistics',
'high_school_us_history', 'high_school_world_history', 'human_aging',
'human_sexuality', 'international_law', 'jurisprudence',
'logical_fallacies', 'machine_learning', 'management', 'marketing',
'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios',
'nutrition', 'philosophy', 'prehistory', 'professional_accounting',
'professional_law', 'professional_medicine', 'professional_psychology',
'public_relations', 'security_studies', 'sociology', 'us_foreign_policy',
'virology', 'world_religions'
]
def __init__(self):
super().__init__(name="MMLU", dataset_name="cais/mmlu")
async def load_dataset(self, sample_size: Optional[int] = None, **kwargs):
"""Load MMLU dataset"""
subjects = kwargs.get('subjects', ['all'])
if 'all' in subjects:
subjects = self.SUBJECTS
else:
subjects = [s for s in subjects if s in self.SUBJECTS]
self.dataset = []
self.few_shot_examples = {} # Store few-shot examples per subject
for subject in subjects:
try:
# Load dev split for few-shot examples
dev_ds = load_dataset(self.dataset_name, subject, split='dev')
# Standard MMLU uses 5-shot
self.few_shot_examples[subject] = [
{
'question': ex['question'],
'choices': ex['choices'],
'answer': ex['answer']
}
for ex in list(dev_ds)[:5]
]
# Load test split for evaluation
test_ds = load_dataset(self.dataset_name, subject, split='test')
for sample in test_ds:
self.dataset.append({
'subject': subject,
'question': sample['question'],
'choices': sample['choices'],
'answer': sample['answer'], # 0-3 index
'raw_sample': sample
})
except Exception as e:
print(f"Error loading {subject}: {e}")
continue
# Shuffle dataset
random.shuffle(self.dataset)
if sample_size and len(self.dataset) > sample_size:
self.dataset = self.dataset[:sample_size]
def format_prompt(self, sample: Dict[str, Any]) -> str:
"""Format MMLU question as prompt with few-shot examples"""
subject = sample['subject']
few_shot_examples = self.few_shot_examples.get(subject, [])
return get_mmlu_prompt(
question=sample['question'],
choices=sample['choices'],
subject=subject.replace('_', ' ').title(),
few_shot_examples=few_shot_examples
)
async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]:
"""Evaluate a single MMLU sample"""
prompt = self.format_prompt(sample)
try:
response = await api.generate_with_retry(prompt, **kwargs)
# Extract answer from response using standard extraction
predicted_letter = extract_answer_mmlu(response)
if predicted_letter:
predicted_index = ord(predicted_letter) - ord('A')
else:
# If no clear answer, mark as incorrect
predicted_index = -1
correct_index = sample['answer']
is_correct = predicted_index == correct_index
result = {
'subject': sample['subject'],
'question': sample['question'],
'choices': sample['choices'],
'correct_answer': correct_index,
'predicted_answer': predicted_index,
'model_response': response,
'is_correct': is_correct
}
return is_correct, result
except Exception as e:
result = {
'subject': sample['subject'],
'question': sample['question'],
'error': str(e),
'is_correct': False
}
return False, result |