File size: 2,966 Bytes
ec3f4e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
from utils.llm_utils import LLMCodeOptimizer
from prompts import system_prompt, generate_prompt
from utils.pipeline_utils import determine_pipe_loading_memory
from utils.hardware_utils import (
    categorize_vram,
    categorize_ram,
    get_gpu_vram_gb,
    get_system_ram_gb,
    is_compile_friendly_gpu,
    is_fp8_friendly,
)
import torch
from pprint import pprint


def create_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--ckpt_id",
        type=str,
        default="black-forest-labs/FLUX.1-dev",
        help="Can be a repo id from the Hub or a local path where the checkpoint is stored.",
    )
    parser.add_argument(
        "--gemini_model",
        type=str,
        default="gemini-2.5-flash-preview-05-20",
        help="Gemini model to use. Choose from https://ai.google.dev/gemini-api/docs/models.",
    )
    parser.add_argument(
        "--variant",
        type=str,
        default=None,
        help="If the `ckpt_id` has variants, supply this flag to estimate compute. Example: 'fp16'.",
    )
    parser.add_argument(
        "--disable_bf16",
        action="store_true",
        help="When enabled the load memory is affected. Prefer not enabling this flag.",
    )
    parser.add_argument(
        "--enable_lossy",
        action="store_true",
        help="When enabled, the code will include snippets for enabling quantization.",
    )
    return parser


def main(args):
    if not torch.cuda.is_available():
        raise ValueError("Not supported for non-CUDA devices for now.")
    
    loading_mem_out = determine_pipe_loading_memory(args.ckpt_id, args.variant, args.disable_bf16)
    load_memory = loading_mem_out["total_loading_memory_gb"]
    ram_gb = get_system_ram_gb()
    ram_category = categorize_ram(ram_gb)
    if ram_gb is not None:
        print(f"\nSystem RAM: {ram_gb:.2f} GB")
        print(f"RAM Category: {ram_category}")
    else:
        print("\nCould not determine System RAM.")

    vram_gb = get_gpu_vram_gb()
    vram_category = categorize_vram(vram_gb)
    if vram_gb is not None:
        print(f"\nGPU VRAM: {vram_gb:.2f} GB")
        print(f"VRAM Category: {vram_category}")
    else:
        print("\nGPU VRAM check complete.")

    is_compile_friendly = is_compile_friendly_gpu()
    is_fp8_compatible = is_fp8_friendly()

    llm = LLMCodeOptimizer(model_name=args.gemini_model, system_prompt=system_prompt)
    current_generate_prompt = generate_prompt.format(
        ckpt_id=args.ckpt_id,
        pipeline_loading_memory=load_memory,
        available_system_ram=ram_gb,
        available_gpu_vram=vram_gb,
        enable_lossy_outputs=args.enable_lossy,
        is_fp8_supported=is_fp8_compatible,
        enable_torch_compile=is_compile_friendly,
    )
    pprint(f"{current_generate_prompt=}")
    print(llm(current_generate_prompt))


if __name__ == "__main__":
    parser = create_parser()
    args = parser.parse_args()
    main(args)