|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Script for training a Unigram tokenizer.""" |
|
|
|
import argparse |
|
import logging |
|
|
|
import datasets |
|
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors |
|
from tokenizers.models import Unigram |
|
from tokenizers.trainers import UnigramTrainer |
|
|
|
from transformers import AlbertTokenizerFast |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Train a unigram tokenizer on the wikitext dataset.") |
|
parser.add_argument( |
|
"--dataset_name", |
|
type=str, |
|
default="wikitext", |
|
help="Name of the training. Explore datasets at: hf.co/datasets.", |
|
) |
|
parser.add_argument( |
|
"--dataset_config", type=str, default="wikitext-103-raw-v1", help="Configuration name of the dataset." |
|
) |
|
parser.add_argument( |
|
"--trust_remote_code", |
|
action="store_true", |
|
help=( |
|
"Whether to trust the execution of code from datasets/models defined on the Hub." |
|
" This option should only be set to `True` for repositories you trust and in which you have read the" |
|
" code, as it will execute code present on the Hub on your local machine." |
|
), |
|
) |
|
parser.add_argument( |
|
"--batch_size", |
|
type=int, |
|
default=1000, |
|
help="Batch size during training.", |
|
) |
|
parser.add_argument( |
|
"--vocab_size", |
|
type=int, |
|
default=10048, |
|
help="Size of the desired vocabulary.", |
|
) |
|
parser.add_argument( |
|
"--limit", |
|
default=None, |
|
type=int, |
|
help="Limit the number of shards (used for debugging).", |
|
) |
|
parser.add_argument( |
|
"--export_to_hub", |
|
action="store_true", |
|
) |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(args): |
|
dataset = datasets.load_dataset( |
|
args.dataset_name, args.dataset_config, split="train", trust_remote_code=args.trust_remote_code |
|
) |
|
|
|
if args.limit is not None: |
|
max_train_samples = min(len(dataset), args.limit) |
|
dataset = dataset.select(range(max_train_samples)) |
|
logger.info(f"Limiting the dataset to {args.limit} entries.") |
|
|
|
def batch_iterator(): |
|
for i in range(0, len(dataset), args.batch_size): |
|
yield dataset[i : i + args.batch_size]["text"] |
|
|
|
|
|
tokenizer = Tokenizer(Unigram()) |
|
tokenizer.normalizer = normalizers.Sequence([normalizers.Replace("``", '"'), normalizers.Replace("''", '"')]) |
|
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace() |
|
|
|
|
|
trainer = UnigramTrainer( |
|
unk_token="<unk>", |
|
special_tokens=["[CLS]", "[SEP]", "<unk>", "<pad>", "[MASK]"], |
|
vocab_size=args.vocab_size, |
|
) |
|
|
|
logger.info("Training the tokenizer.") |
|
tokenizer.train_from_iterator(batch_iterator(), trainer=trainer) |
|
logger.info("Tokenizer training complete!") |
|
|
|
cls_token_id = tokenizer.token_to_id("[CLS]") |
|
sep_token_id = tokenizer.token_to_id("[SEP]") |
|
tokenizer.post_processor = processors.TemplateProcessing( |
|
single="[CLS]:0 $A:0 [SEP]:0", |
|
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", |
|
special_tokens=[ |
|
("[CLS]", cls_token_id), |
|
("[SEP]", sep_token_id), |
|
], |
|
) |
|
tokenizer.decoder = decoders.Metaspace() |
|
|
|
if args.export_to_hub: |
|
logger.info("Exporting the trained tokenizer to Hub.") |
|
new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer) |
|
new_tokenizer.push_to_hub("unigram-tokenizer-dataset") |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
main(args) |
|
|