Spaces:
Running
on
Zero
Running
on
Zero
IceClear
commited on
Commit
·
aebdeba
1
Parent(s):
39245bd
restore norm
Browse files- configs_3b/main.yaml +3 -3
- configs_7b/main.yaml +2 -2
- models/dit_v2/normalization.py +1 -1
configs_3b/main.yaml
CHANGED
@@ -11,7 +11,7 @@ dit:
|
|
11 |
vid_in_channels: 33
|
12 |
vid_out_channels: 16
|
13 |
vid_dim: 2560
|
14 |
-
vid_out_norm:
|
15 |
txt_in_dim: 5120
|
16 |
txt_in_norm: fusedln
|
17 |
txt_dim: ${.vid_dim}
|
@@ -19,11 +19,11 @@ dit:
|
|
19 |
heads: 20
|
20 |
head_dim: 128 # llm-like
|
21 |
expand_ratio: 4
|
22 |
-
norm:
|
23 |
norm_eps: 1.0e-05
|
24 |
ada: single
|
25 |
qk_bias: False
|
26 |
-
qk_norm:
|
27 |
patch_size: [ 1,2,2 ]
|
28 |
num_layers: 32 # llm-like
|
29 |
mm_layers: 10
|
|
|
11 |
vid_in_channels: 33
|
12 |
vid_out_channels: 16
|
13 |
vid_dim: 2560
|
14 |
+
vid_out_norm: fusedrms
|
15 |
txt_in_dim: 5120
|
16 |
txt_in_norm: fusedln
|
17 |
txt_dim: ${.vid_dim}
|
|
|
19 |
heads: 20
|
20 |
head_dim: 128 # llm-like
|
21 |
expand_ratio: 4
|
22 |
+
norm: fusedrms
|
23 |
norm_eps: 1.0e-05
|
24 |
ada: single
|
25 |
qk_bias: False
|
26 |
+
qk_norm: fusedrms
|
27 |
patch_size: [ 1,2,2 ]
|
28 |
num_layers: 32 # llm-like
|
29 |
mm_layers: 10
|
configs_7b/main.yaml
CHANGED
@@ -17,12 +17,12 @@ dit:
|
|
17 |
heads: 24
|
18 |
head_dim: 128 # llm-like
|
19 |
expand_ratio: 4
|
20 |
-
norm:
|
21 |
norm_eps: 1e-5
|
22 |
ada: single
|
23 |
qk_bias: False
|
24 |
qk_rope: True
|
25 |
-
qk_norm:
|
26 |
patch_size: [ 1,2,2 ]
|
27 |
num_layers: 36 # llm-like
|
28 |
shared_mlp: False
|
|
|
17 |
heads: 24
|
18 |
head_dim: 128 # llm-like
|
19 |
expand_ratio: 4
|
20 |
+
norm: fusedrms
|
21 |
norm_eps: 1e-5
|
22 |
ada: single
|
23 |
qk_bias: False
|
24 |
qk_rope: True
|
25 |
+
qk_norm: fusedrms
|
26 |
patch_size: [ 1,2,2 ]
|
27 |
num_layers: 36 # llm-like
|
28 |
shared_mlp: False
|
models/dit_v2/normalization.py
CHANGED
@@ -30,7 +30,7 @@ def get_norm_layer(norm_type: Optional[str]) -> norm_layer_type:
|
|
30 |
return nn.LayerNorm(
|
31 |
normalized_shape=dim,
|
32 |
eps=eps,
|
33 |
-
elementwise_affine=
|
34 |
)
|
35 |
|
36 |
if norm_type == "rms":
|
|
|
30 |
return nn.LayerNorm(
|
31 |
normalized_shape=dim,
|
32 |
eps=eps,
|
33 |
+
elementwise_affine=elementwise_affine,
|
34 |
)
|
35 |
|
36 |
if norm_type == "rms":
|