# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Adapted from: # Link: https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ # Link: https://github.com/huggingface/datasets/blob/master/metrics/squad/squad.py import re import string from collections import Counter from typing import Any, Callable, Union from torch import Tensor, tensor from torchmetrics.utilities import rank_zero_warn SINGLE_PRED_TYPE = dict[str, str] PREDS_TYPE = Union[SINGLE_PRED_TYPE, list[SINGLE_PRED_TYPE]] SINGLE_TARGET_TYPE = dict[str, Union[str, dict[str, Union[list[str], list[int]]]]] TARGETS_TYPE = Union[SINGLE_TARGET_TYPE, list[SINGLE_TARGET_TYPE]] UPDATE_METHOD_SINGLE_PRED_TYPE = Union[list[dict[str, Union[str, int]]], str, dict[str, Union[list[str], list[int]]]] SQuAD_FORMAT = { "answers": {"answer_start": [1], "text": ["This is a test text"]}, "context": "This is a test context.", "id": "1", "question": "Is this a test?", "title": "train test", } def _normalize_text(s: str) -> str: """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text: str) -> str: return re.sub(r"\b(a|an|the)\b", " ", text) def white_space_fix(text: str) -> str: return " ".join(text.split()) def remove_punc(text: str) -> str: exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text: str) -> str: return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def _get_tokens(s: str) -> list[str]: """Split a sentence into separate tokens.""" return [] if not s else _normalize_text(s).split() def _compute_f1_score(predicted_answer: str, target_answer: str) -> Tensor: """Compute F1 Score for two sentences.""" target_tokens = _get_tokens(target_answer) predicted_tokens = _get_tokens(predicted_answer) common = Counter(target_tokens) & Counter(predicted_tokens) num_same = tensor(sum(common.values())) if len(target_tokens) == 0 or len(predicted_tokens) == 0: # If either is no-answer, then F1 is 1 if they agree, 0 otherwise return tensor(int(target_tokens == predicted_tokens)) if num_same == 0: return tensor(0.0) precision = 1.0 * num_same / tensor(len(predicted_tokens)) recall = 1.0 * num_same / tensor(len(target_tokens)) return (2 * precision * recall) / (precision + recall) def _compute_exact_match_score(prediction: str, ground_truth: str) -> Tensor: """Compute Exact Match for two sentences.""" return tensor(int(_normalize_text(prediction) == _normalize_text(ground_truth))) def _metric_max_over_ground_truths( metric_fn: Callable[[str, str], Tensor], prediction: str, ground_truths: list[str] ) -> Tensor: """Calculate maximum score for a predicted answer with all reference answers.""" return max(metric_fn(prediction, truth) for truth in ground_truths) # type: ignore[type-var] def _squad_input_check( preds: PREDS_TYPE, targets: TARGETS_TYPE ) -> tuple[dict[str, str], list[dict[str, list[dict[str, list[dict[str, Any]]]]]]]: """Check for types and convert the input to necessary format to compute the input.""" if isinstance(preds, dict): preds = [preds] if isinstance(targets, dict): targets = [targets] for pred in preds: pred_keys = pred.keys() if "prediction_text" not in pred_keys or "id" not in pred_keys: raise KeyError( "Expected keys in a single prediction are 'prediction_text' and 'id'." "Please make sure that 'prediction_text' maps to the answer string and 'id' maps to the key string." ) for target in targets: target_keys = target.keys() if "answers" not in target_keys or "id" not in target_keys: raise KeyError( "Expected keys in a single target are 'answers' and 'id'." "Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key string.\n" "SQuAD Format: " f"{SQuAD_FORMAT}" ) answers: dict[str, Union[list[str], list[int]]] = target["answers"] # type: ignore[assignment] if "text" not in answers: raise KeyError( "Expected keys in a 'answers' are 'text'." "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n" "SQuAD Format: " f"{SQuAD_FORMAT}" ) preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} _fn_answer = lambda tgt: {"answers": [{"text": txt} for txt in tgt["answers"]["text"]], "id": tgt["id"]} targets_dict = [{"paragraphs": [{"qas": [_fn_answer(target) for target in targets]}]}] return preds_dict, targets_dict def _squad_update( preds: dict[str, str], target: list[dict[str, list[dict[str, list[dict[str, Any]]]]]], ) -> tuple[Tensor, Tensor, Tensor]: """Compute F1 Score and Exact Match for a collection of predictions and references. Args: preds: A dictionary mapping an `id` to the predicted `answer`. target: A list of dictionary mapping `paragraphs` to list of dictionary mapping `qas` to a list of dictionary containing `id` and list of all possible `answers`. Return: Tuple containing F1 score, Exact match score and total number of examples. Example: >>> from torchmetrics.functional.text.squad import _squad_update >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] >>> preds_dict = {pred["id"]: pred["prediction_text"] for pred in preds} >>> targets_dict = [ ... dict(paragraphs=[dict(qas=[dict(answers=[ ... {"text": txt} for txt in tgt["answers"]["text"]], id=tgt["id"]) for tgt in target ... ])]) ... ] >>> _squad_update(preds_dict, targets_dict) (tensor(1.), tensor(1.), tensor(1)) """ f1 = tensor(0.0) exact_match = tensor(0.0) total = tensor(0) for article in target: for paragraph in article["paragraphs"]: for qa in paragraph["qas"]: total += 1 if qa["id"] not in preds: rank_zero_warn(f"Unanswered question {qa['id']} will receive score 0.") continue ground_truths = [x["text"] for x in qa["answers"]] pred = preds[qa["id"]] exact_match += _metric_max_over_ground_truths(_compute_exact_match_score, pred, ground_truths) f1 += _metric_max_over_ground_truths(_compute_f1_score, pred, ground_truths) return f1, exact_match, total def _squad_compute(f1: Tensor, exact_match: Tensor, total: Tensor) -> dict[str, Tensor]: """Aggregate the F1 Score and Exact match for the batch. Return: Dictionary containing the F1 score, Exact match score for the batch. """ exact_match = 100.0 * exact_match / total f1 = 100.0 * f1 / total return {"exact_match": exact_match, "f1": f1} def squad(preds: PREDS_TYPE, target: TARGETS_TYPE) -> dict[str, Tensor]: """Calculate `SQuAD Metric`_ . Args: preds: A Dictionary or List of Dictionary-s that map `id` and `prediction_text` to the respective values. Example prediction: .. code-block:: python {"prediction_text": "TorchMetrics is awesome", "id": "123"} target: A Dictionary or List of Dictionary-s that contain the `answers` and `id` in the SQuAD Format. Example target: .. code-block:: python { 'answers': [{'answer_start': [1], 'text': ['This is a test answer']}], 'id': '1', } Reference SQuAD Format: .. code-block:: python { 'answers': {'answer_start': [1], 'text': ['This is a test text']}, 'context': 'This is a test context.', 'id': '1', 'question': 'Is this a test?', 'title': 'train test' } Return: Dictionary containing the F1 score, Exact match score for the batch. Example: >>> from torchmetrics.functional.text.squad import squad >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]},"id": "56e10a3be3433e1400422b22"}] >>> squad(preds, target) {'exact_match': tensor(100.), 'f1': tensor(100.)} Raises: KeyError: If the required keys are missing in either predictions or targets. References: [1] SQuAD: 100,000+ Questions for Machine Comprehension of Text by Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang `SQuAD Metric`_ . """ preds_dict, target_dict = _squad_input_check(preds, target) f1, exact_match, total = _squad_update(preds_dict, target_dict) return _squad_compute(f1, exact_match, total)