from diffusers import DiffusionPipeline | |
import os | |
class FluxPipeline(DiffusionPipeline): | |
def __init__( | |
self, | |
vae, | |
text_encoder, | |
text_encoder_2, | |
tokenizer, | |
tokenizer_2, | |
transformer, | |
scheduler, | |
**kwargs | |
): | |
super().__init__() | |
self.vae = vae | |
self.text_encoder = text_encoder | |
self.text_encoder_2 = text_encoder_2 | |
self.tokenizer = tokenizer | |
self.tokenizer_2 = tokenizer_2 | |
self.transformer = transformer | |
self.scheduler = scheduler | |
# сюда можно добавить доп. обработку kwargs | |
for k, v in kwargs.items(): | |
setattr(self, k, v) | |
def load_attn_procs(self, path: str): | |
if not os.path.exists(path): | |
raise FileNotFoundError(f"LoRA file not found: {path}") | |
print(f"[FluxPipeline] Loading LoRA from {path}") | |
self.load_lora_weights(path) | |