# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # referenced from # Library Name: torchtext # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score from collections import Counter from collections.abc import Sequence from typing import Callable, Optional, Union import torch from torch import Tensor, tensor def _count_ngram(ngram_input_list: Sequence[str], n_gram: int) -> Counter: """Count how many times each word appears in a given text with ngram. Args: ngram_input_list: A list of translated text or reference texts n_gram: gram value ranged 1 to 4 Return: ngram_counter: a collections.Counter object of ngram """ ngram_counter: Counter = Counter() for i in range(1, n_gram + 1): for j in range(len(ngram_input_list) - i + 1): ngram_key = tuple(ngram_input_list[j : (i + j)]) ngram_counter[ngram_key] += 1 return ngram_counter def _tokenize_fn(sentence: str) -> Sequence[str]: """Tokenizes sentence into list of words. Args: sentence: A sentence separated by white space. Return: List of words """ return sentence.split() def _bleu_score_update( preds: Sequence[str], target: Sequence[Sequence[str]], numerator: Tensor, denominator: Tensor, preds_len: Tensor, target_len: Tensor, n_gram: int = 4, tokenizer: Callable[[str], Sequence[str]] = _tokenize_fn, ) -> tuple[Tensor, Tensor]: """Update and returns variables required to compute the BLEU score. Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus numerator: Numerator of precision score (true positives) denominator: Denominator of precision score (true positives + false positives) preds_len: count of words in a candidate prediction target_len: count of words in a reference translation target: count of words in a reference translation n_gram: gram value ranged 1 to 4 tokenizer: A function that turns sentence into list of words """ target_: Sequence[Sequence[Sequence[str]]] = [[tokenizer(line) if line else [] for line in t] for t in target] preds_: Sequence[Sequence[str]] = [tokenizer(line) if line else [] for line in preds] for pred, targets in zip(preds_, target_): preds_len += len(pred) target_len_list = [len(tgt) for tgt in targets] target_len_diff = [abs(len(pred) - x) for x in target_len_list] target_len += target_len_list[target_len_diff.index(min(target_len_diff))] preds_counter: Counter = _count_ngram(pred, n_gram) target_counter: Counter = Counter() for tgt in targets: target_counter |= _count_ngram(tgt, n_gram) ngram_counter_clip = preds_counter & target_counter for counter_clip in ngram_counter_clip: numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] for counter in preds_counter: denominator[len(counter) - 1] += preds_counter[counter] return preds_len, target_len def _bleu_score_compute( preds_len: Tensor, target_len: Tensor, numerator: Tensor, denominator: Tensor, n_gram: int, weights: Sequence[float], smooth: bool, ) -> Tensor: """Compute the BLEU score. Args: preds_len: count of words in a candidate translation target_len: count of words in a reference translation numerator: Numerator of precision score (true positives) denominator: Denominator of precision score (true positives + false positives) n_gram: gram value ranged 1 to 4 weights: Weights used for unigrams, bigrams, etc. to calculate BLEU score. smooth: Whether to apply smoothing """ device = numerator.device if min(numerator) == 0.0: return tensor(0.0, device=device) if smooth: precision_scores = torch.div( torch.add(numerator, torch.ones(n_gram, device=device)), torch.add(denominator, torch.ones(n_gram, device=device)), ) precision_scores[0] = numerator[0] / denominator[0] else: precision_scores = numerator / denominator log_precision_scores = tensor(weights, device=device) * torch.log(precision_scores) geometric_mean = torch.exp(torch.sum(log_precision_scores)) brevity_penalty = tensor(1.0, device=device) if preds_len > target_len else torch.exp(1 - (target_len / preds_len)) return brevity_penalty * geometric_mean def bleu_score( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], n_gram: int = 4, smooth: bool = False, weights: Optional[Sequence[float]] = None, ) -> Tensor: """Calculate `BLEU score`_ of machine translated text with one or more references. Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus n_gram: Gram value ranged from 1 to 4 smooth: Whether to apply smoothing - see [2] weights: Weights used for unigrams, bigrams, etc. to calculate BLEU score. If not provided, uniform weights are used. Return: Tensor with BLEU Score Raises: ValueError: If ``preds`` and ``target`` corpus have different lengths. ValueError: If a length of a list of weights is not ``None`` and not equal to ``n_gram``. Example: >>> from torchmetrics.functional.text import bleu_score >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> bleu_score(preds, target) tensor(0.7598) References: [1] BLEU: a Method for Automatic Evaluation of Machine Translation by Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu `BLEU`_ [2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_ """ preds_ = [preds] if isinstance(preds, str) else preds target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target] if len(preds_) != len(target_): raise ValueError(f"Corpus has different size {len(preds_)} != {len(target_)}") if weights is not None and len(weights) != n_gram: raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}") if weights is None: weights = [1.0 / n_gram] * n_gram numerator = torch.zeros(n_gram) denominator = torch.zeros(n_gram) preds_len = tensor(0.0) target_len = tensor(0.0) preds_len, target_len = _bleu_score_update( preds_, target_, numerator, denominator, preds_len, target_len, n_gram, _tokenize_fn ) return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, weights, smooth)