et-fluxShell-endpoint / custom_pipeline.py
staswrs's picture
Update custom_pipeline.py
c10891c verified
raw
history blame contribute delete
962 Bytes
from diffusers import DiffusionPipeline
import os
class FluxPipeline(DiffusionPipeline):
def __init__(
self,
vae,
text_encoder,
text_encoder_2,
tokenizer,
tokenizer_2,
transformer,
scheduler,
**kwargs
):
super().__init__()
self.vae = vae
self.text_encoder = text_encoder
self.text_encoder_2 = text_encoder_2
self.tokenizer = tokenizer
self.tokenizer_2 = tokenizer_2
self.transformer = transformer
self.scheduler = scheduler
# сюда можно добавить доп. обработку kwargs
for k, v in kwargs.items():
setattr(self, k, v)
def load_attn_procs(self, path: str):
if not os.path.exists(path):
raise FileNotFoundError(f"LoRA file not found: {path}")
print(f"[FluxPipeline] Loading LoRA from {path}")
self.load_lora_weights(path)