|
import argparse |
|
from utils.llm_utils import LLMCodeOptimizer |
|
from prompts import system_prompt, generate_prompt |
|
from utils.pipeline_utils import determine_pipe_loading_memory |
|
from utils.hardware_utils import ( |
|
categorize_vram, |
|
categorize_ram, |
|
get_gpu_vram_gb, |
|
get_system_ram_gb, |
|
is_compile_friendly_gpu, |
|
is_fp8_friendly, |
|
) |
|
import torch |
|
from pprint import pprint |
|
|
|
|
|
def create_parser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--ckpt_id", |
|
type=str, |
|
default="black-forest-labs/FLUX.1-dev", |
|
help="Can be a repo id from the Hub or a local path where the checkpoint is stored.", |
|
) |
|
parser.add_argument( |
|
"--gemini_model", |
|
type=str, |
|
default="gemini-2.5-flash-preview-05-20", |
|
help="Gemini model to use. Choose from https://ai.google.dev/gemini-api/docs/models.", |
|
) |
|
parser.add_argument( |
|
"--variant", |
|
type=str, |
|
default=None, |
|
help="If the `ckpt_id` has variants, supply this flag to estimate compute. Example: 'fp16'.", |
|
) |
|
parser.add_argument( |
|
"--disable_bf16", |
|
action="store_true", |
|
help="When enabled the load memory is affected. Prefer not enabling this flag.", |
|
) |
|
parser.add_argument( |
|
"--enable_lossy", |
|
action="store_true", |
|
help="When enabled, the code will include snippets for enabling quantization.", |
|
) |
|
return parser |
|
|
|
|
|
def main(args): |
|
if not torch.cuda.is_available(): |
|
raise ValueError("Not supported for non-CUDA devices for now.") |
|
|
|
loading_mem_out = determine_pipe_loading_memory(args.ckpt_id, args.variant, args.disable_bf16) |
|
load_memory = loading_mem_out["total_loading_memory_gb"] |
|
ram_gb = get_system_ram_gb() |
|
ram_category = categorize_ram(ram_gb) |
|
if ram_gb is not None: |
|
print(f"\nSystem RAM: {ram_gb:.2f} GB") |
|
print(f"RAM Category: {ram_category}") |
|
else: |
|
print("\nCould not determine System RAM.") |
|
|
|
vram_gb = get_gpu_vram_gb() |
|
vram_category = categorize_vram(vram_gb) |
|
if vram_gb is not None: |
|
print(f"\nGPU VRAM: {vram_gb:.2f} GB") |
|
print(f"VRAM Category: {vram_category}") |
|
else: |
|
print("\nGPU VRAM check complete.") |
|
|
|
is_compile_friendly = is_compile_friendly_gpu() |
|
is_fp8_compatible = is_fp8_friendly() |
|
|
|
llm = LLMCodeOptimizer(model_name=args.gemini_model, system_prompt=system_prompt) |
|
current_generate_prompt = generate_prompt.format( |
|
ckpt_id=args.ckpt_id, |
|
pipeline_loading_memory=load_memory, |
|
available_system_ram=ram_gb, |
|
available_gpu_vram=vram_gb, |
|
enable_lossy_outputs=args.enable_lossy, |
|
is_fp8_supported=is_fp8_compatible, |
|
enable_torch_compile=is_compile_friendly, |
|
) |
|
pprint(f"{current_generate_prompt=}") |
|
print(llm(current_generate_prompt)) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = create_parser() |
|
args = parser.parse_args() |
|
main(args) |
|
|