import torch import torch.nn as nn import torch.nn.functional as F class TransformerBlock(nn.Module): def __init__(self, embed_size, heads, dropout, forward_expansion): super().__init__() self.attention = nn.MultiheadAttention(embed_size, heads) self.norm1 = nn.LayerNorm(embed_size) self.norm2 = nn.LayerNorm(embed_size) self.feed_forward = nn.Sequential( nn.Linear(embed_size, forward_expansion * embed_size), nn.ReLU(), nn.Linear(forward_expansion * embed_size, embed_size) ) self.dropout = nn.Dropout(dropout) def forward(self, value, key, query): attn_output, _ = self.attention(query, key, value) x = self.dropout(self.norm1(attn_output + query)) forward = self.feed_forward(x) out = self.dropout(self.norm2(forward + x)) return out class MathTransformer(nn.Module): def __init__(self, vocab_size, embed_size=128, num_layers=2, heads=4, forward_expansion=4, max_length=512, dropout=0.1): super().__init__() self.embed_size = embed_size self.word_embedding = nn.Embedding(vocab_size, embed_size) self.position_embedding = nn.Embedding(max_length, embed_size) self.layers = nn.ModuleList([ TransformerBlock(embed_size, heads, dropout, forward_expansion) for _ in range(num_layers) ]) self.fc_out = nn.Linear(embed_size, vocab_size) self.dropout = nn.Dropout(dropout) def forward(self, x): N, seq_length = x.shape positions = torch.arange(0, seq_length).expand(N, seq_length).to(x.device) out = self.dropout(self.word_embedding(x) + self.position_embedding(positions)) for layer in self.layers: out = layer(out, out, out) out = self.fc_out(out) return out