YuvrajSingh9886 commited on
Commit
e67a5e8
·
verified ·
1 Parent(s): 8f3cc16

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.py +43 -0
  2. model.py +489 -0
  3. tokenizer.py +21 -0
config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import dataclass
3
+ import torch
4
+
5
+ @dataclass
6
+ class ModelArgs:
7
+ pre_trained_model_dir: str = None
8
+ fine_tuned_model_dir: str = None
9
+ epochs: int = 4
10
+ beta: float = 0.1
11
+ block_size: int = 256
12
+ batch_size: int = 128
13
+ inference = None
14
+ embeddings_dims: int = 512
15
+ attn_dropout: float = 0.1
16
+ no_of_heads: int = 8
17
+ dropout: float = 0.1
18
+ val_epochs: int = 2
19
+ max_lr: float = 6e-4
20
+ no_of_decoder_layers: int = 16
21
+ weight_decay_optim: float = 0.1
22
+ beta_1: float = 0.9
23
+ beta_2: float = 0.95
24
+ clip: float = 1.0
25
+ device: str = 'cuda'
26
+ no_kv_heads: int = 2
27
+ vocab_size: int = 50304
28
+ eps: float = 1e-5
29
+ dtype: str = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
30
+ save_checkpoint_dir: str = "checkpoints"
31
+ prompt: str = "Once upon a time"
32
+
33
+
34
+ save_checkpoint_iter: int = 50
35
+ total_iters: int = 20000
36
+ eval_iters: int = 50
37
+ eval_check: int = 100
38
+ warmup_iters: int = 700
39
+ min_lr: float = 0.1 * max_lr
40
+ lr_decay_iters: int = 20000
41
+ total_batch_size: int = 524288
42
+ micro_batch_size: int = batch_size
43
+ gradient_accumulation_steps: int = total_batch_size // (micro_batch_size * (block_size * torch.cuda.device_count()))
model.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from config import ModelArgs
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class Normalization(nn.Module):
9
+ def __init__(
10
+ self,
11
+
12
+ embeddings_dims: int = ModelArgs.embeddings_dims
13
+ ):
14
+ super().__init__()
15
+ self.rmsnorm_layer = torch.nn.RMSNorm(normalized_shape=embeddings_dims)
16
+
17
+
18
+ def forward(self, x):
19
+
20
+ x = self.rmsnorm_layer(x)
21
+ return x
22
+
23
+
24
+
25
+
26
+
27
+ # import numpy as np
28
+ class RotaryEmbeddings(nn.Module):
29
+ def __init__(
30
+ self,
31
+ device,
32
+ embeddings_dims: int = ModelArgs.embeddings_dims,
33
+ block_size: int = ModelArgs.block_size,
34
+ batch_size: int = ModelArgs.batch_size
35
+ ):
36
+ super().__init__()
37
+
38
+ self.embeddings_dims = embeddings_dims
39
+ self.block_size = block_size
40
+ self.batch_size = batch_size
41
+ self.theta = 0
42
+ self.device=device
43
+
44
+ # self.d_model = embeddings_dims
45
+ # self.i = torch.arange(0, embeddings_dims, dtype=torch.float32)
46
+ # # self.pos = torch.arange(0, block_size, dtype=torch.float32)
47
+ # self.exp = ((2 * self.i)) / self.d_model
48
+ # self.theta = 10000 ** self.exp
49
+ # # print(self.theta.shape)
50
+ # self.x_reshaped = torch.randn(batch_size, block_size, embeddings_dims,dtype=torch.float32, device=device)
51
+
52
+ # self.cos = torch.cos((self.i / self.theta))
53
+ # self.sin = torch.sin((self.i / self.theta))
54
+
55
+ # self.even = self.sin[::2]
56
+ # self.odd = self.cos[1::2]
57
+
58
+ # # self.block = torch.empty((odd.size(0) + even.size(0),), dtype=self.even.dtype)
59
+ # self.x_reshaped[..., : , ::2] = self.even
60
+ # self.x_reshaped[..., : , 1::2] = self.odd
61
+
62
+
63
+ def apply_rope(self, seq):
64
+ batch_size, seq_len, embeds_dims = seq.shape
65
+ # print(seq.shape)
66
+ # print(self.embeddings_dims)
67
+ # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
68
+
69
+ positions = torch.arange(0 , embeds_dims, 2, dtype=torch.float32, device = self.device).unsqueeze(0)
70
+ # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
71
+ theta = 10000 ** (-2 * (positions) / embeds_dims)
72
+ angles = positions * theta
73
+ angles = angles.expand(seq_len, -1) # because this thing needs to be applied to every sequence in the batch but with embeds dims halved
74
+ x_reshaped = seq.view(batch_size, seq_len, embeds_dims // 2, 2)
75
+
76
+ cos_angles = torch.cos(angles)
77
+ sin_angles = torch.sin(angles)
78
+ # print(cos_angles.shape)
79
+ # print(sin_angles.shape)
80
+ # print(x_reshaped.shape)
81
+ # indices = torch.arange(self.embeddings_dims, dtype=torch.int64, device = self.device)
82
+
83
+ out = torch.stack([x_reshaped[..., 0]*cos_angles - (x_reshaped[...,1] * sin_angles), x_reshaped[...,1] * cos_angles + x_reshaped[..., 0] * sin_angles], dim=-1)
84
+ out = out.view(batch_size, seq_len, embeds_dims)
85
+ return out
86
+
87
+ def forward(self, x):
88
+ # print("X shape: ", x.shape)
89
+ # print("X is: ", x)
90
+ # B,T,C = x.shape
91
+ # print("MATRIX:",x)
92
+ # if(x > self.block_size or x < self.block_size):
93
+ # matrix = self.init_matrix(x)
94
+ # return matrix
95
+ # else:
96
+ # matrix = self.init_matrix(self.block_size)
97
+
98
+ # return matrix
99
+ # if(ModelArgs.inference):
100
+ res = self.apply_rope(x)
101
+ return res
102
+ # else:
103
+ # return self.x_reshaped
104
+
105
+ class RotaryAttentionHead(nn.Module):
106
+ def __init__(
107
+ self,
108
+ device,
109
+ embeddings_dims: int = ModelArgs.embeddings_dims,
110
+ no_of_heads: int = ModelArgs.no_of_heads,
111
+ attn_dropout: int = ModelArgs.attn_dropout
112
+ ):
113
+ super().__init__()
114
+ self.head_size = embeddings_dims // no_of_heads
115
+ self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
116
+ self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
117
+ self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
118
+ self.rope = RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
119
+ self.dropout = nn.Dropout(p = attn_dropout)
120
+ self.device = device
121
+ def forward(self,x):
122
+ # print(x.shape)
123
+ # print("X is: ", x)
124
+ batch, block_size, embeddings_dims = x.shape
125
+ query = self.query(x)
126
+ # print(query)
127
+ key = self.key(x)
128
+ values = self.value(x)
129
+ # matrix = self.rotary_matrix(block_size)
130
+ rotary_q = self.rope(query)
131
+ rotary_k = self.rope(key)
132
+
133
+ # print(matrix.shape)
134
+ # print(query.shape)
135
+ masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
136
+ # rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
137
+ # rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
138
+ weights = rotary_q.permute(2,0,1) @ rotary_k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
139
+ weights_masked = weights.masked_fill(masked == 0, float('-inf'))
140
+ scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
141
+ scaled_weights = F.softmax(scaled_weights, dim=-1)
142
+ value = scaled_weights @ values
143
+ out = self.dropout(value)
144
+ return out
145
+
146
+
147
+ # # import numpy as np
148
+ # class RotaryEmbeddings(nn.Module):
149
+ # def __init__(
150
+ # self,
151
+ # device,
152
+ # embeddings_dims: int = ModelArgs.embeddings_dims,
153
+ # block_size: int = ModelArgs.block_size,
154
+ # batch_size: int = ModelArgs.batch_size
155
+ # ):
156
+ # super().__init__()
157
+
158
+ # self.embeddings_dims = embeddings_dims
159
+ # self.block_size = block_size
160
+ # self.batch_size = batch_size
161
+ # self.theta = 0
162
+
163
+
164
+ # # def init_matrix(self, seq_len):
165
+ # # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False)
166
+ # # for pos in range(seq_len):
167
+ # # for j in range(1, self.embeddings_dims // 2):
168
+ # # self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims)
169
+ # # self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta))
170
+ # # self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta))
171
+ # # self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta))
172
+ # # self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta))
173
+ # # return self.matrix
174
+ # self.device=device
175
+
176
+ # def init_matrix(self, seq_len):
177
+ # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
178
+
179
+ # positions = torch.arange(0 , seq_len, 2, dtype=torch.float32, device = self.device).unsqueeze(1)
180
+ # # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
181
+ # theta = 10000 ** (-2 * (positions - 1) / self.embeddings_dims)
182
+ # angles = positions * theta
183
+
184
+ # cos_angles = torch.cos(angles)
185
+ # sin_angles = torch.sin(angles)
186
+
187
+ # indices = torch.arange(seq_len, dtype=torch.int64, device = self.device)
188
+ # # print(indices)
189
+ # # print(indices.shape)
190
+ # # print(indices[::2])
191
+ # even_indices = indices[::2]
192
+ # odd_indices = indices[1::2]
193
+
194
+ # self.matrix[:, even_indices, even_indices] = cos_angles
195
+ # self.matrix[:, odd_indices, odd_indices] = sin_angles
196
+ # self.matrix[:, odd_indices, even_indices] = -sin_angles
197
+ # self.matrix[:, even_indices, odd_indices] = cos_angles
198
+
199
+ # return self.matrix
200
+
201
+ # def forward(self, x):
202
+ # # B,T,C = x.shape
203
+ # # print("MATRIX:",x)
204
+ # if(x > self.block_size or x < self.block_size):
205
+ # matrix = self.init_matrix(x)
206
+ # return matrix
207
+ # else:
208
+ # matrix = self.init_matrix(self.block_size)
209
+
210
+ # return matrix
211
+
212
+
213
+ # class RotaryAttentionHead(nn.Module):
214
+ # def __init__(
215
+ # self,
216
+ # device,
217
+ # embeddings_dims: int = ModelArgs.embeddings_dims,
218
+ # no_of_heads: int = ModelArgs.no_of_heads,
219
+ # attn_dropout: int = ModelArgs.attn_dropout
220
+ # ):
221
+ # super().__init__()
222
+ # self.head_size = embeddings_dims // no_of_heads
223
+ # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
224
+ # self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
225
+ # self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
226
+ # self.rotary_matrix = RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
227
+ # self.dropout = nn.Dropout(p = attn_dropout)
228
+ # self.device = device
229
+ # def forward(self,x):
230
+ # # print(x.shape)
231
+ # batch, block_size, embeddings_dims = x.shape
232
+ # query = self.query(x)
233
+ # # print(query)
234
+ # key = self.key(x)
235
+ # values = self.value(x)
236
+ # matrix = self.rotary_matrix(block_size)
237
+
238
+ # # print(matrix.shape)
239
+ # # print(query.shape)
240
+ # masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
241
+ # rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
242
+ # rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
243
+ # weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
244
+ # weights_masked = weights.masked_fill(masked == 0, float('-inf'))
245
+ # scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
246
+ # scaled_weights = F.softmax(scaled_weights, dim=-1)
247
+ # value = scaled_weights @ values
248
+ # out = self.dropout(value)
249
+ # return out
250
+
251
+
252
+ class MQA(nn.Module):
253
+ def __init__(
254
+ self,
255
+ device,
256
+ no_of_q_heads: int,
257
+ embeddings_dims: int = ModelArgs.embeddings_dims,
258
+ block_size: int = ModelArgs.block_size,
259
+
260
+
261
+ ):
262
+ super().__init__()
263
+
264
+
265
+ # self.no_of_q_heads = no_of_heads // no_of_kv_heads
266
+ # self.no_of_q_heads = no_of_q_heads
267
+ self.no_of_kv_heads = 2 # I want to have a kv for each pair of query heads
268
+ self.head_size = embeddings_dims // no_of_q_heads
269
+ # self.kv_head_size = (embeddings_dims // self.no_of_kv_heads) * 2
270
+ self.rotary= RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
271
+ # self.rotary_k = RotaryEmbeddings(embeddings_dims=self.kv_head_size, device = device)
272
+ # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False)
273
+ self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device)
274
+ self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device)
275
+ self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
276
+ self.linear_layer = nn.Linear(in_features=self.head_size * self.no_of_kv_heads, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device)
277
+ self.device = device
278
+ self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, device = self.device) for _ in range(self.no_of_kv_heads)])
279
+
280
+ def scaled_dot_product(self, q, k, v, block_size):
281
+
282
+ # masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
283
+ q = self.rotary(q)
284
+ masked_table = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
285
+ # rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
286
+ # rotary_key = matrix @ k.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
287
+ # print("Query: ", q.shape)
288
+ # print("Keys: ", k.shape)
289
+ # print(q.permute(2,0,1).shape)
290
+ # print(k.permute(2,0,1).transpose(-2, -1).shape)
291
+ # weights = q.permute(2,0,1) @ k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
292
+ # weights = q @ k.permute(2,1,0)
293
+ # print(weights.shape)
294
+ # print(masked.shape)
295
+ weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
296
+ masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
297
+ weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
298
+ weights_normalized = self.dropout(weights_normalized)
299
+ out = weights_normalized @ v
300
+ return out
301
+
302
+ def forward(self,x):
303
+ # print("MQA: ", x.shape)
304
+ batch, block_size, embeddings_dims = x.shape
305
+
306
+ # query = self.query(x)
307
+ # matrix = self.rotary_matrix(block_size)
308
+
309
+
310
+ key = self.key(x)
311
+ values = self.value(x)
312
+ # print("Keys: ", key.shape)
313
+ # print("Values: ", values.shape)
314
+ # rotary_value = self.rotary(values)
315
+ rotary_key = self.rotary(key)
316
+ multi_query_concat = torch.cat([self.scaled_dot_product(query(x), rotary_key, values, block_size) for query in self.multi_query], dim=-1)
317
+ # print("Multi query: ", multi_query_concat.shape)
318
+
319
+ linear_layer= self.linear_layer(multi_query_concat)
320
+ # out = self.dropout(linear_layer)
321
+ return linear_layer
322
+
323
+
324
+ class GQA(nn.Module):
325
+ def __init__(
326
+ self,
327
+ device,
328
+ embeddings_dims: int = ModelArgs.embeddings_dims,
329
+ block_size: int = ModelArgs.block_size,
330
+ # no_of_q_heads: int = ModelArgs.no_of_heads,
331
+ mqa_heads: int = ModelArgs.no_kv_heads
332
+ ):
333
+ super().__init__()
334
+
335
+ # self.no_of_kv_heads = no_of_kv_heads
336
+ self.no_of_q_heads = ModelArgs.no_of_heads // mqa_heads
337
+ # self.head_dim = embeddings_dims // self.no_kv_heads
338
+ self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
339
+ self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_q_heads, out_features=embeddings_dims , dtype=torch.float32, bias=False, device = device)
340
+ self.device = device
341
+ self.mqa = nn.ModuleList([MQA(no_of_q_heads=self.no_of_q_heads, embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_q_heads)])
342
+ # self.mqa = MQA(no_of_q_heads=self.no_of_q_heads, device=self.device, embeddings_dims=embeddings_dims, block_size=block_size)
343
+ def forward(self,x):
344
+
345
+ batch, block_size, embeddings_dims = x.shape
346
+
347
+ # res = self.mqa(x)
348
+ grouped_query_concat = torch.cat([group(x) for group in self.mqa], dim=-1)
349
+
350
+ linear_layer= self.linear_layer(grouped_query_concat) #Basically MQA is made into GQA with no_of_q_heads and this class right here is just to consolidate everything into one
351
+ out = self.dropout(linear_layer)
352
+ return out
353
+
354
+
355
+ class Swish(nn.Module):
356
+ def __init__(
357
+ self,
358
+ device,
359
+ block_size: int = ModelArgs.block_size,
360
+ embeddings_dims: int = ModelArgs.embeddings_dims
361
+ ):
362
+ super().__init__()
363
+
364
+ self.sig = torch.nn.Sigmoid()
365
+
366
+
367
+ def forward(self, x):
368
+ swish = x * self.sig(x)
369
+
370
+ return swish
371
+
372
+
373
+
374
+ class SWiGLU(nn.Module):
375
+ def __init__(
376
+ self,
377
+ device,
378
+ block_size: int = ModelArgs.block_size,
379
+ embeddings_dims: int = ModelArgs.embeddings_dims
380
+ ):
381
+ super().__init__()
382
+ self.hidden_dims = int(2 * ( 4 * embeddings_dims) / 3)
383
+ self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device)
384
+ self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device)
385
+ self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device)
386
+ self.linear_layer3 = nn.Linear(in_features=self.hidden_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
387
+
388
+
389
+
390
+
391
+ def forward(self, x):
392
+ swish_res = self.swish(self.linear_layer1(x))
393
+ x_V = self.linear_layer2(x)
394
+ res = torch.mul(swish_res, x_V)
395
+ out = self.linear_layer3(res)
396
+ return out
397
+
398
+
399
+
400
+ class FFN(nn.Module):
401
+ def __init__(self,
402
+ device,
403
+ embeddings_dims: int = ModelArgs.embeddings_dims,
404
+ block_size: int = ModelArgs.block_size,
405
+ vocab_size: int = ModelArgs.vocab_size,
406
+ dropout = ModelArgs.dropout
407
+
408
+ ):
409
+ super().__init__()
410
+
411
+ # self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device)
412
+ self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims, device = device)
413
+ self.dropout = nn.Dropout(p = dropout)
414
+ def forward(self, x):
415
+
416
+ x = self.swiglue(x)
417
+ # x = self.linear_layer(x)
418
+ x = self.dropout(x)
419
+ return x
420
+
421
+
422
+ class DecoderLayer(nn.Module):
423
+ def __init__(self,
424
+ device,
425
+ embeddings_dims: int = ModelArgs.embeddings_dims,
426
+ dropout = ModelArgs.dropout,
427
+ block_size: int = ModelArgs.block_size,
428
+ vocab_size: int = ModelArgs.vocab_size,
429
+
430
+ ) :
431
+ super().__init__()
432
+
433
+
434
+ self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device)
435
+ self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, mqa_heads=2, device = device)
436
+ # self.norm = Normalization(embeddings_dims=embeddings_dims)
437
+ self.norm1 = Normalization(embeddings_dims=embeddings_dims)
438
+ self.norm2 = Normalization(embeddings_dims=embeddings_dims)
439
+ self.dropout = nn.Dropout(p = dropout)
440
+ def forward(self, x):
441
+
442
+ x = x + self.gqa(self.norm1(x))
443
+ x = x + self.feedforward_network(self.norm2(x))
444
+ return x
445
+
446
+
447
+ class Llama(nn.Module):
448
+ def __init__(self,
449
+ device,
450
+ embeddings_dims: int = ModelArgs.embeddings_dims,
451
+ no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers,
452
+ block_size: int = ModelArgs.block_size,
453
+ vocab_size: int = ModelArgs.vocab_size,
454
+ dropout = ModelArgs.dropout
455
+
456
+ ) :
457
+ super().__init__()
458
+
459
+ self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device)
460
+ self.decoder = nn.Sequential(*[DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers)])
461
+ self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device)
462
+ self.dropout = nn.Dropout(p = dropout)
463
+ # self.norm = Normalization(embeddings_dims)
464
+
465
+
466
+ #weight tying
467
+ self.embeddings.weight = self.linear_layer.weight
468
+
469
+ self.apply(self._init_weights)
470
+
471
+ def _init_weights(self, module):
472
+ if isinstance(module, nn.Linear):
473
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
474
+
475
+ if module.bias is not None:
476
+ nn.init.zeros_(module.bias)
477
+ elif isinstance(module, nn.Embedding):
478
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
479
+
480
+
481
+
482
+ def forward(self, x):
483
+ x = self.embeddings(x)
484
+ x = self.dropout(x)
485
+ x = self.decoder(x)
486
+ # x = self.norm(x)
487
+ x = self.linear_layer(x)
488
+ # out = self.norm(x)
489
+ return x
tokenizer.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import AutoTokenizer
3
+ import os
4
+
5
+
6
+ class Tokenizer:
7
+
8
+ def __init__(self) -> None:
9
+
10
+ self.tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", hf_token = '...')
11
+
12
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
13
+
14
+ def ready_tokenizer(self):
15
+
16
+ return self.tokenizer
17
+
18
+
19
+
20
+
21
+