import torch.nn as nn | |
from transformers.models.bert.modeling_bert import BertEmbeddings, BertModel | |
class RobertaEmbeddings(BertEmbeddings): | |
def __init__(self, config): | |
super().__init__(config) | |
self.pad_token_id = config.pad_token_id | |
self.position_embeddings = nn.Embedding( | |
config.max_position_embeddings, config.hidden_size, config.pad_token_id | |
) | |
class RobertaModel(BertModel): | |
def __init__(self, config, add_pooling_layer=True): | |
super().__init__(self, config) | |