Spaces:
Runtime error
Runtime error
from typing import Dict | |
from transformers import PreTrainedTokenizer, PreTrainedModel | |
def smart_tokenizer_and_embedding_resize( | |
special_tokens_dict: Dict, | |
tokenizer: PreTrainedTokenizer, | |
model: PreTrainedModel, | |
): | |
"""Resize tokenizer and embedding. | |
Note: This is the unoptimized version that may make your embedding size not be divisible by 64. | |
""" | |
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) | |
model.resize_token_embeddings(len(tokenizer)) | |
if num_new_tokens > 0: | |
input_embeddings = model.get_input_embeddings().weight.data | |
output_embeddings = model.get_output_embeddings().weight.data | |
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( | |
dim=0, keepdim=True | |
) | |
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( | |
dim=0, keepdim=True | |
) | |
input_embeddings[-num_new_tokens:] = input_embeddings_avg | |
output_embeddings[-num_new_tokens:] = output_embeddings_avg | |