File size: 527 Bytes
e0be88b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertModel
class RobertaEmbeddings(BertEmbeddings):
def __init__(self, config):
super().__init__(config)
self.pad_token_id = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, config.pad_token_id
)
class RobertaModel(BertModel):
def __init__(self, config, add_pooling_layer=True):
super().__init__(self, config)
|