update script
Browse files- convert_flax.py +10 -23
convert_flax.py
CHANGED
@@ -127,20 +127,8 @@ def convert_to_hf(path: Path):
|
|
127 |
print(f"{num_layers=}")
|
128 |
print(f"{num_siglip_layers=}")
|
129 |
|
130 |
-
def load_params(*keys: tuple[str, ...], prefix: str | None = None):
|
131 |
-
# load params with specific keys and params starts with prefix
|
132 |
-
f1 = lambda k: tuple(subkey.key for subkey in k) in keys
|
133 |
-
f2 = lambda k: k[0].key.startswith(prefix)
|
134 |
-
|
135 |
-
# set to None to not load that weights
|
136 |
-
pytree = jax.tree.map_with_path(lambda k, v: v if f1(k) or f2(k) else None, metadata)
|
137 |
-
return ckpt.restore(path, pytree)
|
138 |
-
|
139 |
# NOTE: all gemma3 models use tied embeddings, even for the 27B version.
|
140 |
-
params =
|
141 |
-
("transformer/final_norm", "scale"),
|
142 |
-
prefix="transformer/embedder",
|
143 |
-
)
|
144 |
state_dict = dict()
|
145 |
|
146 |
if num_siglip_layers > 0:
|
@@ -164,7 +152,6 @@ def convert_to_hf(path: Path):
|
|
164 |
|
165 |
for layer_idx in range(num_layers):
|
166 |
jax_prefix = f"transformer/layer_{layer_idx}/"
|
167 |
-
params = load_params(prefix=jax_prefix)
|
168 |
|
169 |
state_dict = dict()
|
170 |
prefix = f"{gemma_prefix}model.layers.{layer_idx}."
|
@@ -200,7 +187,6 @@ def convert_to_hf(path: Path):
|
|
200 |
|
201 |
# vision tower
|
202 |
if num_siglip_layers > 0:
|
203 |
-
params = load_params(prefix=SIGLIP_PREFIX)
|
204 |
siglip_state_dict = convert_siglip(params, num_siglip_layers)
|
205 |
for k, v in siglip_state_dict.items():
|
206 |
state_dict[f"vision_tower.vision_model.{k}"] = v
|
@@ -272,21 +258,22 @@ if __name__ == "__main__":
|
|
272 |
filename = f"model-{shard_idx + 1:05d}.safetensors"
|
273 |
for sub_state_dict in tqdm(convert_to_hf(args.ckpt_dir)):
|
274 |
sub_state_dict = convert_awq(sub_state_dict)
|
|
|
275 |
|
276 |
-
|
277 |
-
state_dict[k] = v
|
278 |
-
size += v.nbytes
|
279 |
-
|
280 |
-
total_size += v.nbytes
|
281 |
-
weight_map[k] = filename
|
282 |
-
|
283 |
-
if size > 5e9:
|
284 |
save_file(state_dict, args.save_dir / filename)
|
285 |
state_dict = dict()
|
286 |
size = 0
|
287 |
shard_idx += 1
|
288 |
filename = f"model-{shard_idx + 1:05d}.safetensors"
|
289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
save_file(state_dict, args.save_dir / filename)
|
291 |
json.dump(
|
292 |
dict(metadata=dict(total_size=total_size), weight_map=weight_map),
|
|
|
127 |
print(f"{num_layers=}")
|
128 |
print(f"{num_siglip_layers=}")
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
# NOTE: all gemma3 models use tied embeddings, even for the 27B version.
|
131 |
+
params = ckpt.restore(path)
|
|
|
|
|
|
|
132 |
state_dict = dict()
|
133 |
|
134 |
if num_siglip_layers > 0:
|
|
|
152 |
|
153 |
for layer_idx in range(num_layers):
|
154 |
jax_prefix = f"transformer/layer_{layer_idx}/"
|
|
|
155 |
|
156 |
state_dict = dict()
|
157 |
prefix = f"{gemma_prefix}model.layers.{layer_idx}."
|
|
|
187 |
|
188 |
# vision tower
|
189 |
if num_siglip_layers > 0:
|
|
|
190 |
siglip_state_dict = convert_siglip(params, num_siglip_layers)
|
191 |
for k, v in siglip_state_dict.items():
|
192 |
state_dict[f"vision_tower.vision_model.{k}"] = v
|
|
|
258 |
filename = f"model-{shard_idx + 1:05d}.safetensors"
|
259 |
for sub_state_dict in tqdm(convert_to_hf(args.ckpt_dir)):
|
260 |
sub_state_dict = convert_awq(sub_state_dict)
|
261 |
+
new_size = sum(v.nbytes for v in sub_state_dict.values())
|
262 |
|
263 |
+
if size + new_size > 5e9:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
save_file(state_dict, args.save_dir / filename)
|
265 |
state_dict = dict()
|
266 |
size = 0
|
267 |
shard_idx += 1
|
268 |
filename = f"model-{shard_idx + 1:05d}.safetensors"
|
269 |
|
270 |
+
# assume that new_size < 5e9
|
271 |
+
size += new_size
|
272 |
+
total_size += new_size
|
273 |
+
for k, v in sub_state_dict.items():
|
274 |
+
state_dict[k] = v
|
275 |
+
weight_map[k] = filename
|
276 |
+
|
277 |
save_file(state_dict, args.save_dir / filename)
|
278 |
json.dump(
|
279 |
dict(metadata=dict(total_size=total_size), weight_map=weight_map),
|