File size: 1,634 Bytes
e6cdcac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import json
from safetensors import safe_open

def generate_safetensors_index(model_path="."):
    """Generate model.safetensors.index.json from existing safetensors files"""
    
    # Load the existing bin index as reference
    with open(f"pytorch_model.bin.index.json", "r") as f:
        bin_index = json.load(f)
    
    # Initialize the safetensors index structure
    safetensors_index = {
        "metadata": bin_index.get("metadata", {}),
        "weight_map": {}
    }
    
    # Map each safetensors file and get its tensor names
    safetensors_files = [
        "pytorch_model-00001-of-00004.safetensors",
        "pytorch_model-00002-of-00004.safetensors", 
        "pytorch_model-00003-of-00004.safetensors",
        "pytorch_model-00004-of-00004.safetensors"
    ]
    
    for safetensor_file in safetensors_files:
        try:
            with safe_open(f"{safetensor_file}", framework="pt") as f:
                for tensor_name in f.keys():
                    safetensors_index["weight_map"][tensor_name] = safetensor_file
            print(f"✓ Processed {safetensor_file}")
        except Exception as e:
            print(f"✗ Error processing {safetensor_file}: {e}")
    
    # Save the index file
    with open(f"model.safetensors.index.json", "w") as f:
        json.dump(safetensors_index, f, indent=2)
    
    print(f"✓ Generated model.safetensors.index.json with {len(safetensors_index['weight_map'])} tensors")
    return safetensors_index

# Run the function
if __name__ == "__main__":
    # Change this path to your model directory if needed
    generate_safetensors_index("./Finance-Llama-8B")