|
import pathlib
|
|
from os import path
|
|
|
|
import torch
|
|
from diffusers import (
|
|
AutoPipelineForText2Image,
|
|
LCMScheduler,
|
|
StableDiffusionPipeline,
|
|
)
|
|
|
|
|
|
def load_lcm_weights(
|
|
pipeline,
|
|
use_local_model,
|
|
lcm_lora_id,
|
|
):
|
|
kwargs = {
|
|
"local_files_only": use_local_model,
|
|
"weight_name": "pytorch_lora_weights.safetensors",
|
|
}
|
|
pipeline.load_lora_weights(
|
|
lcm_lora_id,
|
|
**kwargs,
|
|
adapter_name="lcm",
|
|
)
|
|
|
|
|
|
def get_lcm_lora_pipeline(
|
|
base_model_id: str,
|
|
lcm_lora_id: str,
|
|
use_local_model: bool,
|
|
torch_data_type: torch.dtype,
|
|
pipeline_args={},
|
|
):
|
|
if pathlib.Path(base_model_id).suffix == ".safetensors":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not path.exists(base_model_id):
|
|
raise FileNotFoundError(
|
|
f"Model file not found,Please check your model path: {base_model_id}"
|
|
)
|
|
print("Using single file Safetensors model (Supported models - SD 1.5 models)")
|
|
|
|
dummy_pipeline = StableDiffusionPipeline.from_single_file(
|
|
base_model_id,
|
|
torch_dtype=torch_data_type,
|
|
safety_checker=None,
|
|
local_files_only=use_local_model,
|
|
use_safetensors=True,
|
|
)
|
|
pipeline = AutoPipelineForText2Image.from_pipe(
|
|
dummy_pipeline,
|
|
**pipeline_args,
|
|
)
|
|
del dummy_pipeline
|
|
else:
|
|
pipeline = AutoPipelineForText2Image.from_pretrained(
|
|
base_model_id,
|
|
torch_dtype=torch_data_type,
|
|
local_files_only=use_local_model,
|
|
**pipeline_args,
|
|
)
|
|
|
|
load_lcm_weights(
|
|
pipeline,
|
|
use_local_model,
|
|
lcm_lora_id,
|
|
)
|
|
|
|
|
|
|
|
if "lcm" in lcm_lora_id.lower() or "hypersd" in lcm_lora_id.lower():
|
|
print("LCM LoRA model detected so using recommended LCMScheduler")
|
|
pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
|
|
|
|
|
|
return pipeline
|
|
|