import argparse from utils.llm_utils import LLMCodeOptimizer from prompts import system_prompt, generate_prompt from utils.pipeline_utils import determine_pipe_loading_memory from utils.hardware_utils import ( categorize_vram, categorize_ram, get_gpu_vram_gb, get_system_ram_gb, is_compile_friendly_gpu, is_fp8_friendly, ) import torch from pprint import pprint def create_parser(): parser = argparse.ArgumentParser() parser.add_argument( "--ckpt_id", type=str, default="black-forest-labs/FLUX.1-dev", help="Can be a repo id from the Hub or a local path where the checkpoint is stored.", ) parser.add_argument( "--gemini_model", type=str, default="gemini-2.5-flash-preview-05-20", help="Gemini model to use. Choose from https://ai.google.dev/gemini-api/docs/models.", ) parser.add_argument( "--variant", type=str, default=None, help="If the `ckpt_id` has variants, supply this flag to estimate compute. Example: 'fp16'.", ) parser.add_argument( "--disable_bf16", action="store_true", help="When enabled the load memory is affected. Prefer not enabling this flag.", ) parser.add_argument( "--enable_lossy", action="store_true", help="When enabled, the code will include snippets for enabling quantization.", ) return parser def main(args): if not torch.cuda.is_available(): raise ValueError("Not supported for non-CUDA devices for now.") loading_mem_out = determine_pipe_loading_memory(args.ckpt_id, args.variant, args.disable_bf16) load_memory = loading_mem_out["total_loading_memory_gb"] ram_gb = get_system_ram_gb() ram_category = categorize_ram(ram_gb) if ram_gb is not None: print(f"\nSystem RAM: {ram_gb:.2f} GB") print(f"RAM Category: {ram_category}") else: print("\nCould not determine System RAM.") vram_gb = get_gpu_vram_gb() vram_category = categorize_vram(vram_gb) if vram_gb is not None: print(f"\nGPU VRAM: {vram_gb:.2f} GB") print(f"VRAM Category: {vram_category}") else: print("\nGPU VRAM check complete.") is_compile_friendly = is_compile_friendly_gpu() is_fp8_compatible = is_fp8_friendly() llm = LLMCodeOptimizer(model_name=args.gemini_model, system_prompt=system_prompt) current_generate_prompt = generate_prompt.format( ckpt_id=args.ckpt_id, pipeline_loading_memory=load_memory, available_system_ram=ram_gb, available_gpu_vram=vram_gb, enable_lossy_outputs=args.enable_lossy, is_fp8_supported=is_fp8_compatible, enable_torch_compile=is_compile_friendly, ) pprint(f"{current_generate_prompt=}") print(llm(current_generate_prompt)) if __name__ == "__main__": parser = create_parser() args = parser.parse_args() main(args)