jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# referenced from
# Library Name: torchtext
# Authors: torchtext authors
# Date: 2021-12-07
# Link:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# The RWTH Extended Edit Distance (EED) License
# Copyright (c) 2019, RWTH.
# All rights reserved.
# This license is derived from the Q Public License v1.0 and the Qt Non-Commercial License v1.0 which are both Copyright
# by Trolltech AS, Norway. The aim of this license is to lay down the conditions enabling you to use, modify and
# circulate the SOFTWARE, use of third-party application programs based on the Software and publication of results
# obtained through the use of modified and unmodified versions of the SOFTWARE. However, RWTH remain the authors of the
# SOFTWARE and so retain property rights and the use of all ancillary rights. The SOFTWARE is defined as all successive
# versions of EED software and their documentation that have been developed by RWTH.
#
# When you access and use the SOFTWARE, you are presumed to be aware of and to have accepted all the rights and
# obligations of the present license:
#
# 1. You are granted the non-exclusive rights set forth in this license provided you agree to and comply with any all
# conditions in this license. Whole or partial distribution of the Software, or software items that link with the
# Software, in any form signifies acceptance of this license for non-commercial use only.
# 2. You may copy and distribute the Software in unmodified form provided that the entire package, including - but not
# restricted to - copyright, trademark notices and disclaimers, as released by the initial developer of the
# Software, is distributed.
# 3. You may make modifications to the Software and distribute your modifications, in a form that is separate from the
# Software, such as patches. The following restrictions apply to modifications:
# a. Modifications must not alter or remove any copyright notices in the Software.
# b When modifications to the Software are released under this license, a non-exclusive royalty-free right is
# granted to the initial developer of the Software to distribute your modification in future versions of the
# Software provided such versions remain available under these terms in addition to any other license(s) of the
# initial developer.
# 4. You may distribute machine-executable forms of the Software or machine-executable forms of modified versions of
# the Software, provided that you meet these restrictions:
# a. You must include this license document in the distribution.
# b. You must ensure that all recipients of the machine-executable forms are also able to receive the complete
# machine-readable source code to the distributed Software, including all modifications, without any charge
# beyond the costs of data transfer, and place prominent notices in the distribution explaining this.
# c. You must ensure that all modifications included in the machine-executable forms are available under the terms
# of this license.
# 5. You may use the original or modified versions of the Software to compile, link and run application programs
# legally developed by you or by others.
# 6. You may develop application programs, reusable components and other software items, in a non-commercial setting,
# that link with the original or modified versions of the Software. These items, when distributed, are subject to
# the following requirements:
# a. You must ensure that all recipients of machine-executable forms of these items are also able to receive and use
# the complete machine-readable source code to the items without any charge beyond the costs of data transfer.
# b. You must explicitly license all recipients of your items to use and re-distribute original and modified
# versions of the items in both machine-executable and source code forms. The recipients must be able to do so
# without any charges whatsoever, and they must be able to re-distribute to anyone they choose.
# c. If an application program gives you access to functionality of the Software for development of application
# programs, reusable components or other software components (e.g. an application that is a scripting wrapper),
# usage of the application program is considered to be usage of the Software and is thus bound by this license.
# d. If the items are not available to the general public, and the initial developer of the Software requests a copy
# of the items, then you must supply one.
# 7. Users must cite the authors of the Software upon publication of results obtained through the use of original or
# modified versions of the Software by referring to the following publication:
# P. Stanchev, W. Wang, and H. Ney, β€œEED: Extended Edit Distance Measure for Machine Translation”, submitted to WMT
# 2019.
# 8. In no event shall the initial developers or copyright holders be liable for any damages whatsoever, including -
# but not restricted to - lost revenue or profits or other direct, indirect, special, incidental or consequential
# damages, even if they have been advised of the possibility of such damages, except to the extent invariable law,
# if any, provides otherwise.
# 9. You assume all risks concerning the quality or the effects of the SOFTWARE and its use. If the SOFTWARE is
# defective, you will bear the costs of all required services, corrections or repairs.
# 10. This license has the binding value of a contract.
# 11. The present license and its effects are subject to German law and the competent German Courts.
#
# The Software and this license document are provided "AS IS" with NO EXPLICIT OR IMPLICIT WARRANTY OF ANY KIND,
# INCLUDING WARRANTY OF DESIGN, ADAPTION, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.
import re
import unicodedata
from collections.abc import Sequence
from math import inf
from typing import List, Optional, Union
from torch import Tensor, stack, tensor
from typing_extensions import Literal
from torchmetrics.functional.text.helper import _validate_inputs
def _distance_between_words(preds_word: str, target_word: str) -> int:
"""Distance measure used for substitutions/identity operation.
Code adapted from https://github.com/rwth-i6/ExtendedEditDistance/blob/master/EED.py.
Args:
preds_word: hypothesis word string
target_word: reference word string
Return:
0 for match, 1 for no match
"""
return int(preds_word != target_word)
def _eed_function(
hyp: str,
ref: str,
alpha: float = 2.0,
rho: float = 0.3,
deletion: float = 0.2,
insertion: float = 1.0,
) -> float:
"""Compute extended edit distance score for two lists of strings: hyp and ref.
Code adapted from: https://github.com/rwth-i6/ExtendedEditDistance/blob/master/EED.py.
Args:
hyp: A hypothesis string
ref: A reference string
alpha: optimal jump penalty, penalty for jumps between characters
rho: coverage cost, penalty for repetition of characters
deletion: penalty for deletion of character
insertion: penalty for insertion or substitution of character
Return:
Extended edit distance score as float
"""
number_of_visits = [-1] * (len(hyp) + 1)
# row[i] stores cost of cheapest path from (0,0) to (i,l) in CDER alignment grid.
row = [1.0] * (len(hyp) + 1)
row[0] = 0.0 # CDER initialisation 0,0 = 0.0, rest 1.0
next_row = [inf] * (len(hyp) + 1)
for w in range(1, len(ref) + 1):
for i in range(len(hyp) + 1):
if i > 0:
next_row[i] = min(
next_row[i - 1] + deletion,
row[i - 1] + _distance_between_words(hyp[i - 1], ref[w - 1]),
row[i] + insertion,
)
else:
next_row[i] = row[i] + 1.0
min_index = next_row.index(min(next_row))
number_of_visits[min_index] += 1
# Long Jumps
if ref[w - 1] == " ":
jump = alpha + next_row[min_index]
next_row = [min(x, jump) for x in next_row]
row = next_row
next_row = [inf] * (len(hyp) + 1)
coverage = rho * sum(x if x >= 0 else 1 for x in number_of_visits)
return min(1, (row[-1] + coverage) / (float(len(ref)) + coverage))
def _preprocess_en(sentence: str) -> str:
"""Preprocess english sentences.
Copied from https://github.com/rwth-i6/ExtendedEditDistance/blob/master/util.py.
Raises:
ValueError: If input sentence is not of a type `str`.
"""
if not isinstance(sentence, str):
raise ValueError(f"Only strings allowed during preprocessing step, found {type(sentence)} instead")
sentence = sentence.rstrip() # trailing space, tab, or newline
# Add space before interpunctions
rules_interpunction = [
(".", " ."),
("!", " !"),
("?", " ?"),
(",", " ,"),
]
for pattern, replacement in rules_interpunction:
sentence = sentence.replace(pattern, replacement)
rules_re = [
(r"\s+", r" "), # get rid of extra spaces
(r"(\d) ([.,]) (\d)", r"\1\2\3"), # 0 . 1 -> 0.1
(r"(Dr|Jr|Prof|Rev|Gen|Mr|Mt|Mrs|Ms) .", r"\1."), # Mr . -> Mr.
]
for pattern, replacement in rules_re:
sentence = re.sub(pattern, replacement, sentence)
# Add space between abbreviations
rules_interpunction = [
("e . g .", "e.g."),
("i . e .", "i.e."),
("U . S .", "U.S."),
]
for pattern, replacement in rules_interpunction:
sentence = sentence.replace(pattern, replacement)
# add space to beginning and end of string
return " " + sentence + " "
def _preprocess_ja(sentence: str) -> str:
"""Preprocess japanese sentences.
Copy from https://github.com/rwth-i6/ExtendedEditDistance/blob/master/util.py.
Raises:
ValueError: If input sentence is not of a type `str`.
"""
if not isinstance(sentence, str):
raise ValueError(f"Only strings allowed during preprocessing step, found {type(sentence)} instead")
sentence = sentence.rstrip() # trailing space, tab, newline
# characters which look identical actually are identical
return unicodedata.normalize("NFKC", sentence)
def _eed_compute(sentence_level_scores: List[Tensor]) -> Tensor:
"""Reduction for extended edit distance.
Args:
sentence_level_scores: list of sentence-level scores as floats
Return:
average of scores as a tensor
"""
if len(sentence_level_scores) == 0:
return tensor(0.0)
return sum(sentence_level_scores) / tensor(len(sentence_level_scores))
def _preprocess_sentences(
preds: Union[str, Sequence[str]],
target: Sequence[Union[str, Sequence[str]]],
language: Literal["en", "ja"],
) -> tuple[Union[str, Sequence[str]], Sequence[Union[str, Sequence[str]]]]:
"""Preprocess strings according to language requirements.
Args:
preds: An iterable of hypothesis corpus.
target: An iterable of iterables of reference corpus.
language: Language used in sentences. Only supports English (en) and Japanese (ja) for now. Defaults to en
Return:
Tuple of lists that contain the cleaned strings for target and preds
Raises:
ValueError: If a different language than ``'en'`` or ``'ja'`` is used
ValueError: If length of target not equal to length of preds
ValueError: If objects in reference and hypothesis corpus are not strings
"""
# sanity checks
target, preds = _validate_inputs(hypothesis_corpus=preds, ref_corpus=target)
# preprocess string
if language == "en":
preprocess_function = _preprocess_en
elif language == "ja":
preprocess_function = _preprocess_ja
else:
raise ValueError(f"Expected argument `language` to either be `en` or `ja` but got {language}")
preds = [preprocess_function(pred) for pred in preds]
target = [[preprocess_function(ref) for ref in reference] for reference in target]
return preds, target
def _compute_sentence_statistics(
preds_word: str,
target_words: Union[str, Sequence[str]],
alpha: float = 2.0,
rho: float = 0.3,
deletion: float = 0.2,
insertion: float = 1.0,
) -> Tensor:
"""Compute scores for ExtendedEditDistance.
Args:
target_words: An iterable of reference words
preds_word: A hypothesis word
alpha: An optimal jump penalty, penalty for jumps between characters
rho: coverage cost, penalty for repetition of characters
deletion: penalty for deletion of character
insertion: penalty for insertion or substitution of character
Return:
best_score: best (lowest) sentence-level score as a Tensor
"""
best_score = inf
for reference in target_words:
score = _eed_function(preds_word, reference, alpha, rho, deletion, insertion)
if score < best_score:
best_score = score
return tensor(best_score)
def _eed_update(
preds: Union[str, Sequence[str]],
target: Sequence[Union[str, Sequence[str]]],
language: Literal["en", "ja"] = "en",
alpha: float = 2.0,
rho: float = 0.3,
deletion: float = 0.2,
insertion: float = 1.0,
sentence_eed: Optional[List[Tensor]] = None,
) -> List[Tensor]:
"""Compute scores for ExtendedEditDistance.
Args:
preds: An iterable of hypothesis corpus
target: An iterable of iterables of reference corpus
language: Language used in sentences. Only supports English (en) and Japanese (ja) for now. Defaults to en
alpha: optimal jump penalty, penalty for jumps between characters
rho: coverage cost, penalty for repetition of characters
deletion: penalty for deletion of character
insertion: penalty for insertion or substitution of character
sentence_eed: list of sentence-level scores
Return:
individual sentence scores as a list of Tensors
"""
preds, target = _preprocess_sentences(preds, target, language)
if sentence_eed is None:
sentence_eed = []
# return tensor(0.0) if target or preds is empty
if 0 in (len(preds), len(target[0])):
return sentence_eed
for hypothesis, target_words in zip(preds, target):
score = _compute_sentence_statistics(hypothesis, target_words, alpha, rho, deletion, insertion)
sentence_eed.append(score)
return sentence_eed
def extended_edit_distance(
preds: Union[str, Sequence[str]],
target: Sequence[Union[str, Sequence[str]]],
language: Literal["en", "ja"] = "en",
return_sentence_level_score: bool = False,
alpha: float = 2.0,
rho: float = 0.3,
deletion: float = 0.2,
insertion: float = 1.0,
) -> Union[Tensor, tuple[Tensor, Tensor]]:
"""Compute extended edit distance score (`ExtendedEditDistance`_) [1] for strings or list of strings.
The metric utilises the Levenshtein distance and extends it by adding a jump operation.
Args:
preds: An iterable of hypothesis corpus.
target: An iterable of iterables of reference corpus.
language: Language used in sentences. Only supports English (en) and Japanese (ja) for now. Defaults to en
return_sentence_level_score: An indication of whether sentence-level EED score is to be returned.
alpha: optimal jump penalty, penalty for jumps between characters
rho: coverage cost, penalty for repetition of characters
deletion: penalty for deletion of character
insertion: penalty for insertion or substitution of character
Return:
Extended edit distance score as a tensor
Example:
>>> from torchmetrics.functional.text import extended_edit_distance
>>> preds = ["this is the prediction", "here is an other sample"]
>>> target = ["this is the reference", "here is another one"]
>>> extended_edit_distance(preds=preds, target=target)
tensor(0.3078)
References:
[1] P. Stanchev, W. Wang, and H. Ney, β€œEED: Extended Edit Distance Measure for Machine Translation”,
submitted to WMT 2019. `ExtendedEditDistance`_
"""
# input validation for parameters
for param_name, param in zip(["alpha", "rho", "deletion", "insertion"], [alpha, rho, deletion, insertion]):
if not isinstance(param, float) or (isinstance(param, float) and param < 0):
raise ValueError(f"Parameter `{param_name}` is expected to be a non-negative float.")
sentence_level_scores = _eed_update(preds, target, language, alpha, rho, deletion, insertion)
average = _eed_compute(sentence_level_scores)
if return_sentence_level_score:
return average, stack(sentence_level_scores)
return average