File size: 2,966 Bytes
ec3f4e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import argparse
from utils.llm_utils import LLMCodeOptimizer
from prompts import system_prompt, generate_prompt
from utils.pipeline_utils import determine_pipe_loading_memory
from utils.hardware_utils import (
categorize_vram,
categorize_ram,
get_gpu_vram_gb,
get_system_ram_gb,
is_compile_friendly_gpu,
is_fp8_friendly,
)
import torch
from pprint import pprint
def create_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt_id",
type=str,
default="black-forest-labs/FLUX.1-dev",
help="Can be a repo id from the Hub or a local path where the checkpoint is stored.",
)
parser.add_argument(
"--gemini_model",
type=str,
default="gemini-2.5-flash-preview-05-20",
help="Gemini model to use. Choose from https://ai.google.dev/gemini-api/docs/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="If the `ckpt_id` has variants, supply this flag to estimate compute. Example: 'fp16'.",
)
parser.add_argument(
"--disable_bf16",
action="store_true",
help="When enabled the load memory is affected. Prefer not enabling this flag.",
)
parser.add_argument(
"--enable_lossy",
action="store_true",
help="When enabled, the code will include snippets for enabling quantization.",
)
return parser
def main(args):
if not torch.cuda.is_available():
raise ValueError("Not supported for non-CUDA devices for now.")
loading_mem_out = determine_pipe_loading_memory(args.ckpt_id, args.variant, args.disable_bf16)
load_memory = loading_mem_out["total_loading_memory_gb"]
ram_gb = get_system_ram_gb()
ram_category = categorize_ram(ram_gb)
if ram_gb is not None:
print(f"\nSystem RAM: {ram_gb:.2f} GB")
print(f"RAM Category: {ram_category}")
else:
print("\nCould not determine System RAM.")
vram_gb = get_gpu_vram_gb()
vram_category = categorize_vram(vram_gb)
if vram_gb is not None:
print(f"\nGPU VRAM: {vram_gb:.2f} GB")
print(f"VRAM Category: {vram_category}")
else:
print("\nGPU VRAM check complete.")
is_compile_friendly = is_compile_friendly_gpu()
is_fp8_compatible = is_fp8_friendly()
llm = LLMCodeOptimizer(model_name=args.gemini_model, system_prompt=system_prompt)
current_generate_prompt = generate_prompt.format(
ckpt_id=args.ckpt_id,
pipeline_loading_memory=load_memory,
available_system_ram=ram_gb,
available_gpu_vram=vram_gb,
enable_lossy_outputs=args.enable_lossy,
is_fp8_supported=is_fp8_compatible,
enable_torch_compile=is_compile_friendly,
)
pprint(f"{current_generate_prompt=}")
print(llm(current_generate_prompt))
if __name__ == "__main__":
parser = create_parser()
args = parser.parse_args()
main(args)
|