Spaces:
Paused
Paused
File size: 3,953 Bytes
c9b1bf6 b1dde27 35bd3cf b1dde27 35bd3cf 0a3593d 35bd3cf 0a3593d 35bd3cf 0a3593d 35bd3cf 0a3593d 35bd3cf 0a3593d 35bd3cf 0a3593d 35bd3cf 0a3593d 35bd3cf 0a3593d 35bd3cf b1dde27 35bd3cf b1dde27 35bd3cf b1dde27 35bd3cf b1dde27 35bd3cf b1dde27 35bd3cf b1dde27 35bd3cf b1dde27 35bd3cf b1dde27 35bd3cf b1dde27 35bd3cf b1dde27 35bd3cf 0a3593d 35bd3cf b1dde27 0a3593d c9b1bf6 35bd3cf 0a3593d 35bd3cf c9b1bf6 0a3593d c9b1bf6 b1dde27 c9b1bf6 35bd3cf 0a3593d 35bd3cf 0a3593d 35bd3cf b1dde27 0a3593d 35bd3cf c9b1bf6 0a3593d c9b1bf6 35bd3cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import os
import torch
from huggingface_hub import snapshot_download
from diffusers import (
StableDiffusionPipeline,
DPMSolverMultistepScheduler,
AutoencoderKL,
UNet2DConditionModel,
)
from transformers import CLIPTextModel, CLIPTokenizer
from peft import LoraConfig, get_peft_model
# βββ CONFIG βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DATA_DIR = os.getenv("DATA_DIR", "./data")
MODEL_CACHE = os.getenv("MODEL_DIR", "./hidream-model")
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained")
REPO_ID = "HiDream-ai/HiDream-I1-Dev"
# βββ STEP 1: ENSURE YOU HAVE A COMPLETE SNAPSHOT WITH CONFIGS βββββββββββββββββ
print(f"π₯ Downloading full model snapshot to {MODEL_CACHE}")
MODEL_ROOT = snapshot_download(
repo_id=REPO_ID,
local_dir=MODEL_CACHE,
local_dir_use_symlinks=False, # force a copy so config.json ends up there
)
# βββ STEP 2: LOAD SCHEDULER ββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("π Loading scheduler")
scheduler = DPMSolverMultistepScheduler.from_pretrained(
MODEL_ROOT,
subfolder="scheduler",
)
# βββ STEP 3: LOAD VAE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("π Loading VAE")
vae = AutoencoderKL.from_pretrained(
MODEL_ROOT,
subfolder="vae",
torch_dtype=torch.float16,
).to("cuda")
# βββ STEP 4: LOAD TEXT ENCODER + TOKENIZER βββββββββββββββββββββββββββββββββββββ
print("π Loading text encoder + tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
MODEL_ROOT,
subfolder="text_encoder",
torch_dtype=torch.float16,
).to("cuda")
tokenizer = CLIPTokenizer.from_pretrained(
MODEL_ROOT,
subfolder="tokenizer",
)
# βββ STEP 5: LOAD UβNET βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("π Loading UβNet")
unet = UNet2DConditionModel.from_pretrained(
MODEL_ROOT,
subfolder="unet",
torch_dtype=torch.float16,
).to("cuda")
# βββ STEP 6: BUILD THE PIPELINE βββββββββββββββββββββββββββββββββββββββββββββββ
print("π Building StableDiffusionPipeline")
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
).to("cuda")
# βββ STEP 7: APPLY LORA ADAPTER βββββββββββββββββββββββββββββββββββββββββββββββ
print("π§ Applying LoRA adapter")
lora_config = LoraConfig(
r=16,
lora_alpha=16,
bias="none",
task_type="CAUSAL_LM",
)
pipe.unet = get_peft_model(pipe.unet, lora_config)
# βββ STEP 8: YOUR TRAINING LOOP (SIMULATED) ββββββββββββββββββββββββββββββββββββ
print(f"π Loading dataset from: {DATA_DIR}")
for step in range(100):
# ββ hereβs where youβd load your images, run forward/backward, optimizer, etc.
print(f"Training step {step+1}/100")
# βββ STEP 9: SAVE THE FINEβTUNED LOβRA WEIGHTS βββββββββββββββββββββββββββββββ
os.makedirs(OUTPUT_DIR, exist_ok=True)
pipe.save_pretrained(OUTPUT_DIR)
print("β
Training complete. Saved to", OUTPUT_DIR)
|