File size: 7,231 Bytes
3d18a82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import math
import os
import json
from typing import List, Set, Union, Optional

import numpy as np
import torch
from parlai.core.dict import DictionaryAgent
from transformers import RobertaForSequenceClassification, RobertaTokenizer

def cefr_to_int(cefr: str) -> int:
    mapping = {
        "A1": 0,
        "A2": 1,
        "B1": 2,
        "B2": 3,
        "C1": 4,
        "C2": 5,
    }
    clean_cefr = cefr.upper().strip()
    assert clean_cefr in mapping, f"CEFR must be one of {list(mapping.keys())}, not {cefr}"

    return mapping[clean_cefr]


def load_wordlist(path: str) -> List[str]:
    """
    Load a list of words from a text file containing one word per line
    """
    vocab = []

    if not path:
        return vocab

    assert os.path.isfile(path)

    with open(path, 'r', encoding="utf-8") as vocab_file:
        for row in vocab_file:
            token = row.strip()
            vocab.append(token)

    return vocab


class Wordlist():
    def __init__(self, allowed_words: List[str], dict_agent: DictionaryAgent):
        self.dict_agent = dict_agent

        # Identify IDs that represent a word boundary and those that don't
        self.boundary_ids = []
        self.non_boundary_ids = []

        for idx, subtoken in dict_agent.ind2tok.items():
            if subtoken[0] == "\u0120" or not subtoken.isalpha():
                self.boundary_ids.append(idx)
            else:
                self.non_boundary_ids.append(idx)

        # Identify token ID sequences that are allowed words
        # Identify allowed continuations of sequences
        self.allowed_sequences = []
        self.allowed_continuations = {}
        for word in allowed_words:
            for word_variant in self._get_word_variants(word):
                token_ids = dict_agent.txt2vec(word_variant)
                self.allowed_sequences.append(repr(token_ids))

                for i, idx in enumerate(token_ids[1:]):
                    prefix = repr(token_ids[:i + 1])      # List represented as string for lookup
                    if prefix not in self.allowed_continuations:
                        self.allowed_continuations[prefix] = []
                    self.allowed_continuations[prefix].append(idx)

        self.allowed_sequences = set(self.allowed_sequences)


    def get_allowed_ids(self, token_ids: List[int]) -> List[int]:
        last_word = self._get_last_word(token_ids)
        continuation_ids = self._get_continuation_ids(last_word)

        return continuation_ids


    def _is_word(self, token_ids: List[int]) -> bool:
        """
        For a given sequence of token IDs, determine whether that sequence is a complete word
        """
        return (token_ids == [] or repr(token_ids) in self.allowed_sequences)


    def _get_continuation_ids(self, token_ids: List[int]) -> List[int]:
        """
        For a given sequence of last word token IDs, determine which token IDs the word can continue with
        """
        continuation_ids = []
        if repr(token_ids) in self.allowed_continuations:
            continuation_ids.extend(self.allowed_continuations[repr(token_ids)])

        if self._is_word(token_ids) or token_ids == []:
            continuation_ids.extend(self.boundary_ids)

        return continuation_ids


    def _get_last_word(self, token_ids: List[int]) -> List[int]:
        """
        Get the sequence of token IDs after the last word boundary.
        Assumes that a word boundary is denoted by punctuation or whitespace (Ġ).
        """
        for i in range(-1, -len(token_ids), -1):
            last_word = token_ids[i:]
            check_token = self.dict_agent[last_word[0]]

            if not check_token.isalpha():
                return last_word[1:]

            if check_token[0] == "Ġ":
                return last_word

        raise ValueError("Boundary token not found")


    def _get_word_variants(self, word: str) -> Set[str]:
        return {word, word.lower(), word.capitalize()}



class Reranker():
    def __init__(self,
                 cefr: int,
                 model: str,
                 tokenizer: str = "distilroberta-base",
                 device: Optional[str] = "cuda",
                 text_truncate: int = 128,
                 exempt_tokens: Union[str, List[int]] = "all",
                 penalty_stddev: int = 2,
                 vocab_size: int = 8008,
                 word_filter: Optional[List[str]] = None):

        self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer)
        self.model = RobertaForSequenceClassification.from_pretrained(model)
        self.model.to(device)
        self.device = device

        self.target_cefr = cefr
        self.text_truncate = text_truncate
        self.word_filter = word_filter

        cefr_filepath = os.path.join(os.path.dirname(__file__), 'tokens_by_cefr.json')
        with open(cefr_filepath, 'r') as cefr_file:
            token_cefrs = json.load(cefr_file)

        if exempt_tokens == "all" or penalty_stddev < 0:      # No penalties
            self.token_penalties = torch.tensor([[1] * vocab_size])
        else:
            # calculate penalties per CEFR level difference (0 = same CEFR)
            normal_dist = torch.distributions.normal.Normal(0, penalty_stddev)
            cefr_penalties = [math.exp(normal_dist.log_prob(torch.tensor(i))) for i in range(6)]

            token_penalties = []
            for i in range(vocab_size):
                if i in exempt_tokens:
                    token_penalties.append(cefr_penalties[0])

                elif str(i) in token_cefrs:
                    token_str, token_cefr = token_cefrs[str(i)]
                    penalty = cefr_penalties[int(token_cefr - self.target_cefr)]

                    if token_cefr <= self.target_cefr or not token_str.isalpha():         # ignore lower CEFR levels and punctuation/special tokens
                        penalty = cefr_penalties[0]

                    token_penalties.append(penalty)

                else:       # Assume highest CEFR level if we don't have an assigned CEFR level
                    token_penalties.append(cefr_penalties[int(5 - self.target_cefr)])

            self.token_penalties = torch.tensor([token_penalties])

    def get_complexity_scores(self, hyps: List[str]) -> np.ndarray:
        model_inputs = self.tokenizer(hyps,
                                      padding='max_length',
                                      truncation=True,
                                      max_length=self.text_truncate,
                                      return_tensors='pt',
                                      return_token_type_ids=True,
                                      return_attention_mask=True)

        model_output = self.model(input_ids=model_inputs["input_ids"].to(self.device),
                                  attention_mask=model_inputs["attention_mask"].to(self.device),
                                  token_type_ids=model_inputs["token_type_ids"].to(self.device))

        complexity_scores = model_output.logits.cpu().numpy().flatten()
        complexity_diffs = 5 - np.absolute(complexity_scores - self.target_cefr)      # reversed so that higher score = better

        return complexity_diffs