File size: 4,667 Bytes
8474f02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
import asyncio
from dataclasses import dataclass
import time
from tqdm import tqdm

@dataclass
class BenchmarkResult:
    """Container for benchmark results"""
    benchmark_name: str
    model_name: str
    total_questions: int
    correct: int
    accuracy: float
    avg_response_time: float
    raw_results: List[Dict[str, Any]]
    
class BaseBenchmark(ABC):
    """Base class for all benchmark implementations"""
    
    def __init__(self, name: str, dataset_name: str = None):
        self.name = name
        self.dataset_name = dataset_name or name
        self.dataset = None
        self.results = []
        
    @abstractmethod
    async def load_dataset(self, sample_size: Optional[int] = None, **kwargs):
        """Load the benchmark dataset"""
        pass
    
    @abstractmethod
    async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]:
        """Evaluate a single sample"""
        pass
    
    @abstractmethod
    def format_prompt(self, sample: Dict[str, Any]) -> str:
        """Format the prompt for the model"""
        pass
    
    async def run_benchmark(self, api, sample_size: Optional[int] = None, **kwargs) -> BenchmarkResult:
        """Run the benchmark on the given API"""
        print(f"Running {self.name} benchmark on {api.model_name}...")
        
        # Load dataset
        await self.load_dataset(sample_size, **kwargs)
        
        if not self.dataset:
            raise ValueError(f"No dataset loaded for {self.name}")
        
        # Prepare samples
        samples = self.dataset if sample_size is None else self.dataset[:sample_size]
        total_samples = len(samples)
        
        # Run evaluation
        correct_count = 0
        response_times = []
        raw_results = []
        
        # Use async semaphore for concurrent requests
        concurrent_limit = kwargs.get('concurrent_requests', 5)
        semaphore = asyncio.Semaphore(concurrent_limit)
        
        async def evaluate_with_semaphore(sample, idx):
            async with semaphore:
                start_time = time.time()
                is_correct, result = await self.evaluate_sample(api, sample, **kwargs)
                end_time = time.time()
                
                result['response_time'] = end_time - start_time
                result['index'] = idx
                return is_correct, result
        
        # Create tasks for all samples
        tasks = [evaluate_with_semaphore(sample, idx) for idx, sample in enumerate(samples)]
        
        # Run with progress bar
        # Add imports needed for progress saving
        import json
        import os

        with tqdm(total=total_samples, desc=f"{self.name}") as pbar:
            for coro in asyncio.as_completed(tasks):
                is_correct, result = await coro
                
                if is_correct:
                    correct_count += 1
                
                response_times.append(result['response_time'])
                raw_results.append(result)
                pbar.update(1)

                # --- START: REAL-TIME PROGRESS SAVING ---
                # Every 10 samples, save the progress to a file
                if pbar.n > 0 and pbar.n % 10 == 0:
                    # Ensure results directory exists
                    results_dir = kwargs.get('output_dir', 'results')
                    os.makedirs(results_dir, exist_ok=True)
                    
                    progress_path = os.path.join(results_dir, f'{self.name}_progress.json')
                    # Sort results by index before saving
                    sorted_progress = sorted(raw_results, key=lambda x: x['index'])
                    try:
                        with open(progress_path, 'w') as f:
                            json.dump(sorted_progress, f, indent=2)
                    except Exception as e:
                        print(f"Error saving progress: {e}")
                # --- END: REAL-TIME PROGRESS SAVING ---
        
        # Calculate metrics
        accuracy = correct_count / total_samples if total_samples > 0 else 0
        avg_response_time = sum(response_times) / len(response_times) if response_times else 0
        
        return BenchmarkResult(
            benchmark_name=self.name,
            model_name=api.model_name,
            total_questions=total_samples,
            correct=correct_count,
            accuracy=accuracy,
            avg_response_time=avg_response_time,
            raw_results=sorted(raw_results, key=lambda x: x['index'])
        )