gaunernst commited on
Commit
6fc5e4e
·
1 Parent(s): 2130c75

update script

Browse files
Files changed (1) hide show
  1. convert_flax.py +10 -23
convert_flax.py CHANGED
@@ -127,20 +127,8 @@ def convert_to_hf(path: Path):
127
  print(f"{num_layers=}")
128
  print(f"{num_siglip_layers=}")
129
 
130
- def load_params(*keys: tuple[str, ...], prefix: str | None = None):
131
- # load params with specific keys and params starts with prefix
132
- f1 = lambda k: tuple(subkey.key for subkey in k) in keys
133
- f2 = lambda k: k[0].key.startswith(prefix)
134
-
135
- # set to None to not load that weights
136
- pytree = jax.tree.map_with_path(lambda k, v: v if f1(k) or f2(k) else None, metadata)
137
- return ckpt.restore(path, pytree)
138
-
139
  # NOTE: all gemma3 models use tied embeddings, even for the 27B version.
140
- params = load_params(
141
- ("transformer/final_norm", "scale"),
142
- prefix="transformer/embedder",
143
- )
144
  state_dict = dict()
145
 
146
  if num_siglip_layers > 0:
@@ -164,7 +152,6 @@ def convert_to_hf(path: Path):
164
 
165
  for layer_idx in range(num_layers):
166
  jax_prefix = f"transformer/layer_{layer_idx}/"
167
- params = load_params(prefix=jax_prefix)
168
 
169
  state_dict = dict()
170
  prefix = f"{gemma_prefix}model.layers.{layer_idx}."
@@ -200,7 +187,6 @@ def convert_to_hf(path: Path):
200
 
201
  # vision tower
202
  if num_siglip_layers > 0:
203
- params = load_params(prefix=SIGLIP_PREFIX)
204
  siglip_state_dict = convert_siglip(params, num_siglip_layers)
205
  for k, v in siglip_state_dict.items():
206
  state_dict[f"vision_tower.vision_model.{k}"] = v
@@ -272,21 +258,22 @@ if __name__ == "__main__":
272
  filename = f"model-{shard_idx + 1:05d}.safetensors"
273
  for sub_state_dict in tqdm(convert_to_hf(args.ckpt_dir)):
274
  sub_state_dict = convert_awq(sub_state_dict)
 
275
 
276
- for k, v in sub_state_dict.items():
277
- state_dict[k] = v
278
- size += v.nbytes
279
-
280
- total_size += v.nbytes
281
- weight_map[k] = filename
282
-
283
- if size > 5e9:
284
  save_file(state_dict, args.save_dir / filename)
285
  state_dict = dict()
286
  size = 0
287
  shard_idx += 1
288
  filename = f"model-{shard_idx + 1:05d}.safetensors"
289
 
 
 
 
 
 
 
 
290
  save_file(state_dict, args.save_dir / filename)
291
  json.dump(
292
  dict(metadata=dict(total_size=total_size), weight_map=weight_map),
 
127
  print(f"{num_layers=}")
128
  print(f"{num_siglip_layers=}")
129
 
 
 
 
 
 
 
 
 
 
130
  # NOTE: all gemma3 models use tied embeddings, even for the 27B version.
131
+ params = ckpt.restore(path)
 
 
 
132
  state_dict = dict()
133
 
134
  if num_siglip_layers > 0:
 
152
 
153
  for layer_idx in range(num_layers):
154
  jax_prefix = f"transformer/layer_{layer_idx}/"
 
155
 
156
  state_dict = dict()
157
  prefix = f"{gemma_prefix}model.layers.{layer_idx}."
 
187
 
188
  # vision tower
189
  if num_siglip_layers > 0:
 
190
  siglip_state_dict = convert_siglip(params, num_siglip_layers)
191
  for k, v in siglip_state_dict.items():
192
  state_dict[f"vision_tower.vision_model.{k}"] = v
 
258
  filename = f"model-{shard_idx + 1:05d}.safetensors"
259
  for sub_state_dict in tqdm(convert_to_hf(args.ckpt_dir)):
260
  sub_state_dict = convert_awq(sub_state_dict)
261
+ new_size = sum(v.nbytes for v in sub_state_dict.values())
262
 
263
+ if size + new_size > 5e9:
 
 
 
 
 
 
 
264
  save_file(state_dict, args.save_dir / filename)
265
  state_dict = dict()
266
  size = 0
267
  shard_idx += 1
268
  filename = f"model-{shard_idx + 1:05d}.safetensors"
269
 
270
+ # assume that new_size < 5e9
271
+ size += new_size
272
+ total_size += new_size
273
+ for k, v in sub_state_dict.items():
274
+ state_dict[k] = v
275
+ weight_map[k] = filename
276
+
277
  save_file(state_dict, args.save_dir / filename)
278
  json.dump(
279
  dict(metadata=dict(total_size=total_size), weight_map=weight_map),