import json | |
from safetensors import safe_open | |
def generate_safetensors_index(model_path="."): | |
"""Generate model.safetensors.index.json from existing safetensors files""" | |
# Load the existing bin index as reference | |
with open(f"pytorch_model.bin.index.json", "r") as f: | |
bin_index = json.load(f) | |
# Initialize the safetensors index structure | |
safetensors_index = { | |
"metadata": bin_index.get("metadata", {}), | |
"weight_map": {} | |
} | |
# Map each safetensors file and get its tensor names | |
safetensors_files = [ | |
"pytorch_model-00001-of-00004.safetensors", | |
"pytorch_model-00002-of-00004.safetensors", | |
"pytorch_model-00003-of-00004.safetensors", | |
"pytorch_model-00004-of-00004.safetensors" | |
] | |
for safetensor_file in safetensors_files: | |
try: | |
with safe_open(f"{safetensor_file}", framework="pt") as f: | |
for tensor_name in f.keys(): | |
safetensors_index["weight_map"][tensor_name] = safetensor_file | |
print(f"β Processed {safetensor_file}") | |
except Exception as e: | |
print(f"β Error processing {safetensor_file}: {e}") | |
# Save the index file | |
with open(f"model.safetensors.index.json", "w") as f: | |
json.dump(safetensors_index, f, indent=2) | |
print(f"β Generated model.safetensors.index.json with {len(safetensors_index['weight_map'])} tensors") | |
return safetensors_index | |
# Run the function | |
if __name__ == "__main__": | |
# Change this path to your model directory if needed | |
generate_safetensors_index("./Finance-Llama-8B") |