File size: 8,193 Bytes
3d18a82 9bb3dab 3d18a82 504bc77 3d18a82 f5d3ac2 b12cada f5d3ac2 3d18a82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import math
from operator import attrgetter
from typing import Callable
import numpy as np
import regex
from scipy.stats import rankdata
import torch
from parlai.core.torch_generator_agent import TopKSampling, TreeSearch, _HypothesisTail, _PathSelection
from parlai.utils.torch import neginf
from .generation_utils import Reranker, Wordlist
class VocabTopKSampling(TopKSampling):
def __init__(self,
k: int,
wordlist: Wordlist,
*args, **kwargs):
super().__init__(k=k, *args, **kwargs)
self.k = k
self.wordlist = wordlist
def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection:
"""
Select the next vocabulary item in these beams.
"""
if len(self.all_scores) > 1:
for hypid in range(self.beam_size):
allowed_ids = self.wordlist.get_allowed_ids(self.partial_hyps[hypid])
neginf_assign = torch.ones(logprobs.shape[1], dtype=bool)
neginf_assign[allowed_ids] = False
logprobs[hypid, neginf_assign] = neginf(logprobs.dtype)
return super().select_paths(logprobs, prior_scores, current_length)
class RerankedTopKSampling(TreeSearch):
def __init__(self,
k: int,
reranker: Reranker,
tokenids_to_text: Callable,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.k = k
self.reranker = reranker
self.tokenids_to_text = tokenids_to_text
def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection:
"""
Select the next vocabulary item in these beams.
Adapted from top-k sampling https://github.com/facebookresearch/ParlAI/blob/054a0fff8183e357727dc7a91682496734badb7f/parlai/core/torch_generator_agent.py
"""
values, indices = logprobs.topk(self.k, dim=-1)
probs = torch.softmax(values, dim=-1)
all_penalties = self.reranker.token_penalties.to(probs.device)
if all_penalties.dim() == 1:
all_penalties = all_penalties.unsqueeze(0)
all_penalties = all_penalties.repeat(self.beam_size, 1)
penalties = torch.gather(all_penalties, -1, indices)
penalised_probs = torch.mul(probs, penalties)
choices = torch.multinomial(penalised_probs, 1)[:, 0]
hyp_ids = torch.arange(logprobs.size(0)).to(logprobs.device)
tok_ids = indices[hyp_ids, choices]
scores = values[hyp_ids, choices]
best_scores = prior_scores.expand_as(scores) + scores
token_details: Optional[List[_PathSelectionTokenDetails]] = None
if self.verbose:
tok_logprobs = probs[hyp_ids, choices].log().view(-1).cpu().numpy()
tok_ranks = choices.view(-1).cpu().numpy()
token_details = []
for tok_logprob, tok_rank in zip(tok_logprobs, tok_ranks):
token_details.append(
{"token_logprob": tok_logprob, "token_rank": int(tok_rank)}
)
return _PathSelection(
hypothesis_ids=hyp_ids,
token_ids=tok_ids,
scores=best_scores,
token_details=token_details,
)
def get_rescored_finished(self, n_best=None):
"""
Adapted version of code taken from https://github.com/facebookresearch/ParlAI/blob/054a0fff8183e357727dc7a91682496734badb7f/parlai/core/torch_generator_agent.py
Adds complexity scoring and reranking.
Original description:
Return finished hypotheses according to adjusted scores.
Score adjustment is done according to the Google NMT paper, which
penalizes long utterances.
:param n_best:
number of finalized hypotheses to return
:return:
list of (tokens, score, token_metadata) 3-tuples, in sorted order, where:
- tokens is a tensor of token ids
- score is the adjusted log probability of the entire utterance
- token_metadata dictionary:
token_logprobs -> a tensor of conditional log probabilities of tokens
token_ranks -> a tensor of ranks of tokens in vocabulator, by probability, when sampled
"""
# if we never actually finished, force one
if not self.finished:
self.outputs[-1][0] = self.eos
self.finished.append(
_HypothesisTail(
timestep=len(self.outputs) - 1,
hypid=0,
score=self.all_scores[-1][0],
tokenid=self.outputs[-1][0],
token_details=self.token_details[0][-1]
if self.token_details is not None
else None,
)
)
# Calculate scores
hyps_str = []
length_penalties = []
for finished_item in self.finished:
token_ids = self._get_pretty_hypothesis(self._get_hyp_from_finished(finished_item))
hyps_str.append(self.tokenids_to_text(token_ids))
current_length = finished_item.timestep + 1
# these weights are from Google NMT paper
length_penalty = math.pow((1 + current_length) / 6, self.length_penalty)
length_penalties.append(length_penalty)
original_scores = []
for i, finished_item in enumerate(self.finished):
current_length = finished_item.timestep + 1
# these weights are from Google NMT paper
length_penalty = math.pow((1 + current_length) / 6, self.length_penalty)
original_scores.append(finished_item.score.cpu() / length_penalty)
complexity_scores = self.reranker.get_complexity_scores(hyps_str)
complexity_ranks = rankdata(complexity_scores)
original_ranks = rankdata(original_scores)
combined_ranks = complexity_ranks + original_ranks
rescored_finished = []
for i, finished_item in enumerate(self.finished):
score = combined_ranks[i]
if "u/" in hyps_str[i] or "r/" in hyps_str[i]: # Fix for Reddit language, see paper appendix
score = np.array(-1, dtype=combined_ranks.dtype)
if self.reranker.word_filter:
for word in regex.findall("(?<=[^\p{L}])\p{Ll}+", hyps_str[i]): # Find all non-capitalised words
if word not in self.reranker.word_filter:
score = np.array(-1, dtype=combined_ranks.dtype)
break
rescored_finished.append(
_HypothesisTail(
timestep=finished_item.timestep,
hypid=finished_item.hypid,
score=finished_item.score / length_penalty,
tokenid=finished_item.tokenid,
token_details=finished_item.token_details,
)
)
# Note: beam size is almost always pretty small, so sorting is cheap enough
srted = sorted(rescored_finished, key=attrgetter('score'), reverse=True)
if n_best is not None:
srted = srted[:n_best]
n_best_list = []
for hyp in srted:
hyp_data = self._get_hyp_from_finished(hyp)
token_ids = self._get_pretty_hypothesis(hyp_data)
token_metadata = (
[tok.token_details for tok in reversed(hyp_data)]
if self.verbose
else None
)
n_best_list.append((token_ids, hyp.score, token_metadata))
# check that there is at least one finished candidate
# and assert that each of them contains only one EOS
assert (
len(n_best_list) >= 1
), f'TreeSearch returned {len(n_best_list)} candidates, must be >= 1'
for (pred, score, _) in n_best_list:
assert (pred == self.eos).sum() == 1, (
f'TreeSearch returned a finalized hypo with multiple end tokens '
f'with score {score.item():.2f}'
)
return n_best_list
|