File size: 4,524 Bytes
8474f02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from .base_benchmark import BaseBenchmark
from typing import Dict, Any, Optional, Tuple
from datasets import load_dataset
import re
from .prompt_templates import get_gsm8k_cot_prompt

class GSM8KBenchmark(BaseBenchmark):
    """GSM8K (Grade School Math 8K) benchmark"""
    
    def __init__(self):
        super().__init__(name="GSM8K", dataset_name="gsm8k")
        
    async def load_dataset(self, sample_size: Optional[int] = None, **kwargs):
        """Load GSM8K dataset"""
        dataset = load_dataset(self.dataset_name, 'main', split='test')
        
        self.dataset = []
        for sample in dataset:
            self.dataset.append({
                'question': sample['question'],
                'answer': sample['answer'],
                'raw_sample': sample
            })
        
        # Shuffle dataset
        import random
        random.shuffle(self.dataset)
        
        if sample_size and len(self.dataset) > sample_size:
            self.dataset = self.dataset[:sample_size]
    
    def extract_answer_from_solution(self, solution: str) -> Optional[str]:
        """Extract numerical answer from GSM8K solution string"""
        # GSM8K answers are in format: "... #### number"
        match = re.search(r'#### ([\-0-9\.\,]+)', solution)
        if match:
            answer_str = match.group(1).replace(',', '')
            return answer_str
        return None
    
    def extract_number_from_response(self, response: str) -> Optional[str]:
        """Extract the final numerical answer from model response"""
        # Official lm-eval uses these patterns in order:
        
        # 1. Look for "The answer is X" pattern (CoT standard)
        match = re.search(r'The answer is ([\-0-9\.\,]+)\.?', response, re.IGNORECASE)
        if match:
            return match.group(1).replace(',', '')
        
        # 2. Look for #### format (if model knows GSM8K format)
        match = re.search(r'#### ([\-0-9\.\,]+)', response)
        if match:
            return match.group(1).replace(',', '')
        
        # 3. Flexible extraction: find all numbers and take the last one
        # This matches lm-eval's flexible-extract with group_select: -1
        numbers = re.findall(r'(-?[$0-9.,]{2,})|(-?[0-9]+)', response)
        if numbers:
            # Flatten tuples and get last non-empty match
            flat_numbers = [n for group in numbers for n in group if n]
            if flat_numbers:
                last_number = flat_numbers[-1]
                # Clean the number
                cleaned = last_number.replace('$', '').replace(',', '')
                try:
                    # Validate it's a proper number
                    float(cleaned)
                    return cleaned
                except:
                    pass
        
        return None
    
    def format_prompt(self, sample: Dict[str, Any]) -> str:
        """Format GSM8K question as prompt with CoT examples"""
        # Use the standard CoT prompt from lm-eval
        return get_gsm8k_cot_prompt(sample['question'])
    
    async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]:
        """Evaluate a single GSM8K sample"""
        prompt = self.format_prompt(sample)
        
        try:
            response = await api.generate_with_retry(prompt, **kwargs)
            
            # Extract correct answer
            correct_answer = self.extract_answer_from_solution(sample['answer'])
            
            # Extract model's answer
            model_answer = self.extract_number_from_response(response)
            
            # Check if answers match (exact string match after normalization)
            is_correct = False
            if correct_answer is not None and model_answer is not None:
                # GSM8K uses exact match on normalized strings
                is_correct = correct_answer == model_answer
            
            result = {
                'question': sample['question'],
                'correct_answer': correct_answer,
                'model_answer': model_answer,
                'model_response': response,
                'is_correct': is_correct,
                'solution': sample['answer']
            }
            
            return is_correct, result
            
        except Exception as e:
            result = {
                'question': sample['question'],
                'error': str(e),
                'is_correct': False
            }
            return False, result