File size: 1,757 Bytes
cd123bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from typing import List, Dict
import nltk
import torch
from nltk.tokenize import word_tokenize
from src.data_utils.config import TextProcessorConfig
class TextProcessor:
"""
Main text preprocessor class
Args:
vocab: Vocabulary dictionary
config: Configuration object
"""
def __init__(self, vocab: Dict[str, int], config: TextProcessorConfig):
self.vocab = vocab
self.config = config
self._ensure_nltk_downloaded()
def _ensure_nltk_downloaded(self):
try:
word_tokenize("test")
except LookupError:
nltk.download("punkt")
def preprocess_text(self, text: str) -> List[str]:
"""
Tokenize and preprocess single text string
Args:
text: Your text
Returns:
List of preprocessed tokens
"""
if self.config.lowercase:
text = text.lower()
tokens = word_tokenize(text)
if self.config.remove_punct:
tokens = [t for t in tokens if t.isalpha()]
return tokens
def text_to_tensor(self, text: str) -> torch.Tensor:
"""
Convert raw text to tensor
Args:
text: Your text
Returns:
Tensor of your text
"""
tokens = self.preprocess_text(text)
ids = [self.vocab.get(token, self.vocab[self.config.unk_token]) for token in tokens]
# Pad or truncate
if len(ids) < self.config.max_seq_len:
ids = ids + [self.vocab[self.config.pad_token]] * (self.config.max_seq_len - len(ids))
else:
ids = ids[:self.config.max_seq_len]
return torch.tensor(ids, dtype=torch.long)
|