File size: 4,616 Bytes
3d18a82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from parlai.core.opt import Opt
from parlai.utils.typing import TShared
from parlai.agents.transformer.transformer import TransformerGeneratorAgent

from .generation_methods import VocabTopKSampling, RerankedTopKSampling
from .generation_utils import Wordlist, Reranker, load_wordlist, cefr_to_int

class ControllableBlender(TransformerGeneratorAgent):
    def __init__(self, opt: Opt, shared: TShared = None):
        super().__init__(opt, shared)

        if opt.get("inference", None) == "vocab":
            wordlist_path = opt.get("wordlist_path", None)
            assert wordlist_path, "Please provide path to vocab list, in order to use inference method 'vocab'"

            allowed_words = load_wordlist(wordlist_path)
            self.wordlist = Wordlist(allowed_words, self.dict)

        elif opt.get("inference", None) == "rerank":
            cefr = opt.get("rerank_cefr", None)
            assert cefr, "Please provide CEFR level, in order to use inference method 'rerank'"

            rerank_tokenizer = opt.get("rerank_tokenizer", None)
            rerank_model = opt.get("rerank_model", None)
            assert rerank_model, "Please provide path to directory containing model weights, in order to use inference method 'rerank'"

            device = opt.get("complexity_model_device", None)
            penalty_stddev = opt.get("penalty_stddev", None)
            text_truncate = opt.get("text_truncate", None)

            word_filter = None
            filter_path = opt.get("filter_path", "")
            if filter_path:
                word_filter = load_wordlist(filter_path)

            exempt_tokens = [self.dict.tok2ind.get(self.dict.null_token),
                             self.dict.tok2ind.get(self.dict.start_token),
                             self.dict.tok2ind.get(self.dict.end_token),
                             self.dict.tok2ind.get(self.dict.unk_token)]

            if penalty_stddev < 0:
                exempt_tokens = "all"

            self.reranker = Reranker(cefr=cefr_to_int(cefr),
                                     model=rerank_model,
                                     tokenizer=rerank_tokenizer,
                                     device=device,
                                     text_truncate=text_truncate,
                                     exempt_tokens=exempt_tokens,
                                     penalty_stddev=penalty_stddev,
                                     vocab_size=len(self.dict),
                                     word_filter=word_filter)

        else:
            raise ValueError(f"Inference method {opt.get('inference', None)} does not exist. "
                             f"Please use 'vocab' or 'rerank'.")


    def _treesearch_factory(self, device, verbose=False):
        method = self.opt.get('inference', 'greedy')
        beam_size = self.opt.get('beam_size', 1)
        if method == 'vocab':
            return VocabTopKSampling(
                k=self.opt.get('topk', 40),
                wordlist=self.wordlist,
                beam_size=beam_size,
                min_length=self.beam_min_length,
                block_ngram=self.beam_block_ngram,
                context_block_ngram=self.beam_context_block_ngram,
                length_penalty=self.opt.get('beam_length_penalty', 0.65),
                padding_token=self.NULL_IDX,
                bos_token=self.START_IDX,
                eos_token=self.END_IDX,
                device=device,
                verbose=verbose,
            )
        elif method == "rerank":
            return RerankedTopKSampling(
                k=self.opt.get('topk', 40),
                reranker=self.reranker,
                tokenids_to_text=self._v2t,
                beam_size=beam_size,
                min_length=self.beam_min_length,
                block_ngram=self.beam_block_ngram,
                context_block_ngram=self.beam_context_block_ngram,
                length_penalty=self.opt.get('beam_length_penalty', 0.65),
                padding_token=self.NULL_IDX,
                bos_token=self.START_IDX,
                eos_token=self.END_IDX,
                device=device,
                verbose=verbose,
            )
        else:
            return super()._treesearch_factory(device, verbose=verbose)

    def share(self):
        """
        Share internal states between parent and child instances.
        """
        shared = super().share()
        if hasattr(self, 'wordlist'):
            shared['wordlist'] = self.wordlist
        if hasattr(self, 'reranker'):
            shared['reranker'] = self.reranker
        return shared