Minimal hardware requirements
#188
by
tengai
- opened
I'm trying to run model locally on Win, RTX 3060 12GB, and 16GB of RAM, in a full GPU mode, but constantly getting OOM, even when using smallest resolution, with num steps 1...
Is it possible at all to run this on 12GB VRAM?
When offload to CPU it takes about 11mins to create 512x512 image.
I am not on windows, but this works for me, maybe it works for you too:
https://gist.github.com/jason-kane/130b531ac68d9aa932181fbd530b3fba
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "torch>=2.8.0",
# "diffusers@git+https://github.com/huggingface/diffusers.git",
# "transformers==4.55.4",
# "accelerate>=0.26.0",
# "bitsandbytes==0.45.5",
# "protobuf==5.29.4",
# "sentencepiece",
# ]
# ///
"""
12GB VRAM FLUX.1-schnell example
After the first run it takes 38-41 seconds on a 12GB 3060 GTX
to generate a 1024x1024 image
If you have UV installed, you can just run this file and it will
self-install.
"""
import gc
import logging
import random
import torch
from diffusers import DiffusionPipeline
from transformers import (
BitsAndBytesConfig,
T5EncoderModel,
)
logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)
def disabled_safety_checker(images, clip_input):
if len(images.shape)==4:
num_images = images.shape[0]
return images, [False]*num_images
else:
return images, False
def local_flux_schnell(
clip_prompt, # short prompt
t5_prompt=None, # longer, more detailed prompt
):
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
if t5_prompt is None:
t5_prompt = clip_prompt
model_id = "black-forest-labs/FLUX.1-schnell" #needs 4 steps only - it is faster than the dev version as the name implies
text_encoder = T5EncoderModel.from_pretrained(
model_id,
subfolder="text_encoder_2",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16 #bfloat16 and normal float16 both work - former gives a warning but seems to work
)
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16, #bfloat16 and float16 both work, must match the T5
text_encoder_2=text_encoder,
device_map="balanced",
max_memory={0:"11GiB", "cpu":"48GiB"},
)
pipe.safety_checker = disabled_safety_checker
pipe.vae.enable_tiling() #less memory usage at VAE time
log.info(f'Using {clip_prompt=} and {t5_prompt=} to generate a new image...')
image = pipe(
clip_prompt,
prompt_2=t5_prompt,
num_images_per_prompt=1,
guidance_scale=0.0, #must be 0.0 for schnell version, dev version can be as per SD
num_inference_steps=4, #only need 4 for schnell version, dev version needs 50 or so
max_sequence_length=256, #relates to the T5 encoder - text_encoder_2 - max 256 for schnell
generator=torch.Generator("cpu").manual_seed(int(random.randrange(4294967294)))
).images[0]
del pipe
del text_encoder
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
return image
if __name__ == "__main__":
prompt = "A beautiful landscape, trending on artstation"
image = local_flux_schnell(prompt)
image.save("flux_schnell_example.png")