Spaces:
Running
Running
File size: 4,667 Bytes
8474f02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
import asyncio
from dataclasses import dataclass
import time
from tqdm import tqdm
@dataclass
class BenchmarkResult:
"""Container for benchmark results"""
benchmark_name: str
model_name: str
total_questions: int
correct: int
accuracy: float
avg_response_time: float
raw_results: List[Dict[str, Any]]
class BaseBenchmark(ABC):
"""Base class for all benchmark implementations"""
def __init__(self, name: str, dataset_name: str = None):
self.name = name
self.dataset_name = dataset_name or name
self.dataset = None
self.results = []
@abstractmethod
async def load_dataset(self, sample_size: Optional[int] = None, **kwargs):
"""Load the benchmark dataset"""
pass
@abstractmethod
async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]:
"""Evaluate a single sample"""
pass
@abstractmethod
def format_prompt(self, sample: Dict[str, Any]) -> str:
"""Format the prompt for the model"""
pass
async def run_benchmark(self, api, sample_size: Optional[int] = None, **kwargs) -> BenchmarkResult:
"""Run the benchmark on the given API"""
print(f"Running {self.name} benchmark on {api.model_name}...")
# Load dataset
await self.load_dataset(sample_size, **kwargs)
if not self.dataset:
raise ValueError(f"No dataset loaded for {self.name}")
# Prepare samples
samples = self.dataset if sample_size is None else self.dataset[:sample_size]
total_samples = len(samples)
# Run evaluation
correct_count = 0
response_times = []
raw_results = []
# Use async semaphore for concurrent requests
concurrent_limit = kwargs.get('concurrent_requests', 5)
semaphore = asyncio.Semaphore(concurrent_limit)
async def evaluate_with_semaphore(sample, idx):
async with semaphore:
start_time = time.time()
is_correct, result = await self.evaluate_sample(api, sample, **kwargs)
end_time = time.time()
result['response_time'] = end_time - start_time
result['index'] = idx
return is_correct, result
# Create tasks for all samples
tasks = [evaluate_with_semaphore(sample, idx) for idx, sample in enumerate(samples)]
# Run with progress bar
# Add imports needed for progress saving
import json
import os
with tqdm(total=total_samples, desc=f"{self.name}") as pbar:
for coro in asyncio.as_completed(tasks):
is_correct, result = await coro
if is_correct:
correct_count += 1
response_times.append(result['response_time'])
raw_results.append(result)
pbar.update(1)
# --- START: REAL-TIME PROGRESS SAVING ---
# Every 10 samples, save the progress to a file
if pbar.n > 0 and pbar.n % 10 == 0:
# Ensure results directory exists
results_dir = kwargs.get('output_dir', 'results')
os.makedirs(results_dir, exist_ok=True)
progress_path = os.path.join(results_dir, f'{self.name}_progress.json')
# Sort results by index before saving
sorted_progress = sorted(raw_results, key=lambda x: x['index'])
try:
with open(progress_path, 'w') as f:
json.dump(sorted_progress, f, indent=2)
except Exception as e:
print(f"Error saving progress: {e}")
# --- END: REAL-TIME PROGRESS SAVING ---
# Calculate metrics
accuracy = correct_count / total_samples if total_samples > 0 else 0
avg_response_time = sum(response_times) / len(response_times) if response_times else 0
return BenchmarkResult(
benchmark_name=self.name,
model_name=api.model_name,
total_questions=total_samples,
correct=correct_count,
accuracy=accuracy,
avg_response_time=avg_response_time,
raw_results=sorted(raw_results, key=lambda x: x['index'])
) |