TiM / configs /c2i /tim_xl_p1_512_mg.yaml
blanchon's picture
Update
3ed0796
model:
transport:
target: tim.schedulers.transports.OT_FM
params:
P_mean: -0.4
P_std: 1.0
sigma_d: 1.0
T_max: 1.0
T_min: 0.0
enhance_target: True
w_gt: 1.0
w_cond: 0.75
w_start: 0.3
w_end: 0.8
transition_loss:
diffusion_ratio: 0.5
consistency_ratio: 0.1
derivative_type: dde
differential_epsilon: 0.005
weight_time_type: sqrt
weight_time_tangent: True
network:
target: tim.models.c2i.tim_model.TiM
params:
input_size: 16
patch_size: 1
in_channels: 32
class_dropout_prob: 0.1
num_classes: 1000
depth: 28
hidden_size: 1152
num_heads: 16
encoder_depth: 8
qk_norm: True
z_dim: 768
new_condition: t-r
use_new_embed: True
distance_aware: True
lora_hidden_size: 384
# pretrained_vae:
vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
# repa encoder
enc_dir: checkpoints/radio/radio-v2.5-b_half.pth.tar
proj_coeff: 1.0
# ema
use_ema: True
ema_decay: 0.9999
data:
data_type: latent
dataset:
latent_dir: datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512
image_dir: datasets/imagenet1k/images/train
image_size: 512
dataloader:
num_workers: 4
batch_size: 64 # Batch size (per device) for the training dataloader.
training:
tracker: null
max_train_steps: 750000
checkpointing_steps: 2000
checkpoints_total_limit: 2
resume_from_checkpoint: latest
learning_rate: 1.0e-4
learning_rate_base_batch_size: 256
scale_lr: True
lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_warmup_steps: 0
gradient_accumulation_steps: 1
optimizer:
target: torch.optim.AdamW
params:
# betas: ${tuple:0.9, 0.999}
betas: [0.9, 0.95]
weight_decay: 1.0e-2
eps: 1.0e-6
max_grad_norm: 1.0
proportion_empty_prompts: 0.0
mixed_precision: bf16 # ["no", "fp16", "bf16"]
allow_tf32: True
validation_steps: 500
checkpoint_list: [100000, 250000, 500000]