Spaces:
Running
on
Zero
Running
on
Zero
Initialize rope embeddings properly for the entropy model (#72)
Browse files
bytelatent/base_transformer.py
CHANGED
|
@@ -617,12 +617,8 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
|
|
| 617 |
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
| 618 |
return h
|
| 619 |
|
| 620 |
-
def reset_parameters(self):
|
| 621 |
-
# Either use fixed base std or sqrt model dim
|
| 622 |
-
self.rope_embeddings.reset_parameters()
|
| 623 |
-
|
| 624 |
def init_weights(self):
|
| 625 |
-
self.reset_parameters()
|
| 626 |
for depth, layer in enumerate(self.layers):
|
| 627 |
factor = {
|
| 628 |
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
|
|
|
| 617 |
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
| 618 |
return h
|
| 619 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
def init_weights(self):
|
| 621 |
+
self.rope_embeddings.reset_parameters()
|
| 622 |
for depth, layer in enumerate(self.layers):
|
| 623 |
factor = {
|
| 624 |
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
bytelatent/transformer.py
CHANGED
|
@@ -116,10 +116,11 @@ class LMTransformer(BaseTransformer):
|
|
| 116 |
return logits
|
| 117 |
|
| 118 |
def reset_parameters(self, init_std=None):
|
| 119 |
-
# Either use fixed base std or sqrt model dim
|
| 120 |
-
super().reset_parameters()
|
| 121 |
-
init_std = init_std or (self.dim ** (-0.5))
|
| 122 |
self.norm.reset_parameters()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
nn.init.trunc_normal_(
|
| 124 |
self.tok_embeddings.weight,
|
| 125 |
mean=0.0,
|
|
@@ -127,6 +128,8 @@ class LMTransformer(BaseTransformer):
|
|
| 127 |
a=-3 * init_std,
|
| 128 |
b=3 * init_std,
|
| 129 |
)
|
|
|
|
|
|
|
| 130 |
if not self.weight_tying:
|
| 131 |
nn.init.trunc_normal_(
|
| 132 |
self.output.weight,
|
|
|
|
| 116 |
return logits
|
| 117 |
|
| 118 |
def reset_parameters(self, init_std=None):
|
|
|
|
|
|
|
|
|
|
| 119 |
self.norm.reset_parameters()
|
| 120 |
+
|
| 121 |
+
def init_weights(self):
|
| 122 |
+
self.reset_parameters()
|
| 123 |
+
init_std = self.dim ** (-0.5)
|
| 124 |
nn.init.trunc_normal_(
|
| 125 |
self.tok_embeddings.weight,
|
| 126 |
mean=0.0,
|
|
|
|
| 128 |
a=-3 * init_std,
|
| 129 |
b=3 * init_std,
|
| 130 |
)
|
| 131 |
+
super().init_weights()
|
| 132 |
+
|
| 133 |
if not self.weight_tying:
|
| 134 |
nn.init.trunc_normal_(
|
| 135 |
self.output.weight,
|