File size: 1,634 Bytes
e6cdcac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import json
from safetensors import safe_open
def generate_safetensors_index(model_path="."):
"""Generate model.safetensors.index.json from existing safetensors files"""
# Load the existing bin index as reference
with open(f"pytorch_model.bin.index.json", "r") as f:
bin_index = json.load(f)
# Initialize the safetensors index structure
safetensors_index = {
"metadata": bin_index.get("metadata", {}),
"weight_map": {}
}
# Map each safetensors file and get its tensor names
safetensors_files = [
"pytorch_model-00001-of-00004.safetensors",
"pytorch_model-00002-of-00004.safetensors",
"pytorch_model-00003-of-00004.safetensors",
"pytorch_model-00004-of-00004.safetensors"
]
for safetensor_file in safetensors_files:
try:
with safe_open(f"{safetensor_file}", framework="pt") as f:
for tensor_name in f.keys():
safetensors_index["weight_map"][tensor_name] = safetensor_file
print(f"✓ Processed {safetensor_file}")
except Exception as e:
print(f"✗ Error processing {safetensor_file}: {e}")
# Save the index file
with open(f"model.safetensors.index.json", "w") as f:
json.dump(safetensors_index, f, indent=2)
print(f"✓ Generated model.safetensors.index.json with {len(safetensors_index['weight_map'])} tensors")
return safetensors_index
# Run the function
if __name__ == "__main__":
# Change this path to your model directory if needed
generate_safetensors_index("./Finance-Llama-8B") |