from parlai.core.opt import Opt from parlai.utils.typing import TShared from parlai.agents.transformer.transformer import TransformerGeneratorAgent from .generation_methods import VocabTopKSampling, RerankedTopKSampling from .generation_utils import Wordlist, Reranker, load_wordlist, cefr_to_int class ControllableBlender(TransformerGeneratorAgent): def __init__(self, opt: Opt, shared: TShared = None): super().__init__(opt, shared) if opt.get("inference", None) == "vocab": wordlist_path = opt.get("wordlist_path", None) assert wordlist_path, "Please provide path to vocab list, in order to use inference method 'vocab'" allowed_words = load_wordlist(wordlist_path) self.wordlist = Wordlist(allowed_words, self.dict) elif opt.get("inference", None) == "rerank": cefr = opt.get("rerank_cefr", None) assert cefr, "Please provide CEFR level, in order to use inference method 'rerank'" rerank_tokenizer = opt.get("rerank_tokenizer", None) rerank_model = opt.get("rerank_model", None) assert rerank_model, "Please provide path to directory containing model weights, in order to use inference method 'rerank'" device = opt.get("complexity_model_device", None) penalty_stddev = opt.get("penalty_stddev", None) text_truncate = opt.get("text_truncate", None) word_filter = None filter_path = opt.get("filter_path", "") if filter_path: word_filter = load_wordlist(filter_path) exempt_tokens = [self.dict.tok2ind.get(self.dict.null_token), self.dict.tok2ind.get(self.dict.start_token), self.dict.tok2ind.get(self.dict.end_token), self.dict.tok2ind.get(self.dict.unk_token)] if penalty_stddev < 0: exempt_tokens = "all" self.reranker = Reranker(cefr=cefr_to_int(cefr), model=rerank_model, tokenizer=rerank_tokenizer, device=device, text_truncate=text_truncate, exempt_tokens=exempt_tokens, penalty_stddev=penalty_stddev, vocab_size=len(self.dict), word_filter=word_filter) else: raise ValueError(f"Inference method {opt.get('inference', None)} does not exist. " f"Please use 'vocab' or 'rerank'.") def _treesearch_factory(self, device, verbose=False): method = self.opt.get('inference', 'greedy') beam_size = self.opt.get('beam_size', 1) if method == 'vocab': return VocabTopKSampling( k=self.opt.get('topk', 40), wordlist=self.wordlist, beam_size=beam_size, min_length=self.beam_min_length, block_ngram=self.beam_block_ngram, context_block_ngram=self.beam_context_block_ngram, length_penalty=self.opt.get('beam_length_penalty', 0.65), padding_token=self.NULL_IDX, bos_token=self.START_IDX, eos_token=self.END_IDX, device=device, verbose=verbose, ) elif method == "rerank": return RerankedTopKSampling( k=self.opt.get('topk', 40), reranker=self.reranker, tokenids_to_text=self._v2t, beam_size=beam_size, min_length=self.beam_min_length, block_ngram=self.beam_block_ngram, context_block_ngram=self.beam_context_block_ngram, length_penalty=self.opt.get('beam_length_penalty', 0.65), padding_token=self.NULL_IDX, bos_token=self.START_IDX, eos_token=self.END_IDX, device=device, verbose=verbose, ) else: return super()._treesearch_factory(device, verbose=verbose) def share(self): """ Share internal states between parent and child instances. """ shared = super().share() if hasattr(self, 'wordlist'): shared['wordlist'] = self.wordlist if hasattr(self, 'reranker'): shared['reranker'] = self.reranker return shared