shivansarora's picture
Upload 14 files
3d18a82 verified
raw
history blame
7.23 kB
import math
import os
import json
from typing import List, Set, Union, Optional
import numpy as np
import torch
from parlai.core.dict import DictionaryAgent
from transformers import RobertaForSequenceClassification, RobertaTokenizer
def cefr_to_int(cefr: str) -> int:
mapping = {
"A1": 0,
"A2": 1,
"B1": 2,
"B2": 3,
"C1": 4,
"C2": 5,
}
clean_cefr = cefr.upper().strip()
assert clean_cefr in mapping, f"CEFR must be one of {list(mapping.keys())}, not {cefr}"
return mapping[clean_cefr]
def load_wordlist(path: str) -> List[str]:
"""
Load a list of words from a text file containing one word per line
"""
vocab = []
if not path:
return vocab
assert os.path.isfile(path)
with open(path, 'r', encoding="utf-8") as vocab_file:
for row in vocab_file:
token = row.strip()
vocab.append(token)
return vocab
class Wordlist():
def __init__(self, allowed_words: List[str], dict_agent: DictionaryAgent):
self.dict_agent = dict_agent
# Identify IDs that represent a word boundary and those that don't
self.boundary_ids = []
self.non_boundary_ids = []
for idx, subtoken in dict_agent.ind2tok.items():
if subtoken[0] == "\u0120" or not subtoken.isalpha():
self.boundary_ids.append(idx)
else:
self.non_boundary_ids.append(idx)
# Identify token ID sequences that are allowed words
# Identify allowed continuations of sequences
self.allowed_sequences = []
self.allowed_continuations = {}
for word in allowed_words:
for word_variant in self._get_word_variants(word):
token_ids = dict_agent.txt2vec(word_variant)
self.allowed_sequences.append(repr(token_ids))
for i, idx in enumerate(token_ids[1:]):
prefix = repr(token_ids[:i + 1]) # List represented as string for lookup
if prefix not in self.allowed_continuations:
self.allowed_continuations[prefix] = []
self.allowed_continuations[prefix].append(idx)
self.allowed_sequences = set(self.allowed_sequences)
def get_allowed_ids(self, token_ids: List[int]) -> List[int]:
last_word = self._get_last_word(token_ids)
continuation_ids = self._get_continuation_ids(last_word)
return continuation_ids
def _is_word(self, token_ids: List[int]) -> bool:
"""
For a given sequence of token IDs, determine whether that sequence is a complete word
"""
return (token_ids == [] or repr(token_ids) in self.allowed_sequences)
def _get_continuation_ids(self, token_ids: List[int]) -> List[int]:
"""
For a given sequence of last word token IDs, determine which token IDs the word can continue with
"""
continuation_ids = []
if repr(token_ids) in self.allowed_continuations:
continuation_ids.extend(self.allowed_continuations[repr(token_ids)])
if self._is_word(token_ids) or token_ids == []:
continuation_ids.extend(self.boundary_ids)
return continuation_ids
def _get_last_word(self, token_ids: List[int]) -> List[int]:
"""
Get the sequence of token IDs after the last word boundary.
Assumes that a word boundary is denoted by punctuation or whitespace (Ġ).
"""
for i in range(-1, -len(token_ids), -1):
last_word = token_ids[i:]
check_token = self.dict_agent[last_word[0]]
if not check_token.isalpha():
return last_word[1:]
if check_token[0] == "Ġ":
return last_word
raise ValueError("Boundary token not found")
def _get_word_variants(self, word: str) -> Set[str]:
return {word, word.lower(), word.capitalize()}
class Reranker():
def __init__(self,
cefr: int,
model: str,
tokenizer: str = "distilroberta-base",
device: Optional[str] = "cuda",
text_truncate: int = 128,
exempt_tokens: Union[str, List[int]] = "all",
penalty_stddev: int = 2,
vocab_size: int = 8008,
word_filter: Optional[List[str]] = None):
self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer)
self.model = RobertaForSequenceClassification.from_pretrained(model)
self.model.to(device)
self.device = device
self.target_cefr = cefr
self.text_truncate = text_truncate
self.word_filter = word_filter
cefr_filepath = os.path.join(os.path.dirname(__file__), 'tokens_by_cefr.json')
with open(cefr_filepath, 'r') as cefr_file:
token_cefrs = json.load(cefr_file)
if exempt_tokens == "all" or penalty_stddev < 0: # No penalties
self.token_penalties = torch.tensor([[1] * vocab_size])
else:
# calculate penalties per CEFR level difference (0 = same CEFR)
normal_dist = torch.distributions.normal.Normal(0, penalty_stddev)
cefr_penalties = [math.exp(normal_dist.log_prob(torch.tensor(i))) for i in range(6)]
token_penalties = []
for i in range(vocab_size):
if i in exempt_tokens:
token_penalties.append(cefr_penalties[0])
elif str(i) in token_cefrs:
token_str, token_cefr = token_cefrs[str(i)]
penalty = cefr_penalties[int(token_cefr - self.target_cefr)]
if token_cefr <= self.target_cefr or not token_str.isalpha(): # ignore lower CEFR levels and punctuation/special tokens
penalty = cefr_penalties[0]
token_penalties.append(penalty)
else: # Assume highest CEFR level if we don't have an assigned CEFR level
token_penalties.append(cefr_penalties[int(5 - self.target_cefr)])
self.token_penalties = torch.tensor([token_penalties])
def get_complexity_scores(self, hyps: List[str]) -> np.ndarray:
model_inputs = self.tokenizer(hyps,
padding='max_length',
truncation=True,
max_length=self.text_truncate,
return_tensors='pt',
return_token_type_ids=True,
return_attention_mask=True)
model_output = self.model(input_ids=model_inputs["input_ids"].to(self.device),
attention_mask=model_inputs["attention_mask"].to(self.device),
token_type_ids=model_inputs["token_type_ids"].to(self.device))
complexity_scores = model_output.logits.cpu().numpy().flatten()
complexity_diffs = 5 - np.absolute(complexity_scores - self.target_cefr) # reversed so that higher score = better
return complexity_diffs