|
import time |
|
import warnings |
|
from abc import ABC |
|
from collections import OrderedDict |
|
from copy import deepcopy |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
from torch.nn import functional as F |
|
|
|
from ..pytorch_utils import isin_mps_friendly |
|
from ..tokenization_utils_base import PreTrainedTokenizerBase |
|
from ..utils import add_start_docstrings, logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
STOP_STRING_EMBEDDING_CACHE = OrderedDict() |
|
|
|
|
|
STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): |
|
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax |
|
or scores for each vocabulary token after SoftMax. If this stopping criteria depends on the `scores` input, |
|
make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. |
|
kwargs (`Dict[str, Any]`, *optional*): |
|
Additional stopping criteria specific kwargs. |
|
|
|
Return: |
|
`torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation |
|
for a particular row, `True` indicates we should continue. |
|
|
|
""" |
|
|
|
|
|
class StoppingCriteria(ABC): |
|
"""Abstract base class for all stopping criteria that can be applied during generation. |
|
|
|
If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True, |
|
output_scores=True` to `generate`. |
|
""" |
|
|
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: |
|
raise NotImplementedError("StoppingCriteria needs to be subclassed") |
|
|
|
|
|
class MaxLengthCriteria(StoppingCriteria): |
|
""" |
|
This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep |
|
in mind for decoder-only type of transformers, this will include the initial prompted tokens. |
|
|
|
Args: |
|
max_length (`int`): |
|
The maximum length that the output sequence can have in number of tokens. |
|
max_position_embeddings (`int`, *optional*): |
|
The maximum model length, as defined by the model's `config.max_position_embeddings` attribute. |
|
""" |
|
|
|
def __init__(self, max_length: int, max_position_embeddings: Optional[int] = None): |
|
self.max_length = max_length |
|
self.max_position_embeddings = max_position_embeddings |
|
|
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: |
|
cur_len = input_ids.shape[-1] |
|
is_done = cur_len >= self.max_length |
|
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings: |
|
logger.warning_once( |
|
"This is a friendly reminder - the current text generation call will exceed the model's predefined " |
|
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe " |
|
"exceptions, performance degradation, or nothing at all." |
|
) |
|
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) |
|
|
|
|
|
class MaxTimeCriteria(StoppingCriteria): |
|
""" |
|
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the |
|
time will start being counted when you initialize this function. You can override this by passing an |
|
`initial_time`. |
|
|
|
Args: |
|
max_time (`float`): |
|
The maximum allowed time in seconds for the generation. |
|
initial_time (`float`, *optional*, defaults to `time.time()`): |
|
The start of the generation allowed time. |
|
""" |
|
|
|
def __init__(self, max_time: float, initial_timestamp: Optional[float] = None): |
|
self.max_time = max_time |
|
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp |
|
|
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: |
|
is_done = time.time() - self.initial_timestamp > self.max_time |
|
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) |
|
|
|
|
|
class StopStringCriteria(StoppingCriteria): |
|
""" |
|
This class can be used to stop generation whenever specific string sequences are generated. It preprocesses |
|
the strings together with the tokenizer vocab to find positions where tokens can validly complete the stop strings. |
|
|
|
Generation is stopped as soon as a token is generated that completes any of the stop strings. |
|
We want to catch any instance in which the stop string would be present in the decoded output, which means |
|
we must also catch cases with "overhangs" off one or both ends. To make this more concrete, for the stop string |
|
"stop", any of the following token sequences would trigger the match: |
|
|
|
- ["st", "op"] |
|
- ["stop"] |
|
- ["st", "opera"] |
|
- ["sto", "pper"] |
|
- ["las", "topper"] |
|
- ["s", "to", "pped"] |
|
|
|
Note that a match will only be triggered if the stop string is at the end of the generated sequence. In other |
|
words, these sequences will not trigger a match: |
|
|
|
- ["stop", "at"] |
|
- ["st", "op", "at"] |
|
- ["st", "opera", "tion"] |
|
|
|
The reason these are not a match is that the stop string does not overlap with the final token. If you can remove |
|
one or more tokens from the end of the sequence without destroying the stop string, then this criterion will not |
|
match that stop string. This is by design; because this check is run after each token is generated, we can't miss a |
|
valid stop string if one is generated, but we don't want to halt generation just because the stop string exists |
|
somewhere in the past input_ids. |
|
|
|
How is the match actually performed, though? We do it in quite a confusing way, because we want the entire match |
|
process to be compilable with Torch or XLA, which means we cannot use standard string methods. However, it is possible, |
|
with some work, to do string matching with pure tensor operations. We'll begin by describing the algorithm we use |
|
with standard string operations, and then at the end we'll explain how this is converted to pure tensor operations. |
|
|
|
The key to the algorithm is an observation: Because the stop string must overlap with the end of the token sequence, we can start at |
|
the end of the sequence and work backwards. Specifically, we check that there is an overlap between the start of |
|
the final token and the end of the stop_string, or to put it another way, stop_string[-i:] == token[:i] for |
|
some i > 0. If you look at the positive examples above, you'll see the last token in all of them fulfills this |
|
property: |
|
|
|
- ["st", "op"] (overlap is "op", overlap length == 2) |
|
- ["stop"] (overlap is "stop", overlap length == 4) |
|
- ["st", "opera"] (overlap is "op", overlap length == 2) |
|
- ["sto", "pper"] (overlap is "p", overlap length == 1) |
|
- ["las", "topper"] (overlap is "top", overlap length == 3) |
|
- ["s", "to", "pped"] (overlap is "p", overlap length == 1) |
|
|
|
It's impossible to construct a matching sequence that does not have this property (feel free to verify this |
|
yourself). However, although this overlap between the start of the final token and the end of the stop string is |
|
necessary for a match, it is not sufficient. We also need to check that the rest of the token sequence is |
|
consistent with the stop string. |
|
|
|
How do we do that? Let's use ["s", "to", "pped"] as an example. We know that the final token, "pped", has an |
|
overlap of 1 with the stop string, "stop". We then go back to the previous token, "to". Since we have already |
|
matched 1 character from the stop string, the remainder to check is "sto". We check that the next token "to" |
|
matches the end of the remainder, which it does. We have now matched 3 characters from the stop string, and the |
|
remainder to match is "s". We go back to the previous token again, which is also "s". This is a match, and so |
|
we have matched the entire stop string. |
|
|
|
How does it work when the tokens run off the start of the stop string, though? Let's consider the example of |
|
["las", "topper"]. The final token, "topper", has an overlap of 3 with the stop string, "stop". Therefore, |
|
the remaining stop string to match is "s". We go back to the previous token, "las". Because the remainder to |
|
match is just "s", with length 1, we consider only the final 1 character from the token, which is "s". This |
|
matches the stop string, and so the entire string is matched. |
|
|
|
How do we compute these matches with tensor operations, though? Simply: we efficiently precompute the necessary |
|
information for all tokens! For every token, we compute: |
|
- Its overlap with the end of the stop string, if any |
|
- The positions inside the stop string where the token matches, including matches that run off the start. |
|
- The total length of the token |
|
|
|
For example, for the token "pped", we would compute an end overlap of 1, no internal matching positions, |
|
and a length of 4. For the token "to", we would compute no end overlap, a single internal matching position |
|
of 1 (counting from the end), and a length of 2. For the token "s", we would compute no end overlap, |
|
a single internal matching position of 3 (again counting from the end) and a length of 1. |
|
|
|
As long as we have this information, we can execute the algorithm above without any string comparison |
|
operations. We simply perform the following steps: |
|
- Check if the final token has an end-overlap with the start string |
|
- Continue backwards, keeping track of how much of the stop string we've matched so far |
|
- At each point, check if the next token has the current position as one of its valid positions |
|
- Continue until either a match fails, or we completely match the whole stop string |
|
|
|
Again, consider ["s", "to", "pped"] as an example. "pped" has an end overlap of 1, so we can begin a match. |
|
We have matched 1 character so far, so we check that the next token "to", has 1 as a valid position (again, |
|
counting from the end). It does, so we add the length of "to" to our position tracker. We have now matched |
|
3 characters, so we check that the next token "s" has 3 as a valid position. It does, so we add its length |
|
to the position tracker. The position tracker is now 4, which is the length of the stop string. We have matched the |
|
entire stop string. |
|
|
|
In the second case, ["las", "topper"], "topper" has an end overlap of 3, so we can begin a match. We have |
|
matched 3 characters so far, so we check that the next token "las" has 3 as a valid position. It does, because we |
|
allow tokens to match positions that run off the start of the stop string. We add its length to the position |
|
tracker. The position tracker is now 6, which is greater than the length of the stop string! Don't panic, though - |
|
this also counts as a match of the stop string. We have matched the entire stop string. |
|
|
|
|
|
Args: |
|
tokenizer (`PreTrainedTokenizer`): |
|
The model's associated tokenizer (necessary to extract vocab and tokenize the termination sequences) |
|
stop_strings (`Union[str, List[str]]`): |
|
A list of strings that should end generation. If a string is passed, it will be treated like a |
|
list with a single element. |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") |
|
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2") |
|
>>> inputs = tokenizer("The biggest states in the USA by land area:", return_tensors="pt") |
|
|
|
>>> gen_out = model.generate(**inputs) |
|
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]) |
|
The biggest states in the USA by land area: |
|
- Alaska |
|
- Texas |
|
- California |
|
|
|
>>> # Passing one or more stop strings will halt generation after those strings are emitted |
|
>>> # Note that generating with stop strings requires you to pass the tokenizer too |
|
>>> gen_out = model.generate(**inputs, stop_strings=["Texas"], tokenizer=tokenizer) |
|
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]) |
|
The biggest states in the USA by land area: |
|
- Alaska |
|
- Texas |
|
``` |
|
""" |
|
|
|
def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, List[str]]): |
|
if isinstance(stop_strings, str): |
|
stop_strings = [stop_strings] |
|
self.stop_strings: Tuple[str, ...] = tuple(stop_strings) |
|
vocab = tokenizer.get_vocab() |
|
token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values()) |
|
self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache( |
|
token_list, token_indices, self.stop_strings, tokenizer |
|
) |
|
|
|
self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings]) |
|
self.num_stop_strings = len(self.stop_strings) |
|
self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32) |
|
|
|
def clean_and_embed_tokens_with_cache(self, token_list, token_indices, stop_strings, tokenizer): |
|
|
|
if (token_list, token_indices, stop_strings) in STOP_STRING_EMBEDDING_CACHE: |
|
embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[ |
|
(token_list, token_indices, self.stop_strings) |
|
] |
|
STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, stop_strings)) |
|
else: |
|
clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer) |
|
embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec( |
|
clean_token_list, clean_token_indices, stop_strings |
|
) |
|
STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, stop_strings)] = ( |
|
embedding_vec, |
|
max_valid_positions, |
|
max_valid_end_lens, |
|
) |
|
if len(STOP_STRING_EMBEDDING_CACHE) > 8: |
|
STOP_STRING_EMBEDDING_CACHE.popitem(last=False) |
|
return embedding_vec, max_valid_positions, max_valid_end_lens |
|
|
|
@staticmethod |
|
def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"): |
|
""" |
|
This method turns a tokenizer vocab into a "clean" vocab where each token represents the actual string |
|
it will yield, without any special prefixes like "##" or "Δ ". This is trickier than it looks - the method |
|
tokenizer.convert_tokens_to_string() does not always return the correct string because of issues with prefix |
|
space addition/removal. To work around this, we add a static prefix to the start of the token, then remove |
|
it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string(). |
|
""" |
|
vocab = tokenizer.get_vocab() |
|
clean_token_list = [] |
|
clean_token_indices = [] |
|
sentence_base = tokenizer(static_prefix, add_special_tokens=False)["input_ids"] |
|
tokens_base = [tokenizer._convert_id_to_token(tok) for tok in sentence_base] |
|
for token, token_idx in vocab.items(): |
|
token_string = tokenizer.convert_tokens_to_string(tokens_base + [token]) |
|
token_string = token_string[token_string.index(static_prefix) + len(static_prefix) :] |
|
clean_token_list.append(token_string) |
|
clean_token_indices.append(token_idx) |
|
return tuple(clean_token_list), tuple(clean_token_indices) |
|
|
|
@staticmethod |
|
def _stop_string_get_matching_positions( |
|
token_list, token_indices, stop_strings |
|
) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: |
|
"""This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can |
|
validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the |
|
token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters |
|
from the end of the stop string that overlap with the start of the token, which can have more than one value. |
|
|
|
The reason for computing these may seem a bit cryptic - please see the docstring for StopStringCriteria for a full |
|
explanation of what these values are for!""" |
|
|
|
token_valid_positions = {} |
|
token_end_overlaps = {} |
|
for stop_string in stop_strings: |
|
reversed_stop_string = stop_string[::-1] |
|
token_valid_positions[stop_string] = {} |
|
token_end_overlaps[stop_string] = {} |
|
for token, tok_idx in zip(token_list, token_indices): |
|
reversed_token = token[::-1] |
|
matching_positions = [] |
|
possible_end_lengths = [] |
|
for i in range(1 - len(token), len(stop_string)): |
|
if i < 0: |
|
tok = reversed_token[-i:] |
|
i = 0 |
|
else: |
|
tok = reversed_token |
|
stop = reversed_stop_string[i : i + len(tok)] |
|
if tok.startswith(stop): |
|
if i == 0: |
|
possible_end_lengths.append(min(len(tok), len(stop))) |
|
else: |
|
matching_positions.append(i) |
|
|
|
if matching_positions: |
|
token_valid_positions[stop_string][tok_idx] = matching_positions |
|
if possible_end_lengths: |
|
token_end_overlaps[stop_string][tok_idx] = possible_end_lengths |
|
return token_valid_positions, token_end_overlaps |
|
|
|
@staticmethod |
|
def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) -> Dict[str, torch.tensor]: |
|
"""This function precomputes everything needed for the run-time checks in StopStringCriteria, and packs |
|
them into an embedding tensor that can be accessed with pure tensor operations. For the specifics of the values |
|
that are precomputed and what they are used for, please refer to the StopStringCriteria docstring!""" |
|
token_valid_positions, token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions( |
|
token_list, token_indices, stop_strings |
|
) |
|
all_valid_positions = [len(val) for positions in token_valid_positions.values() for val in positions.values()] |
|
|
|
|
|
max_valid_positions = max(all_valid_positions) if all_valid_positions else 1 |
|
|
|
valid_end_lens = [len(val) for positions in token_end_overlaps.values() for val in positions.values()] |
|
if not valid_end_lens: |
|
raise ValueError( |
|
"Stop string preprocessing was unable to identify tokens matching one or more of the " |
|
"supplied stop string(s). This is most often caused by the stop " |
|
"strings containing unusual characters that are not in the tokenizer vocabulary." |
|
) |
|
max_valid_end_lens = max(valid_end_lens) |
|
vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1 |
|
gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1) |
|
|
|
for i, stop_string in enumerate(stop_strings): |
|
positions = token_valid_positions[stop_string] |
|
end_lens = token_end_overlaps[stop_string] |
|
|
|
|
|
|
|
for token_idx, valid_positions in positions.items(): |
|
gather_vec[token_idx, max_valid_positions * i : max_valid_positions * i + len(valid_positions)] = ( |
|
valid_positions |
|
) |
|
for token_idx, possible_end_lens in end_lens.items(): |
|
gather_vec[ |
|
token_idx, |
|
max_valid_positions * len(stop_strings) + max_valid_end_lens * i : max_valid_positions |
|
* len(stop_strings) |
|
+ max_valid_end_lens * i |
|
+ len(possible_end_lens), |
|
] = possible_end_lens |
|
for token, token_idx in zip(token_list, token_indices): |
|
gather_vec[token_idx, -1] = len(token) |
|
|
|
gather_vec = torch.tensor(gather_vec, dtype=torch.int32) |
|
|
|
return gather_vec, max_valid_positions, max_valid_end_lens |
|
|
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.Tensor: |
|
self.embedding_vec = self.embedding_vec.to(input_ids.device) |
|
self.target_lens = self.target_lens.to(input_ids.device) |
|
|
|
|
|
input_ids = input_ids[:, -self.maximum_token_len :] |
|
|
|
|
|
flipped_ids = torch.flip(input_ids, (1,)) |
|
|
|
|
|
max_valid_positions = self.max_valid_positions |
|
|
|
|
|
embedded = F.embedding(flipped_ids, self.embedding_vec) |
|
|
|
|
|
valid_positions = embedded[:, 1:, : max_valid_positions * self.num_stop_strings].unflatten( |
|
-1, (self.num_stop_strings, -1) |
|
) |
|
|
|
|
|
end_lengths = embedded[:, :1, max_valid_positions * self.num_stop_strings : -1].unflatten( |
|
-1, (self.num_stop_strings, -1) |
|
) |
|
|
|
lengths = embedded[:, 1:, None, -1:] |
|
|
|
|
|
lengths = lengths.expand((-1, -1, end_lengths.shape[-2], end_lengths.shape[-1])) |
|
lengths_with_ends = torch.cat([end_lengths, lengths], dim=1) |
|
|
|
|
|
cumsum = lengths_with_ends.cumsum(dim=1) |
|
|
|
|
|
|
|
initial_match = end_lengths > 0 |
|
|
|
|
|
|
|
later_match = torch.any(cumsum[:, :-1, :, None] == valid_positions[:, :, :, :, None], axis=-2) |
|
|
|
|
|
match = torch.cat([initial_match, later_match], dim=1) |
|
|
|
|
|
mask = (~match).cumsum(dim=1, dtype=torch.int32) |
|
mask = mask == 0 |
|
|
|
|
|
|
|
string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= self.target_lens[None, :] |
|
|
|
|
|
return torch.any(string_matches, dim=-1) |
|
|
|
|
|
class EosTokenCriteria(StoppingCriteria): |
|
""" |
|
This class can be used to stop generation whenever the "end-of-sequence" token is generated. |
|
By default, it uses the `model.generation_config.eos_token_id`. |
|
|
|
Args: |
|
eos_token_id (`Union[int, List[int], torch.Tensor]`): |
|
The id(s) of the *end-of-sequence* token. |
|
""" |
|
|
|
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor]): |
|
if not isinstance(eos_token_id, torch.Tensor): |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
eos_token_id = torch.tensor(eos_token_id) |
|
self.eos_token_id = eos_token_id |
|
|
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: |
|
self.eos_token_id = self.eos_token_id.to(input_ids.device) |
|
is_done = isin_mps_friendly(input_ids[:, -1], self.eos_token_id) |
|
return is_done |
|
|
|
|
|
class ConfidenceCriteria(StoppingCriteria): |
|
""" |
|
This class can be used to stop generation whenever assistant model's confidence in its prediction for the current token is lower than the threshold |
|
`model.generation_config.assistant_confidence_threshold` even if the number of speculative tokens (defined by `num_assistant_tokens`) is not yet reached. |
|
|
|
Args: |
|
assistant_confidence_threshold (`float`): |
|
The value of the threshold. |
|
""" |
|
|
|
def __init__(self, assistant_confidence_threshold): |
|
self.assistant_confidence_threshold = assistant_confidence_threshold |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: |
|
probs = scores[-1].softmax(-1) |
|
p = probs[0, input_ids[0, -1]].item() |
|
if p < self.assistant_confidence_threshold: |
|
return True |
|
return False |
|
|
|
|
|
class StoppingCriteriaList(list): |
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: |
|
is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device, dtype=torch.bool) |
|
for criteria in self: |
|
is_done = is_done | criteria(input_ids, scores, **kwargs) |
|
return is_done |
|
|
|
@property |
|
def max_length(self) -> Optional[int]: |
|
for stopping_criterium in self: |
|
if isinstance(stopping_criterium, MaxLengthCriteria): |
|
return stopping_criterium.max_length |
|
return None |
|
|
|
|
|
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList: |
|
stopping_max_length = stopping_criteria.max_length |
|
new_stopping_criteria = deepcopy(stopping_criteria) |
|
if stopping_max_length is not None and stopping_max_length != max_length: |
|
warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) |
|
elif stopping_max_length is None: |
|
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) |
|
return new_stopping_criteria |
|
|