shivansarora's picture
Update controllable_blender/generation_methods.py
e53591d verified
import math
from operator import attrgetter
from typing import Callable
import numpy as np
import regex
from scipy.stats import rankdata
import torch
from parlai.core.torch_generator_agent import TopKSampling, TreeSearch, _HypothesisTail, _PathSelection
from parlai.utils.torch import neginf
from .generation_utils import Reranker, Wordlist
class VocabTopKSampling(TopKSampling):
def __init__(self,
k: int,
wordlist: Wordlist,
*args, **kwargs):
super().__init__(k=k, *args, **kwargs)
self.k = k
self.wordlist = wordlist
def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection:
"""
Select the next vocabulary item in these beams.
"""
if len(self.all_scores) > 1:
for hypid in range(self.beam_size):
allowed_ids = self.wordlist.get_allowed_ids(self.partial_hyps[hypid])
neginf_assign = torch.ones(logprobs.shape[1], dtype=bool)
neginf_assign[allowed_ids] = False
logprobs[hypid, neginf_assign] = neginf(logprobs.dtype)
return super().select_paths(logprobs, prior_scores, current_length)
class RerankedTopKSampling(TreeSearch):
def __init__(self,
k: int,
reranker: Reranker,
tokenids_to_text: Callable,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.k = k
self.reranker = reranker
self.tokenids_to_text = tokenids_to_text
def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection:
"""
Select the next vocabulary item in these beams.
Adapted from top-k sampling https://github.com/facebookresearch/ParlAI/blob/054a0fff8183e357727dc7a91682496734badb7f/parlai/core/torch_generator_agent.py
"""
values, indices = logprobs.topk(self.k, dim=-1)
probs = torch.softmax(values, dim=-1)
batch_size = logprobs.size(0)
all_penalties = self.reranker.token_penalties.to(probs.device)
if all_penalties.dim() == 1:
all_penalties = all_penalties.unsqueeze(0)
all_penalties = all_penalties.expand(batch_size, 8008)
penalties = torch.gather(all_penalties, -1, indices)
penalised_probs = torch.mul(probs, penalties)
choices = torch.multinomial(penalised_probs, 1)[:, 0]
hyp_ids = torch.arange(logprobs.size(0)).to(logprobs.device)
tok_ids = indices[hyp_ids, choices]
scores = values[hyp_ids, choices]
best_scores = prior_scores.expand_as(scores) + scores
token_details: Optional[List[_PathSelectionTokenDetails]] = None
if self.verbose:
tok_logprobs = probs[hyp_ids, choices].log().view(-1).cpu().numpy()
tok_ranks = choices.view(-1).cpu().numpy()
token_details = []
for tok_logprob, tok_rank in zip(tok_logprobs, tok_ranks):
token_details.append(
{"token_logprob": tok_logprob, "token_rank": int(tok_rank)}
)
return _PathSelection(
hypothesis_ids=hyp_ids,
token_ids=tok_ids,
scores=best_scores,
token_details=token_details,
)
def get_rescored_finished(self, n_best=None):
"""
Adapted version of code taken from https://github.com/facebookresearch/ParlAI/blob/054a0fff8183e357727dc7a91682496734badb7f/parlai/core/torch_generator_agent.py
Adds complexity scoring and reranking.
Original description:
Return finished hypotheses according to adjusted scores.
Score adjustment is done according to the Google NMT paper, which
penalizes long utterances.
:param n_best:
number of finalized hypotheses to return
:return:
list of (tokens, score, token_metadata) 3-tuples, in sorted order, where:
- tokens is a tensor of token ids
- score is the adjusted log probability of the entire utterance
- token_metadata dictionary:
token_logprobs -> a tensor of conditional log probabilities of tokens
token_ranks -> a tensor of ranks of tokens in vocabulator, by probability, when sampled
"""
# if we never actually finished, force one
if not self.finished:
self.outputs[-1][0] = self.eos
self.finished.append(
_HypothesisTail(
timestep=len(self.outputs) - 1,
hypid=0,
score=self.all_scores[-1][0],
tokenid=self.outputs[-1][0],
token_details=self.token_details[0][-1]
if self.token_details is not None
else None,
)
)
# Calculate scores
hyps_str = []
length_penalties = []
for finished_item in self.finished:
token_ids = self._get_pretty_hypothesis(self._get_hyp_from_finished(finished_item))
hyps_str.append(self.tokenids_to_text(token_ids))
current_length = finished_item.timestep + 1
# these weights are from Google NMT paper
length_penalty = math.pow((1 + current_length) / 6, self.length_penalty)
length_penalties.append(length_penalty)
original_scores = []
for i, finished_item in enumerate(self.finished):
current_length = finished_item.timestep + 1
# these weights are from Google NMT paper
length_penalty = math.pow((1 + current_length) / 6, self.length_penalty)
original_scores.append(finished_item.score.cpu() / length_penalty)
complexity_scores = self.reranker.get_complexity_scores(hyps_str)
complexity_ranks = rankdata(complexity_scores)
original_ranks = rankdata(original_scores)
combined_ranks = complexity_ranks + original_ranks
rescored_finished = []
for i, finished_item in enumerate(self.finished):
score = combined_ranks[i]
if "u/" in hyps_str[i] or "r/" in hyps_str[i]: # Fix for Reddit language, see paper appendix
score = np.array(-1, dtype=combined_ranks.dtype)
if self.reranker.word_filter:
for word in regex.findall("(?<=[^\p{L}])\p{Ll}+", hyps_str[i]): # Find all non-capitalised words
if word not in self.reranker.word_filter:
score = np.array(-1, dtype=combined_ranks.dtype)
break
rescored_finished.append(
_HypothesisTail(
timestep=finished_item.timestep,
hypid=finished_item.hypid,
score=finished_item.score / length_penalty,
tokenid=finished_item.tokenid,
token_details=finished_item.token_details,
)
)
# Note: beam size is almost always pretty small, so sorting is cheap enough
srted = sorted(rescored_finished, key=attrgetter('score'), reverse=True)
if n_best is not None:
srted = srted[:n_best]
n_best_list = []
for hyp in srted:
hyp_data = self._get_hyp_from_finished(hyp)
token_ids = self._get_pretty_hypothesis(hyp_data)
token_metadata = (
[tok.token_details for tok in reversed(hyp_data)]
if self.verbose
else None
)
n_best_list.append((token_ids, hyp.score, token_metadata))
# check that there is at least one finished candidate
# and assert that each of them contains only one EOS
assert (
len(n_best_list) >= 1
), f'TreeSearch returned {len(n_best_list)} candidates, must be >= 1'
for (pred, score, _) in n_best_list:
assert (pred == self.eos).sum() == 1, (
f'TreeSearch returned a finalized hypo with multiple end tokens '
f'with score {score.item():.2f}'
)
return n_best_list