|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
import string |
|
from collections import Counter |
|
from typing import Any, Callable, Union |
|
|
|
from torch import Tensor, tensor |
|
|
|
from torchmetrics.utilities import rank_zero_warn |
|
|
|
SINGLE_PRED_TYPE = dict[str, str] |
|
PREDS_TYPE = Union[SINGLE_PRED_TYPE, list[SINGLE_PRED_TYPE]] |
|
SINGLE_TARGET_TYPE = dict[str, Union[str, dict[str, Union[list[str], list[int]]]]] |
|
TARGETS_TYPE = Union[SINGLE_TARGET_TYPE, list[SINGLE_TARGET_TYPE]] |
|
UPDATE_METHOD_SINGLE_PRED_TYPE = Union[list[dict[str, Union[str, int]]], str, dict[str, Union[list[str], list[int]]]] |
|
|
|
SQuAD_FORMAT = { |
|
"answers": {"answer_start": [1], "text": ["This is a test text"]}, |
|
"context": "This is a test context.", |
|
"id": "1", |
|
"question": "Is this a test?", |
|
"title": "train test", |
|
} |
|
|
|
|
|
def _normalize_text(s: str) -> str: |
|
"""Lower text and remove punctuation, articles and extra whitespace.""" |
|
|
|
def remove_articles(text: str) -> str: |
|
return re.sub(r"\b(a|an|the)\b", " ", text) |
|
|
|
def white_space_fix(text: str) -> str: |
|
return " ".join(text.split()) |
|
|
|
def remove_punc(text: str) -> str: |
|
exclude = set(string.punctuation) |
|
return "".join(ch for ch in text if ch not in exclude) |
|
|
|
def lower(text: str) -> str: |
|
return text.lower() |
|
|
|
return white_space_fix(remove_articles(remove_punc(lower(s)))) |
|
|
|
|
|
def _get_tokens(s: str) -> list[str]: |
|
"""Split a sentence into separate tokens.""" |
|
return [] if not s else _normalize_text(s).split() |
|
|
|
|
|
def _compute_f1_score(predicted_answer: str, target_answer: str) -> Tensor: |
|
"""Compute F1 Score for two sentences.""" |
|
target_tokens = _get_tokens(target_answer) |
|
predicted_tokens = _get_tokens(predicted_answer) |
|
common = Counter(target_tokens) & Counter(predicted_tokens) |
|
num_same = tensor(sum(common.values())) |
|
if len(target_tokens) == 0 or len(predicted_tokens) == 0: |
|
|
|
return tensor(int(target_tokens == predicted_tokens)) |
|
if num_same == 0: |
|
return tensor(0.0) |
|
precision = 1.0 * num_same / tensor(len(predicted_tokens)) |
|
recall = 1.0 * num_same / tensor(len(target_tokens)) |
|
return (2 * precision * recall) / (precision + recall) |
|
|
|
|
|
def _compute_exact_match_score(prediction: str, ground_truth: str) -> Tensor: |
|
"""Compute Exact Match for two sentences.""" |
|
return tensor(int(_normalize_text(prediction) == _normalize_text(ground_truth))) |
|
|
|
|
|
def _metric_max_over_ground_truths( |
|
metric_fn: Callable[[str, str], Tensor], prediction: str, ground_truths: list[str] |
|
) -> Tensor: |
|
"""Calculate maximum score for a predicted answer with all reference answers.""" |
|
return max(metric_fn(prediction, truth) for truth in ground_truths) |
|
|
|
|
|
def _squad_input_check( |
|
preds: PREDS_TYPE, targets: TARGETS_TYPE |
|
) -> tuple[dict[str, str], list[dict[str, list[dict[str, list[dict[str, Any]]]]]]]: |
|
"""Check for types and convert the input to necessary format to compute the input.""" |
|
if isinstance(preds, dict): |
|
preds = [preds] |
|
|
|
if isinstance(targets, dict): |
|
targets = [targets] |
|
|
|
for pred in preds: |
|
pred_keys = pred.keys() |
|
if "prediction_text" not in pred_keys or "id" not in pred_keys: |
|
raise KeyError( |
|
"Expected keys in a single prediction are 'prediction_text' and 'id'." |
|
"Please make sure that 'prediction_text' maps to the answer string and 'id' maps to the key string." |
|
) |
|
|
|
for target in targets: |
|
target_keys = target.keys() |
|
if "answers" not in target_keys or "id" not in target_keys: |
|
raise KeyError( |
|
"Expected keys in a single target are 'answers' and 'id'." |
|
"Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key string.\n" |
|
"SQuAD Format: " |
|
f"{SQuAD_FORMAT}" |
|
) |
|
|
|
answers: dict[str, Union[list[str], list[int]]] = target["answers"] |
|
if "text" not in answers: |
|
raise KeyError( |
|
"Expected keys in a 'answers' are 'text'." |
|
"Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n" |
|
"SQuAD Format: " |
|
f"{SQuAD_FORMAT}" |
|
) |
|
|
|
preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} |
|
_fn_answer = lambda tgt: {"answers": [{"text": txt} for txt in tgt["answers"]["text"]], "id": tgt["id"]} |
|
targets_dict = [{"paragraphs": [{"qas": [_fn_answer(target) for target in targets]}]}] |
|
return preds_dict, targets_dict |
|
|
|
|
|
def _squad_update( |
|
preds: dict[str, str], |
|
target: list[dict[str, list[dict[str, list[dict[str, Any]]]]]], |
|
) -> tuple[Tensor, Tensor, Tensor]: |
|
"""Compute F1 Score and Exact Match for a collection of predictions and references. |
|
|
|
Args: |
|
preds: A dictionary mapping an `id` to the predicted `answer`. |
|
target: |
|
A list of dictionary mapping `paragraphs` to list of dictionary mapping `qas` to a list of dictionary |
|
containing `id` and list of all possible `answers`. |
|
|
|
Return: |
|
Tuple containing F1 score, Exact match score and total number of examples. |
|
|
|
Example: |
|
>>> from torchmetrics.functional.text.squad import _squad_update |
|
>>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] |
|
>>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] |
|
>>> preds_dict = {pred["id"]: pred["prediction_text"] for pred in preds} |
|
>>> targets_dict = [ |
|
... dict(paragraphs=[dict(qas=[dict(answers=[ |
|
... {"text": txt} for txt in tgt["answers"]["text"]], id=tgt["id"]) for tgt in target |
|
... ])]) |
|
... ] |
|
>>> _squad_update(preds_dict, targets_dict) |
|
(tensor(1.), tensor(1.), tensor(1)) |
|
|
|
""" |
|
f1 = tensor(0.0) |
|
exact_match = tensor(0.0) |
|
total = tensor(0) |
|
for article in target: |
|
for paragraph in article["paragraphs"]: |
|
for qa in paragraph["qas"]: |
|
total += 1 |
|
if qa["id"] not in preds: |
|
rank_zero_warn(f"Unanswered question {qa['id']} will receive score 0.") |
|
continue |
|
ground_truths = [x["text"] for x in qa["answers"]] |
|
pred = preds[qa["id"]] |
|
exact_match += _metric_max_over_ground_truths(_compute_exact_match_score, pred, ground_truths) |
|
f1 += _metric_max_over_ground_truths(_compute_f1_score, pred, ground_truths) |
|
|
|
return f1, exact_match, total |
|
|
|
|
|
def _squad_compute(f1: Tensor, exact_match: Tensor, total: Tensor) -> dict[str, Tensor]: |
|
"""Aggregate the F1 Score and Exact match for the batch. |
|
|
|
Return: |
|
Dictionary containing the F1 score, Exact match score for the batch. |
|
|
|
""" |
|
exact_match = 100.0 * exact_match / total |
|
f1 = 100.0 * f1 / total |
|
return {"exact_match": exact_match, "f1": f1} |
|
|
|
|
|
def squad(preds: PREDS_TYPE, target: TARGETS_TYPE) -> dict[str, Tensor]: |
|
"""Calculate `SQuAD Metric`_ . |
|
|
|
Args: |
|
preds: A Dictionary or List of Dictionary-s that map `id` and `prediction_text` to the respective values. |
|
|
|
Example prediction: |
|
|
|
.. code-block:: python |
|
|
|
{"prediction_text": "TorchMetrics is awesome", "id": "123"} |
|
|
|
target: A Dictionary or List of Dictionary-s that contain the `answers` and `id` in the SQuAD Format. |
|
|
|
Example target: |
|
|
|
.. code-block:: python |
|
|
|
{ |
|
'answers': [{'answer_start': [1], 'text': ['This is a test answer']}], |
|
'id': '1', |
|
} |
|
|
|
Reference SQuAD Format: |
|
|
|
.. code-block:: python |
|
|
|
{ |
|
'answers': {'answer_start': [1], 'text': ['This is a test text']}, |
|
'context': 'This is a test context.', |
|
'id': '1', |
|
'question': 'Is this a test?', |
|
'title': 'train test' |
|
} |
|
|
|
|
|
Return: |
|
Dictionary containing the F1 score, Exact match score for the batch. |
|
|
|
Example: |
|
>>> from torchmetrics.functional.text.squad import squad |
|
>>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] |
|
>>> target = [{"answers": {"answer_start": [97], "text": ["1976"]},"id": "56e10a3be3433e1400422b22"}] |
|
>>> squad(preds, target) |
|
{'exact_match': tensor(100.), 'f1': tensor(100.)} |
|
|
|
Raises: |
|
KeyError: |
|
If the required keys are missing in either predictions or targets. |
|
|
|
References: |
|
[1] SQuAD: 100,000+ Questions for Machine Comprehension of Text by Pranav Rajpurkar, Jian Zhang, Konstantin |
|
Lopyrev, Percy Liang `SQuAD Metric`_ . |
|
|
|
""" |
|
preds_dict, target_dict = _squad_input_check(preds, target) |
|
f1, exact_match, total = _squad_update(preds_dict, target_dict) |
|
return _squad_compute(f1, exact_match, total) |
|
|