File size: 2,114 Bytes
3ed0796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
model: 
  transport:
    target: tim.schedulers.transports.OT_FM
    params:
      P_mean: -0.4
      P_std: 1.0 
      sigma_d: 1.0
      T_max: 1.0
      T_min: 0.0
      enhance_target: True
      w_gt: 1.0
      w_cond: 0.75
      w_start: 0.3
      w_end: 0.8
  transition_loss: 
    diffusion_ratio: 0.5
    consistency_ratio: 0.1
    derivative_type: dde
    differential_epsilon: 0.005
    weight_time_type: sqrt
    weight_time_tangent: True
  network:  
    target: tim.models.c2i.tim_model.TiM
    params:
      input_size: 16
      patch_size: 1
      in_channels: 32
      class_dropout_prob: 0.1
      num_classes: 1000
      depth: 28
      hidden_size: 1152
      num_heads: 16
      encoder_depth: 8
      qk_norm: True
      z_dim: 768
      new_condition: t-r
      use_new_embed: True
      distance_aware: True
      lora_hidden_size: 384
  # pretrained_vae:
  vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
  # repa encoder
  enc_dir: checkpoints/radio/radio-v2.5-b_half.pth.tar
  proj_coeff: 1.0
  # ema
  use_ema: True
  ema_decay: 0.9999
  
data:
  data_type: latent
  dataset:
    latent_dir: datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512
    image_dir: datasets/imagenet1k/images/train
    image_size: 512
  dataloader:
    num_workers: 4
    batch_size: 64  # Batch size (per device) for the training dataloader.

  
  
training:
  tracker: null
  max_train_steps: 750000
  checkpointing_steps: 2000
  checkpoints_total_limit: 2
  resume_from_checkpoint: latest
  learning_rate: 1.0e-4
  learning_rate_base_batch_size: 256
  scale_lr: True
  lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
  lr_warmup_steps: 0
  gradient_accumulation_steps: 1
  optimizer: 
    target: torch.optim.AdamW
    params:
      # betas: ${tuple:0.9, 0.999}
      betas: [0.9, 0.95]
      weight_decay: 1.0e-2
      eps: 1.0e-6
  max_grad_norm: 1.0
  proportion_empty_prompts: 0.0
  mixed_precision: bf16 # ["no", "fp16", "bf16"]
  allow_tf32: True 
  validation_steps: 500
  checkpoint_list: [100000, 250000, 500000]