Minimal hardware requirements

#188
by tengai - opened

I'm trying to run model locally on Win, RTX 3060 12GB, and 16GB of RAM, in a full GPU mode, but constantly getting OOM, even when using smallest resolution, with num steps 1...

Is it possible at all to run this on 12GB VRAM?

When offload to CPU it takes about 11mins to create 512x512 image.

I am not on windows, but this works for me, maybe it works for you too:
https://gist.github.com/jason-kane/130b531ac68d9aa932181fbd530b3fba

#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "torch>=2.8.0",
#     "diffusers@git+https://github.com/huggingface/diffusers.git",
#     "transformers==4.55.4",
#     "accelerate>=0.26.0",
#     "bitsandbytes==0.45.5",
#     "protobuf==5.29.4",
#     "sentencepiece",
# ]
# ///
"""
12GB VRAM FLUX.1-schnell example
After the first run it takes 38-41 seconds on a 12GB 3060 GTX 
to generate a 1024x1024 image

If you have UV installed, you can just run this file and it will
self-install.
"""
import gc
import logging
import random

import torch
from diffusers import DiffusionPipeline
from transformers import (
    BitsAndBytesConfig,
    T5EncoderModel,
)

logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)


def disabled_safety_checker(images, clip_input):
    if len(images.shape)==4:
        num_images = images.shape[0]
        return images, [False]*num_images
    else:
        return images, False


def local_flux_schnell(
        clip_prompt,  # short prompt 
        t5_prompt=None,  # longer, more detailed prompt
    ):

    quantization_config = BitsAndBytesConfig(load_in_8bit=True)
    if t5_prompt is None:
        t5_prompt = clip_prompt

    model_id = "black-forest-labs/FLUX.1-schnell"    #needs 4 steps only - it is faster than the dev version as the name implies
    
    text_encoder = T5EncoderModel.from_pretrained(
        model_id,
        subfolder="text_encoder_2",
        quantization_config=quantization_config,
        torch_dtype=torch.bfloat16   #bfloat16 and normal float16 both work - former gives a warning but seems to work                                
    )

    pipe = DiffusionPipeline.from_pretrained(
        model_id, 
        torch_dtype=torch.bfloat16,   #bfloat16 and float16 both work, must match the T5               
        text_encoder_2=text_encoder,
        device_map="balanced", 
        max_memory={0:"11GiB", "cpu":"48GiB"},
    )

    pipe.safety_checker = disabled_safety_checker
    pipe.vae.enable_tiling()   #less memory usage at VAE time

    log.info(f'Using {clip_prompt=} and {t5_prompt=} to generate a new image...')
    image = pipe(
        clip_prompt,
        prompt_2=t5_prompt,
        num_images_per_prompt=1,
        guidance_scale=0.0,    #must be 0.0 for schnell version, dev version can be as per SD                                                         
        num_inference_steps=4,  #only need 4 for schnell version, dev version needs 50 or so                                                      
        max_sequence_length=256,  #relates to the T5 encoder - text_encoder_2 - max 256 for schnell                                                   
        generator=torch.Generator("cpu").manual_seed(int(random.randrange(4294967294)))
    ).images[0]
    
    del pipe
    del text_encoder
    gc.collect()

    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    return image

if __name__ == "__main__":
    prompt = "A beautiful landscape, trending on artstation"
    image = local_flux_schnell(prompt)
    image.save("flux_schnell_example.png")

Sign up or log in to comment