File size: 7,843 Bytes
ec3f4e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import collections
from pathlib import Path
import functools
import os
import safetensors.torch
from huggingface_hub import model_info, hf_hub_download
import tempfile
import torch
import functools
import os
import requests
import struct
from huggingface_hub import hf_hub_url

DTYPE_MAP = {"FP32": torch.float32, "FP16": torch.float16, "BF16": torch.bfloat16}


# https://huggingface.co/docs/safetensors/v0.3.2/metadata_parsing#python
def _parse_single_file(url):
    print(f"{url=}")
    token = os.getenv("HF_TOKEN")
    headers = {"Range": "bytes=0-7", "Authorization": f"Bearer {token}"}
    response = requests.get(url, headers=headers)
    length_of_header = struct.unpack("<Q", response.content)[0]
    headers = {"Range": f"bytes=8-{7 + length_of_header}", "Authorization": f"Bearer {token}"}
    response = requests.get(url, headers=headers)
    header = response.json()
    return header


def _get_dtype_from_safetensor_file(file_path):
    """Inspects a safetensors file and returns the dtype of the first tensor.

    If it's not a safetensors file and a URL instead, we query it.
    """
    if "https" in file_path:
        metadata = _parse_single_file(file_path)
        except_format_metadata_keys = sorted({k for k in metadata if k != "__metadata__"})
        string_dtype = metadata[except_format_metadata_keys[0]]["dtype"]
        return DTYPE_MAP[string_dtype]
    try:
        # load_file is simple and sufficient for this info-gathering purpose.
        state_dict = safetensors.torch.load_file(file_path)
        if not state_dict:
            return "N/A (empty)"

        # Get the dtype from the first tensor in the state dict
        first_tensor = next(iter(state_dict.values()))
        return first_tensor.dtype
    except Exception as e:
        print(f"Warning: Could not determine dtype from {file_path}. Error: {e}")
        return "N/A (error)"


def _process_components(component_files, file_accessor_fn, disable_bf16=False):
    """
    Generic function to process components, calculate size, and determine dtype.

    Args:
        component_files (dict): A dictionary mapping component names to lists of file objects.
        file_accessor_fn (function): A function that takes a file object and returns
                                     a tuple of (local_path_for_inspection, size_in_bytes, relative_filename).
        disable_bf16 (bool): To disable using `torch.bfloat16`. Use it at your own risk.

    Returns:
        dict: A dictionary containing the total memory and detailed component info.
    """
    components_info = {}
    total_size_bytes = 0

    for name, files in component_files.items():
        # Get dtype by inspecting the first file of the component
        first_file = files[0]

        # The accessor function handles how to get the path (download vs local)
        # and its size and relative name.
        inspection_path, _, _ = file_accessor_fn(first_file)
        dtype = _get_dtype_from_safetensor_file(inspection_path)

        component_size_bytes = 0
        component_file_details = []
        for f in files:
            _, size_bytes, rel_filename = file_accessor_fn(f)
            component_size_bytes += size_bytes
            component_file_details.append({"filename": rel_filename, "size_mb": size_bytes / (1024**2)})

        if dtype == torch.float32 and not disable_bf16:
            print(
                f"The `dtype` for component ({name}) is torch.float32. Since bf16 computation is not disabled "
                "we will slash the total size of this component by 2."
            )
            total_size_bytes += component_size_bytes / 2
        else:
            total_size_bytes += component_size_bytes

        components_info[name] = {
            "size_gb": round(component_size_bytes / (1024**3), 3),
            "dtype": dtype,
            "files": sorted(component_file_details, key=lambda x: x["filename"]),
        }

    return {
        "total_loading_memory_gb": round(total_size_bytes / (1024**3), 3),
        "components": components_info,
    }


@functools.lru_cache()
def _determine_memory_from_hub_ckpt(ckpt_id, variant=None, disable_bf16=False):
    """
    Determines memory and dtypes for a checkpoint on the Hugging Face Hub.
    """
    files_in_repo = model_info(ckpt_id, files_metadata=True, token=os.getenv("HF_TOKEN")).siblings
    all_safetensors_siblings = [
        s for s in files_in_repo if s.rfilename.endswith(".safetensors") and "/" in s.rfilename
    ]
    if variant:
        all_safetensors_siblings = [f for f in all_safetensors_siblings if variant in f.rfilename]

    component_files = collections.defaultdict(list)
    for sibling in all_safetensors_siblings:
        component_name = Path(sibling.rfilename).parent.name
        component_files[component_name].append(sibling)

    with tempfile.TemporaryDirectory() as temp_dir:

        def hub_file_accessor(file_obj):
            """Accessor for Hub files: downloads them and returns path/size."""
            print(f"Querying '{file_obj.rfilename}' for inspection...")
            url = hf_hub_url(ckpt_id, file_obj.rfilename)
            return url, file_obj.size, file_obj.rfilename

        # We only need to download one file per component for dtype inspection.
        # To make this efficient, we create a specialized accessor for the processing loop
        # that only downloads the *first* file encountered for a component.
        downloaded_for_inspection = {}

        def efficient_hub_accessor(file_obj):
            component_name = Path(file_obj.rfilename).parent.name
            if component_name not in downloaded_for_inspection:
                path, _, _ = hub_file_accessor(file_obj)
                downloaded_for_inspection[component_name] = path

            inspection_path = downloaded_for_inspection[component_name]
            return inspection_path, file_obj.size, file_obj.rfilename

        return _process_components(component_files, efficient_hub_accessor, disable_bf16)


@functools.lru_cache()
def _determine_memory_from_local_ckpt(path: str, variant=None, disable_bf16=False):
    """
    Determines memory and dtypes for a local checkpoint.
    """
    ckpt_path = Path(path)
    if not ckpt_path.is_dir():
        return {"error": f"Checkpoint path '{path}' not found or is not a directory."}

    all_safetensors_paths = list(ckpt_path.glob("**/*.safetensors"))
    if variant:
        all_safetensors_paths = [p for p in all_safetensors_paths if variant in p.name]

    component_files = collections.defaultdict(list)
    for file_path in all_safetensors_paths:
        component_name = file_path.parent.name
        component_files[component_name].append(file_path)

    def local_file_accessor(file_path):
        """Accessor for local files: just returns their path and size."""
        return file_path, file_path.stat().st_size, str(file_path.relative_to(ckpt_path))

    return _process_components(component_files, local_file_accessor, disable_bf16)


def determine_pipe_loading_memory(ckpt_id: str, variant=None, disable_bf16=False):
    """
    Determines the memory and dtypes for a pipeline, whether it's local or on the Hub.
    """
    if os.path.isdir(ckpt_id):
        return _determine_memory_from_local_ckpt(ckpt_id, variant, disable_bf16)
    else:
        return _determine_memory_from_hub_ckpt(ckpt_id, variant, disable_bf16)


if __name__ == "__main__":
    output = _determine_memory_from_hub_ckpt("Wan-AI/Wan2.1-T2V-14B-Diffusers")
    total_size_gb = output["total_loading_memory_gb"]
    safetensor_files = output["components"]
    print(f"{total_size_gb=} GB")
    print(f"{safetensor_files=}")
    print("\n")
    # total_size_gb, safetensor_files = _determine_memory_from_local_ckpt("LOCAL_DIR") # change me.
    # print(f"{total_size_gb=} GB")
    # print(f"{safetensor_files=}")