import logging from textwrap import dedent from typing import Callable, Optional from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase from lighteval.metrics.dynamic_metrics import SampleLevelMetric from math_verify.metric import math_metric from math_verify.few_shots import GSM8K_FEW_SHOTS, MATH_HARD_FEW_SHOTS import numpy as np from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig logger = logging.getLogger(__name__) def as_lighteval_metric( metric: Callable[ [list[str], list[str]], tuple[float, Optional[tuple[list[str], list[str]]]] ], ) -> SampleLevelMetric: def sample_level_fn( formatted_doc: Doc, golds: list[str], predictions: list[str] ) -> float: result, extracted_predictions = metric(golds, predictions) if extracted_predictions is not None: if not formatted_doc.specific: formatted_doc.specific = {} formatted_doc.specific["extracted_predictions"] = extracted_predictions return result return SampleLevelMetric( metric_name="extractive_match", sample_level_fn=sample_level_fn, category=MetricCategory.GENERATIVE, use_case=MetricUseCase.ACCURACY, corpus_level_fn=np.mean, higher_is_better=True, ) def math_hard_prompt_function(x: dict, task_name: str) -> Doc: if x.get("__few_shots"): index = x["__index"] few_shot_doc = ( MATH_HARD_FEW_SHOTS[index] if index < len(MATH_HARD_FEW_SHOTS) else MATH_HARD_FEW_SHOTS[-1] ) answer = few_shot_doc["answer"] question = few_shot_doc["question"] else: answer = str(x["solution"]) question = x["problem"] query = dedent(f"""\ Question: {question} Step-by-Step Answer:\ """).strip() choices = [answer] return Doc(query=query, choices=choices, gold_index=0) def math_prompt_function(x: dict, task_name: str) -> Doc: if x.get("__few_shots"): index = x["__index"] few_shot_doc = ( MATH_HARD_FEW_SHOTS[index] if index < len(MATH_HARD_FEW_SHOTS) else MATH_HARD_FEW_SHOTS[-1] ) answer = few_shot_doc["answer"] question = few_shot_doc["question"] else: answer = str(x["answer"]) question = x["problem"] query = dedent(f"""\ Question: {question} Step-by-Step Answer:\ """).strip() choices = [answer] return Doc(query=query, choices=choices, gold_index=0) def math_aime24_prompt_function(x: dict, task_name: str) -> Doc: if x.get("__few_shots"): index = x["__index"] few_shot_doc = ( MATH_HARD_FEW_SHOTS[index] if index < len(MATH_HARD_FEW_SHOTS) else MATH_HARD_FEW_SHOTS[-1] ) answer = few_shot_doc["answer"] question = few_shot_doc["question"] else: answer = str(x["reference_solution"]) question = x["problem"] query = dedent(f"""\ Question: {question} Step-by-Step Answer:\ """).strip() choices = [f" {answer}"] return Doc(query=query, choices=choices, gold_index=0) def math_amc23_prompt_function(x: dict, task_name: str) -> Doc: if x.get("__few_shots"): index = x["__index"] few_shot_doc = ( MATH_HARD_FEW_SHOTS[index] if index < len(MATH_HARD_FEW_SHOTS) else MATH_HARD_FEW_SHOTS[-1] ) answer = few_shot_doc["answer"] question = few_shot_doc["question"] else: answer = str(x["answer"]) question = x["question"] query = dedent(f"""\ Question: {question} Step-by-Step Answer:\ """).strip() choices = [f" {answer}"] return Doc(query=query, choices=choices, gold_index=0) def gsm8k_prompt_function(x: dict, task_name: str) -> Doc: if x.get("__few_shots"): index = x["__index"] few_shot_doc = ( GSM8K_FEW_SHOTS[index] if index < len(GSM8K_FEW_SHOTS) else GSM8K_FEW_SHOTS[-1] ) answer = few_shot_doc["answer"] question = few_shot_doc["question"] else: answer = f"{x['answer'].split('####')[-1].strip()}" question = x["question"] query = dedent(f"""\ Question: {question} Step-by-Step Answer:\ """).strip() choices = [f" {answer}"] return Doc(query=query, choices=choices, gold_index=0) math_hard_lighteval = [ LightevalTaskConfig( name=f"math_hard:{subset}", suite=["lighteval", "math"], prompt_function=math_hard_prompt_function, hf_repo="lighteval/MATH-Hard", hf_subset=subset, evaluation_splits=["test"], few_shots_split="train", generation_size=1024, metric=[ as_lighteval_metric( math_metric( gold_extraction_target=( LatexExtractionConfig(boxed_match_priority=0), ), pred_extraction_target=( LatexExtractionConfig(), ExprExtractionConfig(), ), ) ), ], stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"], trust_dataset=True, version=0, ) for subset in [ "algebra", "counting_and_probability", "geometry", "intermediate_algebra", "number_theory", "prealgebra", "precalculus", ] ] math_500_lighteval = [ LightevalTaskConfig( name="math_500", suite=["lighteval", "math"], prompt_function=math_prompt_function, hf_repo="HuggingFaceH4/MATH-500", hf_subset="default", evaluation_splits=["test"], few_shots_split="test", generation_size=1024, metric=[ as_lighteval_metric( math_metric( gold_extraction_target=( LatexExtractionConfig(boxed_match_priority=0), ), pred_extraction_target=( LatexExtractionConfig(), ExprExtractionConfig(), ), ) ), ], stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"], trust_dataset=True, version=0, ) ] aime24_lighteval = [ LightevalTaskConfig( name="aime24", suite=["lighteval", "math"], prompt_function=math_aime24_prompt_function, hf_repo="zwhe99/aime24", hf_subset="default", evaluation_splits=["test"], few_shots_split="test", generation_size=1024, metric=[ as_lighteval_metric( math_metric( gold_extraction_target=(LatexExtractionConfig(),), pred_extraction_target=( LatexExtractionConfig(), ExprExtractionConfig(), ), ) ), ], stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"], trust_dataset=True, version=0, ) ] amc23_lighteval = [ LightevalTaskConfig( name="amc23", suite=["lighteval", "math"], prompt_function=math_amc23_prompt_function, hf_repo="zwhe99/amc23", hf_subset="default", hf_filter=lambda x: len(x["question"].strip()) > 0, evaluation_splits=["test"], few_shots_split="test", generation_size=1024, metric=[ as_lighteval_metric( math_metric( gold_extraction_target=(ExprExtractionConfig(),), pred_extraction_target=( LatexExtractionConfig(), ExprExtractionConfig(), ), ) ), ], stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"], trust_dataset=True, version=0, ) ] gsm8k_lighteval = [ LightevalTaskConfig( name="gsm8k", suite=["lighteval", "math"], prompt_function=gsm8k_prompt_function, hf_repo="openai/gsm8k", hf_subset="main", hf_filter=lambda x: len(x["question"].strip()) > 0, evaluation_splits=["test"], few_shots_split="test", generation_size=1024, stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"], metric=[ as_lighteval_metric( math_metric( gold_extraction_target=(ExprExtractionConfig(),), pred_extraction_target=( LatexExtractionConfig(), ExprExtractionConfig(), ), fallback_mode="first_match", ) ), ], ) ] TASKS_TABLE = [ *gsm8k_lighteval, *math_hard_lighteval, *math_500_lighteval, *aime24_lighteval, *amc23_lighteval, ]