Spaces:
Paused
Paused
Upload 3 files
Browse files- config.py +43 -0
- model.py +489 -0
- tokenizer.py +21 -0
config.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import torch
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class ModelArgs:
|
7 |
+
pre_trained_model_dir: str = None
|
8 |
+
fine_tuned_model_dir: str = None
|
9 |
+
epochs: int = 4
|
10 |
+
beta: float = 0.1
|
11 |
+
block_size: int = 256
|
12 |
+
batch_size: int = 128
|
13 |
+
inference = None
|
14 |
+
embeddings_dims: int = 512
|
15 |
+
attn_dropout: float = 0.1
|
16 |
+
no_of_heads: int = 8
|
17 |
+
dropout: float = 0.1
|
18 |
+
val_epochs: int = 2
|
19 |
+
max_lr: float = 6e-4
|
20 |
+
no_of_decoder_layers: int = 16
|
21 |
+
weight_decay_optim: float = 0.1
|
22 |
+
beta_1: float = 0.9
|
23 |
+
beta_2: float = 0.95
|
24 |
+
clip: float = 1.0
|
25 |
+
device: str = 'cuda'
|
26 |
+
no_kv_heads: int = 2
|
27 |
+
vocab_size: int = 50304
|
28 |
+
eps: float = 1e-5
|
29 |
+
dtype: str = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
|
30 |
+
save_checkpoint_dir: str = "checkpoints"
|
31 |
+
prompt: str = "Once upon a time"
|
32 |
+
|
33 |
+
|
34 |
+
save_checkpoint_iter: int = 50
|
35 |
+
total_iters: int = 20000
|
36 |
+
eval_iters: int = 50
|
37 |
+
eval_check: int = 100
|
38 |
+
warmup_iters: int = 700
|
39 |
+
min_lr: float = 0.1 * max_lr
|
40 |
+
lr_decay_iters: int = 20000
|
41 |
+
total_batch_size: int = 524288
|
42 |
+
micro_batch_size: int = batch_size
|
43 |
+
gradient_accumulation_steps: int = total_batch_size // (micro_batch_size * (block_size * torch.cuda.device_count()))
|
model.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from config import ModelArgs
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class Normalization(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
|
12 |
+
embeddings_dims: int = ModelArgs.embeddings_dims
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
self.rmsnorm_layer = torch.nn.RMSNorm(normalized_shape=embeddings_dims)
|
16 |
+
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
|
20 |
+
x = self.rmsnorm_layer(x)
|
21 |
+
return x
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
# import numpy as np
|
28 |
+
class RotaryEmbeddings(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
device,
|
32 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
33 |
+
block_size: int = ModelArgs.block_size,
|
34 |
+
batch_size: int = ModelArgs.batch_size
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.embeddings_dims = embeddings_dims
|
39 |
+
self.block_size = block_size
|
40 |
+
self.batch_size = batch_size
|
41 |
+
self.theta = 0
|
42 |
+
self.device=device
|
43 |
+
|
44 |
+
# self.d_model = embeddings_dims
|
45 |
+
# self.i = torch.arange(0, embeddings_dims, dtype=torch.float32)
|
46 |
+
# # self.pos = torch.arange(0, block_size, dtype=torch.float32)
|
47 |
+
# self.exp = ((2 * self.i)) / self.d_model
|
48 |
+
# self.theta = 10000 ** self.exp
|
49 |
+
# # print(self.theta.shape)
|
50 |
+
# self.x_reshaped = torch.randn(batch_size, block_size, embeddings_dims,dtype=torch.float32, device=device)
|
51 |
+
|
52 |
+
# self.cos = torch.cos((self.i / self.theta))
|
53 |
+
# self.sin = torch.sin((self.i / self.theta))
|
54 |
+
|
55 |
+
# self.even = self.sin[::2]
|
56 |
+
# self.odd = self.cos[1::2]
|
57 |
+
|
58 |
+
# # self.block = torch.empty((odd.size(0) + even.size(0),), dtype=self.even.dtype)
|
59 |
+
# self.x_reshaped[..., : , ::2] = self.even
|
60 |
+
# self.x_reshaped[..., : , 1::2] = self.odd
|
61 |
+
|
62 |
+
|
63 |
+
def apply_rope(self, seq):
|
64 |
+
batch_size, seq_len, embeds_dims = seq.shape
|
65 |
+
# print(seq.shape)
|
66 |
+
# print(self.embeddings_dims)
|
67 |
+
# self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
|
68 |
+
|
69 |
+
positions = torch.arange(0 , embeds_dims, 2, dtype=torch.float32, device = self.device).unsqueeze(0)
|
70 |
+
# dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
|
71 |
+
theta = 10000 ** (-2 * (positions) / embeds_dims)
|
72 |
+
angles = positions * theta
|
73 |
+
angles = angles.expand(seq_len, -1) # because this thing needs to be applied to every sequence in the batch but with embeds dims halved
|
74 |
+
x_reshaped = seq.view(batch_size, seq_len, embeds_dims // 2, 2)
|
75 |
+
|
76 |
+
cos_angles = torch.cos(angles)
|
77 |
+
sin_angles = torch.sin(angles)
|
78 |
+
# print(cos_angles.shape)
|
79 |
+
# print(sin_angles.shape)
|
80 |
+
# print(x_reshaped.shape)
|
81 |
+
# indices = torch.arange(self.embeddings_dims, dtype=torch.int64, device = self.device)
|
82 |
+
|
83 |
+
out = torch.stack([x_reshaped[..., 0]*cos_angles - (x_reshaped[...,1] * sin_angles), x_reshaped[...,1] * cos_angles + x_reshaped[..., 0] * sin_angles], dim=-1)
|
84 |
+
out = out.view(batch_size, seq_len, embeds_dims)
|
85 |
+
return out
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
# print("X shape: ", x.shape)
|
89 |
+
# print("X is: ", x)
|
90 |
+
# B,T,C = x.shape
|
91 |
+
# print("MATRIX:",x)
|
92 |
+
# if(x > self.block_size or x < self.block_size):
|
93 |
+
# matrix = self.init_matrix(x)
|
94 |
+
# return matrix
|
95 |
+
# else:
|
96 |
+
# matrix = self.init_matrix(self.block_size)
|
97 |
+
|
98 |
+
# return matrix
|
99 |
+
# if(ModelArgs.inference):
|
100 |
+
res = self.apply_rope(x)
|
101 |
+
return res
|
102 |
+
# else:
|
103 |
+
# return self.x_reshaped
|
104 |
+
|
105 |
+
class RotaryAttentionHead(nn.Module):
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
device,
|
109 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
110 |
+
no_of_heads: int = ModelArgs.no_of_heads,
|
111 |
+
attn_dropout: int = ModelArgs.attn_dropout
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
self.head_size = embeddings_dims // no_of_heads
|
115 |
+
self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
116 |
+
self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
117 |
+
self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
118 |
+
self.rope = RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
|
119 |
+
self.dropout = nn.Dropout(p = attn_dropout)
|
120 |
+
self.device = device
|
121 |
+
def forward(self,x):
|
122 |
+
# print(x.shape)
|
123 |
+
# print("X is: ", x)
|
124 |
+
batch, block_size, embeddings_dims = x.shape
|
125 |
+
query = self.query(x)
|
126 |
+
# print(query)
|
127 |
+
key = self.key(x)
|
128 |
+
values = self.value(x)
|
129 |
+
# matrix = self.rotary_matrix(block_size)
|
130 |
+
rotary_q = self.rope(query)
|
131 |
+
rotary_k = self.rope(key)
|
132 |
+
|
133 |
+
# print(matrix.shape)
|
134 |
+
# print(query.shape)
|
135 |
+
masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
136 |
+
# rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
137 |
+
# rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
138 |
+
weights = rotary_q.permute(2,0,1) @ rotary_k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
|
139 |
+
weights_masked = weights.masked_fill(masked == 0, float('-inf'))
|
140 |
+
scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
|
141 |
+
scaled_weights = F.softmax(scaled_weights, dim=-1)
|
142 |
+
value = scaled_weights @ values
|
143 |
+
out = self.dropout(value)
|
144 |
+
return out
|
145 |
+
|
146 |
+
|
147 |
+
# # import numpy as np
|
148 |
+
# class RotaryEmbeddings(nn.Module):
|
149 |
+
# def __init__(
|
150 |
+
# self,
|
151 |
+
# device,
|
152 |
+
# embeddings_dims: int = ModelArgs.embeddings_dims,
|
153 |
+
# block_size: int = ModelArgs.block_size,
|
154 |
+
# batch_size: int = ModelArgs.batch_size
|
155 |
+
# ):
|
156 |
+
# super().__init__()
|
157 |
+
|
158 |
+
# self.embeddings_dims = embeddings_dims
|
159 |
+
# self.block_size = block_size
|
160 |
+
# self.batch_size = batch_size
|
161 |
+
# self.theta = 0
|
162 |
+
|
163 |
+
|
164 |
+
# # def init_matrix(self, seq_len):
|
165 |
+
# # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False)
|
166 |
+
# # for pos in range(seq_len):
|
167 |
+
# # for j in range(1, self.embeddings_dims // 2):
|
168 |
+
# # self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims)
|
169 |
+
# # self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta))
|
170 |
+
# # self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta))
|
171 |
+
# # self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta))
|
172 |
+
# # self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta))
|
173 |
+
# # return self.matrix
|
174 |
+
# self.device=device
|
175 |
+
|
176 |
+
# def init_matrix(self, seq_len):
|
177 |
+
# self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
|
178 |
+
|
179 |
+
# positions = torch.arange(0 , seq_len, 2, dtype=torch.float32, device = self.device).unsqueeze(1)
|
180 |
+
# # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
|
181 |
+
# theta = 10000 ** (-2 * (positions - 1) / self.embeddings_dims)
|
182 |
+
# angles = positions * theta
|
183 |
+
|
184 |
+
# cos_angles = torch.cos(angles)
|
185 |
+
# sin_angles = torch.sin(angles)
|
186 |
+
|
187 |
+
# indices = torch.arange(seq_len, dtype=torch.int64, device = self.device)
|
188 |
+
# # print(indices)
|
189 |
+
# # print(indices.shape)
|
190 |
+
# # print(indices[::2])
|
191 |
+
# even_indices = indices[::2]
|
192 |
+
# odd_indices = indices[1::2]
|
193 |
+
|
194 |
+
# self.matrix[:, even_indices, even_indices] = cos_angles
|
195 |
+
# self.matrix[:, odd_indices, odd_indices] = sin_angles
|
196 |
+
# self.matrix[:, odd_indices, even_indices] = -sin_angles
|
197 |
+
# self.matrix[:, even_indices, odd_indices] = cos_angles
|
198 |
+
|
199 |
+
# return self.matrix
|
200 |
+
|
201 |
+
# def forward(self, x):
|
202 |
+
# # B,T,C = x.shape
|
203 |
+
# # print("MATRIX:",x)
|
204 |
+
# if(x > self.block_size or x < self.block_size):
|
205 |
+
# matrix = self.init_matrix(x)
|
206 |
+
# return matrix
|
207 |
+
# else:
|
208 |
+
# matrix = self.init_matrix(self.block_size)
|
209 |
+
|
210 |
+
# return matrix
|
211 |
+
|
212 |
+
|
213 |
+
# class RotaryAttentionHead(nn.Module):
|
214 |
+
# def __init__(
|
215 |
+
# self,
|
216 |
+
# device,
|
217 |
+
# embeddings_dims: int = ModelArgs.embeddings_dims,
|
218 |
+
# no_of_heads: int = ModelArgs.no_of_heads,
|
219 |
+
# attn_dropout: int = ModelArgs.attn_dropout
|
220 |
+
# ):
|
221 |
+
# super().__init__()
|
222 |
+
# self.head_size = embeddings_dims // no_of_heads
|
223 |
+
# self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
224 |
+
# self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
225 |
+
# self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
226 |
+
# self.rotary_matrix = RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
|
227 |
+
# self.dropout = nn.Dropout(p = attn_dropout)
|
228 |
+
# self.device = device
|
229 |
+
# def forward(self,x):
|
230 |
+
# # print(x.shape)
|
231 |
+
# batch, block_size, embeddings_dims = x.shape
|
232 |
+
# query = self.query(x)
|
233 |
+
# # print(query)
|
234 |
+
# key = self.key(x)
|
235 |
+
# values = self.value(x)
|
236 |
+
# matrix = self.rotary_matrix(block_size)
|
237 |
+
|
238 |
+
# # print(matrix.shape)
|
239 |
+
# # print(query.shape)
|
240 |
+
# masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
241 |
+
# rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
242 |
+
# rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
243 |
+
# weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
|
244 |
+
# weights_masked = weights.masked_fill(masked == 0, float('-inf'))
|
245 |
+
# scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
|
246 |
+
# scaled_weights = F.softmax(scaled_weights, dim=-1)
|
247 |
+
# value = scaled_weights @ values
|
248 |
+
# out = self.dropout(value)
|
249 |
+
# return out
|
250 |
+
|
251 |
+
|
252 |
+
class MQA(nn.Module):
|
253 |
+
def __init__(
|
254 |
+
self,
|
255 |
+
device,
|
256 |
+
no_of_q_heads: int,
|
257 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
258 |
+
block_size: int = ModelArgs.block_size,
|
259 |
+
|
260 |
+
|
261 |
+
):
|
262 |
+
super().__init__()
|
263 |
+
|
264 |
+
|
265 |
+
# self.no_of_q_heads = no_of_heads // no_of_kv_heads
|
266 |
+
# self.no_of_q_heads = no_of_q_heads
|
267 |
+
self.no_of_kv_heads = 2 # I want to have a kv for each pair of query heads
|
268 |
+
self.head_size = embeddings_dims // no_of_q_heads
|
269 |
+
# self.kv_head_size = (embeddings_dims // self.no_of_kv_heads) * 2
|
270 |
+
self.rotary= RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
|
271 |
+
# self.rotary_k = RotaryEmbeddings(embeddings_dims=self.kv_head_size, device = device)
|
272 |
+
# self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False)
|
273 |
+
self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device)
|
274 |
+
self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device)
|
275 |
+
self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
|
276 |
+
self.linear_layer = nn.Linear(in_features=self.head_size * self.no_of_kv_heads, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device)
|
277 |
+
self.device = device
|
278 |
+
self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, device = self.device) for _ in range(self.no_of_kv_heads)])
|
279 |
+
|
280 |
+
def scaled_dot_product(self, q, k, v, block_size):
|
281 |
+
|
282 |
+
# masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
283 |
+
q = self.rotary(q)
|
284 |
+
masked_table = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
285 |
+
# rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
286 |
+
# rotary_key = matrix @ k.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
287 |
+
# print("Query: ", q.shape)
|
288 |
+
# print("Keys: ", k.shape)
|
289 |
+
# print(q.permute(2,0,1).shape)
|
290 |
+
# print(k.permute(2,0,1).transpose(-2, -1).shape)
|
291 |
+
# weights = q.permute(2,0,1) @ k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
|
292 |
+
# weights = q @ k.permute(2,1,0)
|
293 |
+
# print(weights.shape)
|
294 |
+
# print(masked.shape)
|
295 |
+
weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
|
296 |
+
masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
|
297 |
+
weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
|
298 |
+
weights_normalized = self.dropout(weights_normalized)
|
299 |
+
out = weights_normalized @ v
|
300 |
+
return out
|
301 |
+
|
302 |
+
def forward(self,x):
|
303 |
+
# print("MQA: ", x.shape)
|
304 |
+
batch, block_size, embeddings_dims = x.shape
|
305 |
+
|
306 |
+
# query = self.query(x)
|
307 |
+
# matrix = self.rotary_matrix(block_size)
|
308 |
+
|
309 |
+
|
310 |
+
key = self.key(x)
|
311 |
+
values = self.value(x)
|
312 |
+
# print("Keys: ", key.shape)
|
313 |
+
# print("Values: ", values.shape)
|
314 |
+
# rotary_value = self.rotary(values)
|
315 |
+
rotary_key = self.rotary(key)
|
316 |
+
multi_query_concat = torch.cat([self.scaled_dot_product(query(x), rotary_key, values, block_size) for query in self.multi_query], dim=-1)
|
317 |
+
# print("Multi query: ", multi_query_concat.shape)
|
318 |
+
|
319 |
+
linear_layer= self.linear_layer(multi_query_concat)
|
320 |
+
# out = self.dropout(linear_layer)
|
321 |
+
return linear_layer
|
322 |
+
|
323 |
+
|
324 |
+
class GQA(nn.Module):
|
325 |
+
def __init__(
|
326 |
+
self,
|
327 |
+
device,
|
328 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
329 |
+
block_size: int = ModelArgs.block_size,
|
330 |
+
# no_of_q_heads: int = ModelArgs.no_of_heads,
|
331 |
+
mqa_heads: int = ModelArgs.no_kv_heads
|
332 |
+
):
|
333 |
+
super().__init__()
|
334 |
+
|
335 |
+
# self.no_of_kv_heads = no_of_kv_heads
|
336 |
+
self.no_of_q_heads = ModelArgs.no_of_heads // mqa_heads
|
337 |
+
# self.head_dim = embeddings_dims // self.no_kv_heads
|
338 |
+
self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
|
339 |
+
self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_q_heads, out_features=embeddings_dims , dtype=torch.float32, bias=False, device = device)
|
340 |
+
self.device = device
|
341 |
+
self.mqa = nn.ModuleList([MQA(no_of_q_heads=self.no_of_q_heads, embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_q_heads)])
|
342 |
+
# self.mqa = MQA(no_of_q_heads=self.no_of_q_heads, device=self.device, embeddings_dims=embeddings_dims, block_size=block_size)
|
343 |
+
def forward(self,x):
|
344 |
+
|
345 |
+
batch, block_size, embeddings_dims = x.shape
|
346 |
+
|
347 |
+
# res = self.mqa(x)
|
348 |
+
grouped_query_concat = torch.cat([group(x) for group in self.mqa], dim=-1)
|
349 |
+
|
350 |
+
linear_layer= self.linear_layer(grouped_query_concat) #Basically MQA is made into GQA with no_of_q_heads and this class right here is just to consolidate everything into one
|
351 |
+
out = self.dropout(linear_layer)
|
352 |
+
return out
|
353 |
+
|
354 |
+
|
355 |
+
class Swish(nn.Module):
|
356 |
+
def __init__(
|
357 |
+
self,
|
358 |
+
device,
|
359 |
+
block_size: int = ModelArgs.block_size,
|
360 |
+
embeddings_dims: int = ModelArgs.embeddings_dims
|
361 |
+
):
|
362 |
+
super().__init__()
|
363 |
+
|
364 |
+
self.sig = torch.nn.Sigmoid()
|
365 |
+
|
366 |
+
|
367 |
+
def forward(self, x):
|
368 |
+
swish = x * self.sig(x)
|
369 |
+
|
370 |
+
return swish
|
371 |
+
|
372 |
+
|
373 |
+
|
374 |
+
class SWiGLU(nn.Module):
|
375 |
+
def __init__(
|
376 |
+
self,
|
377 |
+
device,
|
378 |
+
block_size: int = ModelArgs.block_size,
|
379 |
+
embeddings_dims: int = ModelArgs.embeddings_dims
|
380 |
+
):
|
381 |
+
super().__init__()
|
382 |
+
self.hidden_dims = int(2 * ( 4 * embeddings_dims) / 3)
|
383 |
+
self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device)
|
384 |
+
self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device)
|
385 |
+
self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device)
|
386 |
+
self.linear_layer3 = nn.Linear(in_features=self.hidden_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
|
387 |
+
|
388 |
+
|
389 |
+
|
390 |
+
|
391 |
+
def forward(self, x):
|
392 |
+
swish_res = self.swish(self.linear_layer1(x))
|
393 |
+
x_V = self.linear_layer2(x)
|
394 |
+
res = torch.mul(swish_res, x_V)
|
395 |
+
out = self.linear_layer3(res)
|
396 |
+
return out
|
397 |
+
|
398 |
+
|
399 |
+
|
400 |
+
class FFN(nn.Module):
|
401 |
+
def __init__(self,
|
402 |
+
device,
|
403 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
404 |
+
block_size: int = ModelArgs.block_size,
|
405 |
+
vocab_size: int = ModelArgs.vocab_size,
|
406 |
+
dropout = ModelArgs.dropout
|
407 |
+
|
408 |
+
):
|
409 |
+
super().__init__()
|
410 |
+
|
411 |
+
# self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device)
|
412 |
+
self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims, device = device)
|
413 |
+
self.dropout = nn.Dropout(p = dropout)
|
414 |
+
def forward(self, x):
|
415 |
+
|
416 |
+
x = self.swiglue(x)
|
417 |
+
# x = self.linear_layer(x)
|
418 |
+
x = self.dropout(x)
|
419 |
+
return x
|
420 |
+
|
421 |
+
|
422 |
+
class DecoderLayer(nn.Module):
|
423 |
+
def __init__(self,
|
424 |
+
device,
|
425 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
426 |
+
dropout = ModelArgs.dropout,
|
427 |
+
block_size: int = ModelArgs.block_size,
|
428 |
+
vocab_size: int = ModelArgs.vocab_size,
|
429 |
+
|
430 |
+
) :
|
431 |
+
super().__init__()
|
432 |
+
|
433 |
+
|
434 |
+
self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device)
|
435 |
+
self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, mqa_heads=2, device = device)
|
436 |
+
# self.norm = Normalization(embeddings_dims=embeddings_dims)
|
437 |
+
self.norm1 = Normalization(embeddings_dims=embeddings_dims)
|
438 |
+
self.norm2 = Normalization(embeddings_dims=embeddings_dims)
|
439 |
+
self.dropout = nn.Dropout(p = dropout)
|
440 |
+
def forward(self, x):
|
441 |
+
|
442 |
+
x = x + self.gqa(self.norm1(x))
|
443 |
+
x = x + self.feedforward_network(self.norm2(x))
|
444 |
+
return x
|
445 |
+
|
446 |
+
|
447 |
+
class Llama(nn.Module):
|
448 |
+
def __init__(self,
|
449 |
+
device,
|
450 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
451 |
+
no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers,
|
452 |
+
block_size: int = ModelArgs.block_size,
|
453 |
+
vocab_size: int = ModelArgs.vocab_size,
|
454 |
+
dropout = ModelArgs.dropout
|
455 |
+
|
456 |
+
) :
|
457 |
+
super().__init__()
|
458 |
+
|
459 |
+
self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device)
|
460 |
+
self.decoder = nn.Sequential(*[DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers)])
|
461 |
+
self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device)
|
462 |
+
self.dropout = nn.Dropout(p = dropout)
|
463 |
+
# self.norm = Normalization(embeddings_dims)
|
464 |
+
|
465 |
+
|
466 |
+
#weight tying
|
467 |
+
self.embeddings.weight = self.linear_layer.weight
|
468 |
+
|
469 |
+
self.apply(self._init_weights)
|
470 |
+
|
471 |
+
def _init_weights(self, module):
|
472 |
+
if isinstance(module, nn.Linear):
|
473 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
474 |
+
|
475 |
+
if module.bias is not None:
|
476 |
+
nn.init.zeros_(module.bias)
|
477 |
+
elif isinstance(module, nn.Embedding):
|
478 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
479 |
+
|
480 |
+
|
481 |
+
|
482 |
+
def forward(self, x):
|
483 |
+
x = self.embeddings(x)
|
484 |
+
x = self.dropout(x)
|
485 |
+
x = self.decoder(x)
|
486 |
+
# x = self.norm(x)
|
487 |
+
x = self.linear_layer(x)
|
488 |
+
# out = self.norm(x)
|
489 |
+
return x
|
tokenizer.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class Tokenizer:
|
7 |
+
|
8 |
+
def __init__(self) -> None:
|
9 |
+
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", hf_token = '...')
|
11 |
+
|
12 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
13 |
+
|
14 |
+
def ready_tokenizer(self):
|
15 |
+
|
16 |
+
return self.tokenizer
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|