File size: 7,598 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# referenced from
# Library Name: torchtext
# Authors: torchtext authors and @sluks
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from collections import Counter
from collections.abc import Sequence
from typing import Callable, Optional, Union
import torch
from torch import Tensor, tensor
def _count_ngram(ngram_input_list: Sequence[str], n_gram: int) -> Counter:
"""Count how many times each word appears in a given text with ngram.
Args:
ngram_input_list: A list of translated text or reference texts
n_gram: gram value ranged 1 to 4
Return:
ngram_counter: a collections.Counter object of ngram
"""
ngram_counter: Counter = Counter()
for i in range(1, n_gram + 1):
for j in range(len(ngram_input_list) - i + 1):
ngram_key = tuple(ngram_input_list[j : (i + j)])
ngram_counter[ngram_key] += 1
return ngram_counter
def _tokenize_fn(sentence: str) -> Sequence[str]:
"""Tokenizes sentence into list of words.
Args:
sentence: A sentence separated by white space.
Return:
List of words
"""
return sentence.split()
def _bleu_score_update(
preds: Sequence[str],
target: Sequence[Sequence[str]],
numerator: Tensor,
denominator: Tensor,
preds_len: Tensor,
target_len: Tensor,
n_gram: int = 4,
tokenizer: Callable[[str], Sequence[str]] = _tokenize_fn,
) -> tuple[Tensor, Tensor]:
"""Update and returns variables required to compute the BLEU score.
Args:
preds: An iterable of machine translated corpus
target: An iterable of iterables of reference corpus
numerator: Numerator of precision score (true positives)
denominator: Denominator of precision score (true positives + false positives)
preds_len: count of words in a candidate prediction
target_len: count of words in a reference translation
target: count of words in a reference translation
n_gram: gram value ranged 1 to 4
tokenizer: A function that turns sentence into list of words
"""
target_: Sequence[Sequence[Sequence[str]]] = [[tokenizer(line) if line else [] for line in t] for t in target]
preds_: Sequence[Sequence[str]] = [tokenizer(line) if line else [] for line in preds]
for pred, targets in zip(preds_, target_):
preds_len += len(pred)
target_len_list = [len(tgt) for tgt in targets]
target_len_diff = [abs(len(pred) - x) for x in target_len_list]
target_len += target_len_list[target_len_diff.index(min(target_len_diff))]
preds_counter: Counter = _count_ngram(pred, n_gram)
target_counter: Counter = Counter()
for tgt in targets:
target_counter |= _count_ngram(tgt, n_gram)
ngram_counter_clip = preds_counter & target_counter
for counter_clip in ngram_counter_clip:
numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]
for counter in preds_counter:
denominator[len(counter) - 1] += preds_counter[counter]
return preds_len, target_len
def _bleu_score_compute(
preds_len: Tensor,
target_len: Tensor,
numerator: Tensor,
denominator: Tensor,
n_gram: int,
weights: Sequence[float],
smooth: bool,
) -> Tensor:
"""Compute the BLEU score.
Args:
preds_len: count of words in a candidate translation
target_len: count of words in a reference translation
numerator: Numerator of precision score (true positives)
denominator: Denominator of precision score (true positives + false positives)
n_gram: gram value ranged 1 to 4
weights: Weights used for unigrams, bigrams, etc. to calculate BLEU score.
smooth: Whether to apply smoothing
"""
device = numerator.device
if min(numerator) == 0.0:
return tensor(0.0, device=device)
if smooth:
precision_scores = torch.div(
torch.add(numerator, torch.ones(n_gram, device=device)),
torch.add(denominator, torch.ones(n_gram, device=device)),
)
precision_scores[0] = numerator[0] / denominator[0]
else:
precision_scores = numerator / denominator
log_precision_scores = tensor(weights, device=device) * torch.log(precision_scores)
geometric_mean = torch.exp(torch.sum(log_precision_scores))
brevity_penalty = tensor(1.0, device=device) if preds_len > target_len else torch.exp(1 - (target_len / preds_len))
return brevity_penalty * geometric_mean
def bleu_score(
preds: Union[str, Sequence[str]],
target: Sequence[Union[str, Sequence[str]]],
n_gram: int = 4,
smooth: bool = False,
weights: Optional[Sequence[float]] = None,
) -> Tensor:
"""Calculate `BLEU score`_ of machine translated text with one or more references.
Args:
preds: An iterable of machine translated corpus
target: An iterable of iterables of reference corpus
n_gram: Gram value ranged from 1 to 4
smooth: Whether to apply smoothing - see [2]
weights:
Weights used for unigrams, bigrams, etc. to calculate BLEU score.
If not provided, uniform weights are used.
Return:
Tensor with BLEU Score
Raises:
ValueError: If ``preds`` and ``target`` corpus have different lengths.
ValueError: If a length of a list of weights is not ``None`` and not equal to ``n_gram``.
Example:
>>> from torchmetrics.functional.text import bleu_score
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> bleu_score(preds, target)
tensor(0.7598)
References:
[1] BLEU: a Method for Automatic Evaluation of Machine Translation by Papineni,
Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu `BLEU`_
[2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence
and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_
"""
preds_ = [preds] if isinstance(preds, str) else preds
target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target]
if len(preds_) != len(target_):
raise ValueError(f"Corpus has different size {len(preds_)} != {len(target_)}")
if weights is not None and len(weights) != n_gram:
raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}")
if weights is None:
weights = [1.0 / n_gram] * n_gram
numerator = torch.zeros(n_gram)
denominator = torch.zeros(n_gram)
preds_len = tensor(0.0)
target_len = tensor(0.0)
preds_len, target_len = _bleu_score_update(
preds_, target_, numerator, denominator, preds_len, target_len, n_gram, _tokenize_fn
)
return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, weights, smooth)
|