File size: 4,934 Bytes
8474f02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Evaluation utilities matching standard implementations"""

import re
from typing import Optional, Union
import numpy as np
try:
    import sympy
    from sympy.parsing.latex import parse_latex
    SYMPY_AVAILABLE = True
except ImportError:
    SYMPY_AVAILABLE = False

def normalize_math_answer(answer: str) -> str:
    """Normalize mathematical answers following lm-eval's approach"""
    if not answer:
        return ""
    
    # Extract content after equals sign
    if '=' in answer:
        answer = answer.split('=')[-1]
    
    # Remove dollar signs and spaces
    answer = answer.strip()
    answer = answer.strip('$')
    
    # Remove text{} and textbf{}
    answer = re.sub(r'\\text\{([^}]*)\}', r'\1', answer)
    answer = re.sub(r'\\textbf\{([^}]*)\}', r'\1', answer)
    
    # Fix \fracab -> \frac{a}{b}
    answer = re.sub(r'\\frac([0-9a-zA-Z])([0-9a-zA-Z])', r'\\frac{\1}{\2}', answer)
    
    # Remove commas from numbers
    answer = re.sub(r'(\d),', r'\1', answer)
    
    # Remove specific words
    for word in ['square', 'units', 'integers', 'dollars', 'mph', 'inches', 'feet', 'minutes', 'cm', 'gm', 'pounds', 'meters', 'meals', 'edges', 'students', 'childrentickets', 'multiples', 'hours', 'degrees', 'ounces', 'bits', 'factorization', 'greenmarbles', 'redmarbles', 'bluemarbles']:
        answer = answer.replace(word, '')
    
    # Remove extra spaces
    answer = ' '.join(answer.split())
    
    return answer.strip()

def extract_answer_gsm8k(response: str) -> Optional[float]:
    """Extract answer from GSM8K response following official format"""
    # Look for the last number in the response
    numbers = re.findall(r'[-+]?\d*\.?\d+', response)
    if numbers:
        try:
            return float(numbers[-1])
        except:
            pass
    return None

def extract_answer_mmlu(response: str) -> Optional[str]:
    """Extract MMLU answer following official format"""
    # Clean response
    response = response.strip()
    
    # Look for single letter answer
    if len(response) == 1 and response in 'ABCD':
        return response
    
    # Look for letter followed by parenthesis or period
    match = re.search(r'^([ABCD])[).\s]', response)
    if match:
        return match.group(1)
    
    # Look for "answer is X" pattern
    match = re.search(r'answer is ([ABCD])', response, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    # Look for first occurrence of A, B, C, or D
    match = re.search(r'[ABCD]', response)
    if match:
        return match.group(0)
    
    return None

def calculate_accuracy_with_confidence(results: list) -> dict:
    """Calculate accuracy with confidence intervals"""
    correct = sum(1 for r in results if r.get('is_correct', False))
    total = len(results)
    
    if total == 0:
        return {
            'accuracy': 0.0,
            'correct': 0,
            'total': 0,
            'confidence_interval': (0.0, 0.0)
        }
    
    accuracy = correct / total
    
    # Wilson score interval for binomial proportion
    z = 1.96  # 95% confidence
    n = total
    p = accuracy
    
    denominator = 1 + z**2 / n
    center = (p + z**2 / (2*n)) / denominator
    margin = z * np.sqrt(p * (1-p) / n + z**2 / (4*n**2)) / denominator
    
    lower = max(0, center - margin)
    upper = min(1, center + margin)
    
    return {
        'accuracy': accuracy,
        'correct': correct,
        'total': total,
        'confidence_interval': (lower, upper)
    }

def is_math_equiv(pred: str, gold: str) -> bool:
    """Check mathematical equivalence using SymPy (matching lm-eval)"""
    # First normalize both answers
    pred_norm = normalize_math_answer(pred)
    gold_norm = normalize_math_answer(gold)
    
    # Quick string comparison
    if pred_norm == gold_norm:
        return True
    
    if not SYMPY_AVAILABLE:
        # Fallback to string comparison
        return pred_norm == gold_norm
    
    try:
        # Try to parse as LaTeX
        try:
            pred_expr = parse_latex(pred_norm)
            gold_expr = parse_latex(gold_norm)
        except:
            # Try parsing as regular SymPy expression
            pred_expr = sympy.sympify(pred_norm)
            gold_expr = sympy.sympify(gold_norm)
        
        # Check if expressions are equivalent
        diff = sympy.simplify(pred_expr - gold_expr)
        return diff == 0 or diff.is_zero
        
    except Exception:
        # If parsing fails, fall back to string comparison
        return pred_norm == gold_norm

def is_gsm8k_correct(pred: str, gold: str) -> bool:
    """Check GSM8K answer correctness"""
    if pred == gold:
        return True
    
    try:
        # Try numeric comparison
        pred_num = float(pred)
        gold_num = float(gold)
        # GSM8K uses exact match, but we allow tiny floating point errors
        return abs(pred_num - gold_num) < 1e-9
    except:
        return False