File size: 4,678 Bytes
8474f02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from .base_benchmark import BaseBenchmark
from typing import Dict, Any, Optional, Tuple
from datasets import load_dataset
import re
from .evaluation_utils import normalize_math_answer, is_math_equiv

class MATHBenchmark(BaseBenchmark):
    """MATH (Mathematics) benchmark for competition-level problems"""
    
    LEVELS = ['Level 1', 'Level 2', 'Level 3', 'Level 4', 'Level 5']
    TYPES = ['Algebra', 'Counting & Probability', 'Geometry', 'Intermediate Algebra',
             'Number Theory', 'Prealgebra', 'Precalculus']
    
    def __init__(self):
        super().__init__(name="MATH", dataset_name="hendrycks/competition_math")
        
    async def load_dataset(self, sample_size: Optional[int] = None, **kwargs):
        """Load MATH dataset"""
        dataset = load_dataset(self.dataset_name, split='test')
        
        # Filter by difficulty level if specified
        difficulty_levels = kwargs.get('difficulty', ['all'])
        if 'all' not in difficulty_levels:
            dataset = dataset.filter(lambda x: x['level'] in difficulty_levels)
        
        self.dataset = []
        for sample in dataset:
            self.dataset.append({
                'problem': sample['problem'],
                'solution': sample['solution'],
                'level': sample['level'],
                'type': sample['type'],
                'raw_sample': sample
            })
        
        # Shuffle dataset
        import random
        random.shuffle(self.dataset)
        
        if sample_size and len(self.dataset) > sample_size:
            self.dataset = self.dataset[:sample_size]
    
    def extract_answer(self, solution: str) -> Optional[str]:
        """Extract the final answer from MATH solution using lm-eval's method"""
        # Find all boxed content
        boxed_matches = re.findall(r'\\boxed\{([^{}]*)\}', solution)
        fbox_matches = re.findall(r'\\fbox\{([^{}]*)\}', solution)
        
        all_matches = boxed_matches + fbox_matches
        
        if all_matches:
            # Return the last boxed answer
            return all_matches[-1].strip()
        
        return None
    
    def extract_model_answer(self, response: str) -> Optional[str]:
        """Extract answer from model response"""
        # Try to find boxed answer first
        answer = self.extract_answer(response)
        if answer:
            return answer
        
        # If no boxed answer, look for common patterns
        # "The answer is X"
        match = re.search(r'answer is[\s:]*([^.\n]+)', response, re.IGNORECASE)
        if match:
            return match.group(1).strip()
        
        # "Therefore, X"
        match = re.search(r'therefore[,\s]+([^.\n]+)', response, re.IGNORECASE)
        if match:
            return match.group(1).strip()
        
        return None
    
    def format_prompt(self, sample: Dict[str, Any]) -> str:
        """Format MATH problem as prompt"""
        prompt = f"""Solve the following mathematics problem step by step. Show all your work and put your final answer in the format \\boxed{{answer}}.

Problem: {sample['problem']}

Solution:"""
        return prompt
    
    async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]:
        """Evaluate a single MATH sample"""
        prompt = self.format_prompt(sample)
        
        try:
            response = await api.generate_with_retry(prompt, **kwargs)
            
            # Extract correct answer
            correct_answer = self.extract_answer(sample['solution'])
            
            # Extract model's answer
            model_answer = self.extract_model_answer(response)
            
            # Compare answers using mathematical equivalence
            is_correct = False
            if correct_answer and model_answer:
                # Use the official equivalence checking
                is_correct = is_math_equiv(model_answer, correct_answer)
            
            result = {
                'problem': sample['problem'],
                'level': sample['level'],
                'type': sample['type'],
                'correct_answer': correct_answer,
                'model_answer': model_answer,
                'model_response': response,
                'is_correct': is_correct
            }
            
            return is_correct, result
            
        except Exception as e:
            result = {
                'problem': sample['problem'],
                'level': sample['level'],
                'type': sample['type'],
                'error': str(e),
                'is_correct': False
            }
            return False, result