|
from typing import Dict, Iterator, List, Optional, Union |
|
|
|
from tokenizers import AddedToken, Tokenizer, decoders, trainers |
|
from tokenizers.models import WordPiece |
|
from tokenizers.normalizers import BertNormalizer |
|
from tokenizers.pre_tokenizers import BertPreTokenizer |
|
from tokenizers.processors import BertProcessing |
|
|
|
from .base_tokenizer import BaseTokenizer |
|
|
|
|
|
class BertWordPieceTokenizer(BaseTokenizer): |
|
"""Bert WordPiece Tokenizer""" |
|
|
|
def __init__( |
|
self, |
|
vocab: Optional[Union[str, Dict[str, int]]] = None, |
|
unk_token: Union[str, AddedToken] = "[UNK]", |
|
sep_token: Union[str, AddedToken] = "[SEP]", |
|
cls_token: Union[str, AddedToken] = "[CLS]", |
|
pad_token: Union[str, AddedToken] = "[PAD]", |
|
mask_token: Union[str, AddedToken] = "[MASK]", |
|
clean_text: bool = True, |
|
handle_chinese_chars: bool = True, |
|
strip_accents: Optional[bool] = None, |
|
lowercase: bool = True, |
|
wordpieces_prefix: str = "##", |
|
): |
|
if vocab is not None: |
|
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(unk_token))) |
|
else: |
|
tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token))) |
|
|
|
|
|
if tokenizer.token_to_id(str(unk_token)) is not None: |
|
tokenizer.add_special_tokens([str(unk_token)]) |
|
if tokenizer.token_to_id(str(sep_token)) is not None: |
|
tokenizer.add_special_tokens([str(sep_token)]) |
|
if tokenizer.token_to_id(str(cls_token)) is not None: |
|
tokenizer.add_special_tokens([str(cls_token)]) |
|
if tokenizer.token_to_id(str(pad_token)) is not None: |
|
tokenizer.add_special_tokens([str(pad_token)]) |
|
if tokenizer.token_to_id(str(mask_token)) is not None: |
|
tokenizer.add_special_tokens([str(mask_token)]) |
|
|
|
tokenizer.normalizer = BertNormalizer( |
|
clean_text=clean_text, |
|
handle_chinese_chars=handle_chinese_chars, |
|
strip_accents=strip_accents, |
|
lowercase=lowercase, |
|
) |
|
tokenizer.pre_tokenizer = BertPreTokenizer() |
|
|
|
if vocab is not None: |
|
sep_token_id = tokenizer.token_to_id(str(sep_token)) |
|
if sep_token_id is None: |
|
raise TypeError("sep_token not found in the vocabulary") |
|
cls_token_id = tokenizer.token_to_id(str(cls_token)) |
|
if cls_token_id is None: |
|
raise TypeError("cls_token not found in the vocabulary") |
|
|
|
tokenizer.post_processor = BertProcessing((str(sep_token), sep_token_id), (str(cls_token), cls_token_id)) |
|
tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix) |
|
|
|
parameters = { |
|
"model": "BertWordPiece", |
|
"unk_token": unk_token, |
|
"sep_token": sep_token, |
|
"cls_token": cls_token, |
|
"pad_token": pad_token, |
|
"mask_token": mask_token, |
|
"clean_text": clean_text, |
|
"handle_chinese_chars": handle_chinese_chars, |
|
"strip_accents": strip_accents, |
|
"lowercase": lowercase, |
|
"wordpieces_prefix": wordpieces_prefix, |
|
} |
|
|
|
super().__init__(tokenizer, parameters) |
|
|
|
@staticmethod |
|
def from_file(vocab: str, **kwargs): |
|
vocab = WordPiece.read_file(vocab) |
|
return BertWordPieceTokenizer(vocab, **kwargs) |
|
|
|
def train( |
|
self, |
|
files: Union[str, List[str]], |
|
vocab_size: int = 30000, |
|
min_frequency: int = 2, |
|
limit_alphabet: int = 1000, |
|
initial_alphabet: List[str] = [], |
|
special_tokens: List[Union[str, AddedToken]] = [ |
|
"[PAD]", |
|
"[UNK]", |
|
"[CLS]", |
|
"[SEP]", |
|
"[MASK]", |
|
], |
|
show_progress: bool = True, |
|
wordpieces_prefix: str = "##", |
|
): |
|
"""Train the model using the given files""" |
|
|
|
trainer = trainers.WordPieceTrainer( |
|
vocab_size=vocab_size, |
|
min_frequency=min_frequency, |
|
limit_alphabet=limit_alphabet, |
|
initial_alphabet=initial_alphabet, |
|
special_tokens=special_tokens, |
|
show_progress=show_progress, |
|
continuing_subword_prefix=wordpieces_prefix, |
|
) |
|
if isinstance(files, str): |
|
files = [files] |
|
self._tokenizer.train(files, trainer=trainer) |
|
|
|
def train_from_iterator( |
|
self, |
|
iterator: Union[Iterator[str], Iterator[Iterator[str]]], |
|
vocab_size: int = 30000, |
|
min_frequency: int = 2, |
|
limit_alphabet: int = 1000, |
|
initial_alphabet: List[str] = [], |
|
special_tokens: List[Union[str, AddedToken]] = [ |
|
"[PAD]", |
|
"[UNK]", |
|
"[CLS]", |
|
"[SEP]", |
|
"[MASK]", |
|
], |
|
show_progress: bool = True, |
|
wordpieces_prefix: str = "##", |
|
length: Optional[int] = None, |
|
): |
|
"""Train the model using the given iterator""" |
|
|
|
trainer = trainers.WordPieceTrainer( |
|
vocab_size=vocab_size, |
|
min_frequency=min_frequency, |
|
limit_alphabet=limit_alphabet, |
|
initial_alphabet=initial_alphabet, |
|
special_tokens=special_tokens, |
|
show_progress=show_progress, |
|
continuing_subword_prefix=wordpieces_prefix, |
|
) |
|
self._tokenizer.train_from_iterator( |
|
iterator, |
|
trainer=trainer, |
|
length=length, |
|
) |
|
|