File size: 5,613 Bytes
8474f02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from .base_benchmark import BaseBenchmark
from typing import Dict, Any, Optional, Tuple, List
from datasets import load_dataset
import re
import random
from .prompt_templates import get_mmlu_prompt
from .evaluation_utils import extract_answer_mmlu

class MMLUBenchmark(BaseBenchmark):
    """MMLU (Massive Multitask Language Understanding) benchmark"""
    
    SUBJECTS = [
        'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics',
        'clinical_knowledge', 'college_biology', 'college_chemistry',
        'college_computer_science', 'college_mathematics', 'college_medicine',
        'college_physics', 'computer_security', 'conceptual_physics',
        'econometrics', 'electrical_engineering', 'elementary_mathematics',
        'formal_logic', 'global_facts', 'high_school_biology',
        'high_school_chemistry', 'high_school_computer_science',
        'high_school_european_history', 'high_school_geography',
        'high_school_government_and_politics', 'high_school_macroeconomics',
        'high_school_mathematics', 'high_school_microeconomics',
        'high_school_physics', 'high_school_psychology', 'high_school_statistics',
        'high_school_us_history', 'high_school_world_history', 'human_aging',
        'human_sexuality', 'international_law', 'jurisprudence',
        'logical_fallacies', 'machine_learning', 'management', 'marketing',
        'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios',
        'nutrition', 'philosophy', 'prehistory', 'professional_accounting',
        'professional_law', 'professional_medicine', 'professional_psychology',
        'public_relations', 'security_studies', 'sociology', 'us_foreign_policy',
        'virology', 'world_religions'
    ]
    
    def __init__(self):
        super().__init__(name="MMLU", dataset_name="cais/mmlu")
        
    async def load_dataset(self, sample_size: Optional[int] = None, **kwargs):
        """Load MMLU dataset"""
        subjects = kwargs.get('subjects', ['all'])
        
        if 'all' in subjects:
            subjects = self.SUBJECTS
        else:
            subjects = [s for s in subjects if s in self.SUBJECTS]
        
        self.dataset = []
        self.few_shot_examples = {}  # Store few-shot examples per subject
        
        for subject in subjects:
            try:
                # Load dev split for few-shot examples
                dev_ds = load_dataset(self.dataset_name, subject, split='dev')
                # Standard MMLU uses 5-shot
                self.few_shot_examples[subject] = [
                    {
                        'question': ex['question'],
                        'choices': ex['choices'],
                        'answer': ex['answer']
                    }
                    for ex in list(dev_ds)[:5]
                ]
                
                # Load test split for evaluation
                test_ds = load_dataset(self.dataset_name, subject, split='test')
                
                for sample in test_ds:
                    self.dataset.append({
                        'subject': subject,
                        'question': sample['question'],
                        'choices': sample['choices'],
                        'answer': sample['answer'],  # 0-3 index
                        'raw_sample': sample
                    })
            except Exception as e:
                print(f"Error loading {subject}: {e}")
                continue
        
        # Shuffle dataset
        random.shuffle(self.dataset)
        
        if sample_size and len(self.dataset) > sample_size:
            self.dataset = self.dataset[:sample_size]
    
    def format_prompt(self, sample: Dict[str, Any]) -> str:
        """Format MMLU question as prompt with few-shot examples"""
        subject = sample['subject']
        few_shot_examples = self.few_shot_examples.get(subject, [])
        
        return get_mmlu_prompt(
            question=sample['question'],
            choices=sample['choices'],
            subject=subject.replace('_', ' ').title(),
            few_shot_examples=few_shot_examples
        )
    
    async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]:
        """Evaluate a single MMLU sample"""
        prompt = self.format_prompt(sample)
        
        try:
            response = await api.generate_with_retry(prompt, **kwargs)
            
            # Extract answer from response using standard extraction
            predicted_letter = extract_answer_mmlu(response)
            
            if predicted_letter:
                predicted_index = ord(predicted_letter) - ord('A')
            else:
                # If no clear answer, mark as incorrect
                predicted_index = -1
            
            correct_index = sample['answer']
            is_correct = predicted_index == correct_index
            
            result = {
                'subject': sample['subject'],
                'question': sample['question'],
                'choices': sample['choices'],
                'correct_answer': correct_index,
                'predicted_answer': predicted_index,
                'model_response': response,
                'is_correct': is_correct
            }
            
            return is_correct, result
            
        except Exception as e:
            result = {
                'subject': sample['subject'],
                'question': sample['question'],
                'error': str(e),
                'is_correct': False
            }
            return False, result