|
from typing import List, Dict |
|
|
|
import nltk |
|
import torch |
|
|
|
from nltk.tokenize import word_tokenize |
|
|
|
from src.data_utils.config import TextProcessorConfig |
|
|
|
|
|
class TextProcessor: |
|
""" |
|
Main text preprocessor class |
|
|
|
Args: |
|
vocab: Vocabulary dictionary |
|
config: Configuration object |
|
""" |
|
|
|
def __init__(self, vocab: Dict[str, int], config: TextProcessorConfig): |
|
self.vocab = vocab |
|
self.config = config |
|
self._ensure_nltk_downloaded() |
|
|
|
|
|
def _ensure_nltk_downloaded(self): |
|
try: |
|
word_tokenize("test") |
|
except LookupError: |
|
nltk.download("punkt") |
|
nltk.download('punkt_tab') |
|
|
|
|
|
def preprocess_text(self, text: str) -> List[str]: |
|
""" |
|
Tokenize and preprocess single text string |
|
|
|
Args: |
|
text: Your text |
|
|
|
Returns: |
|
List of preprocessed tokens |
|
""" |
|
|
|
if self.config.lowercase: |
|
text = text.lower() |
|
|
|
tokens = word_tokenize(text) |
|
|
|
if self.config.remove_punct: |
|
tokens = [t for t in tokens if t.isalpha()] |
|
|
|
return tokens |
|
|
|
|
|
def text_to_tensor(self, text: str) -> torch.Tensor: |
|
""" |
|
Convert raw text to tensor |
|
|
|
Args: |
|
text: Your text |
|
|
|
Returns: |
|
Tensor of your text |
|
""" |
|
|
|
tokens = self.preprocess_text(text) |
|
ids = [self.vocab.get(token, self.vocab[self.config.unk_token]) for token in tokens] |
|
|
|
|
|
if len(ids) < self.config.max_seq_len: |
|
ids = ids + [self.vocab[self.config.pad_token]] * (self.config.max_seq_len - len(ids)) |
|
else: |
|
ids = ids[:self.config.max_seq_len] |
|
|
|
return torch.tensor(ids, dtype=torch.long) |
|
|