File size: 5,072 Bytes
0c8d55e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from dataclasses import dataclass
from typing import Optional, List


@dataclass
class TrainingConfig:
    seed: int = 42
    wandb_project: str = "univa-denoiser"
    wandb_name: str = "default_config"
    output_dir: str = "./output"
    logging_dir: str = "./logs"
    gradient_accumulation_steps: int = 1
    learning_rate: float = 1e-5
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_epsilon: float = 1e-8
    adam_weight_decay: float = 1e-2
    mixed_precision: str = "bf16"
    report_to: str = "wandb"
    gradient_checkpointing: bool = False
    num_train_epochs: int = 1
    max_train_steps: Optional[int] = None
    lr_scheduler: str = "constant"
    lr_warmup_steps: int = 0
    lr_num_cycles: int = 1
    lr_power: float = 1.0
    resume_from_checkpoint: Optional[str] = None
    weighting_scheme: Optional[str] = (
        "logit_normal"  # ["sigma_sqrt", "logit_normal", "mode", "cosmap", "null"]
    )
    logit_mean: float = 0.0
    logit_std: float = 1.0
    mode_scale: float = 1.29
    max_grad_norm: float = 1.0
    checkpointing_steps: int = 100
    checkpoints_total_limit: Optional[int] = 500
    drop_condition_rate: float = 0.0
    drop_t5_rate: float = 1.0
    validation_steps: int = 100
    num_validation_images: int = 1

    noise_reference_images: bool = False 
    mask_weight_type: Optional[str] = None     # ['log', 'exp']
    sigmas_as_weight: bool = False  # Used in Flux

    discrete_timestep: bool = True  # Used in Flux


    optimizer: str = 'adamw'  # ['adamw', 'prodigy']

    prodigy_use_bias_correction: bool = True
    prodigy_safeguard_warmup: bool = True
    prodigy_decouple: bool = True
    prodigy_beta3: Optional[float] = None 
    prodigy_d_coef: float = 1.0

    profile_out_dir: Optional[str] = None

    ema_deepspeed_config_file: Optional[str] = None
    ema_update_freq: int = 1
    ema_decay: float = 0.99

@dataclass
class DatasetConfig:
    dataset_type: str
    data_txt: str
    batch_size: int = 16
    num_workers: int = 4
    height: int = 512
    width: int = 512
    min_pixels: int = 448*448
    max_pixels: int = 448*448
    anyres: str = 'any_1ratio'
    ocr_enhancer: bool = False
    random_data: bool = False
    padding_side: str = 'right'
    validation_t2i_prompt: Optional[str] = None
    validation_it2i_prompt: Optional[str] = None
    validation_image_path: Optional[str] = None
    pin_memory: bool = True
    validation_iit2i_prompt: Optional[str] = None
    validation_iit2i_path: Optional[List[str]] = None

    validation_REFiit2i_prompt: Optional[str] = None
    validation_REFiit2i_path: Optional[List[str]] = None


    validation_cannyt2i_prompt: Optional[str] = None
    validation_cannyt2i_path: Optional[str] = None
    validation_poset2i_prompt: Optional[str] = None
    validation_poset2i_path: Optional[str] = None
    
    validation_it2pose_prompt: Optional[str] = None
    validation_it2pose_path: Optional[str] = None
    validation_it2canny_prompt: Optional[str] = None
    validation_it2canny_path: Optional[str] = None

    validation_NIKEit2i_prompt: Optional[str] = None
    validation_NIKEit2i_path: Optional[str] = None

    validation_TRANSFERit2i_prompt: Optional[str] = None
    validation_TRANSFERit2i_path: Optional[str] = None

    validation_EXTRACTit2i_prompt: Optional[str] = None
    validation_EXTRACTit2i_path: Optional[str] = None

    validation_TRYONit2i_prompt: Optional[str] = None
    validation_TRYONit2i_path: Optional[str] = None

    validation_REPLACEit2i_prompt: Optional[str] = None
    validation_REPLACEit2i_path: Optional[str] = None

    validation_DETit2i_prompt: Optional[str] = None
    validation_DETit2i_path: Optional[str] = None

    validation_SEGit2i_prompt: Optional[str] = None
    validation_SEGit2i_path: Optional[str] = None

@dataclass
class ModelConfig:
    pretrained_lvlm_name_or_path: str
    pretrained_denoiser_name_or_path: str
    pretrained_siglip_name_or_path: Optional[str] = None

    train_vision_tower_mm_projector: bool = False

    guidance_scale: float = 1.0  # Used in Flux
    tune_mlp1_only: bool = False
    pretrained_mlp1_path: Optional[str] = None

    with_tune_mlp2: bool = False
    only_tune_mlp2: bool = False
    pretrained_mlp2_path: Optional[str] = None

    only_tune_image_branch: bool = True  # Used in SD3

    with_tune_mlp3: bool = False
    only_tune_mlp3: bool = False
    pretrained_mlp3_path: Optional[str] = None

    flux_train_layer_idx: Optional[list] = None

    with_tune_siglip_mlp: bool = False
    only_tune_siglip_mlp: bool = False
    pretrained_siglip_mlp_path: Optional[str] = None

    joint_ref_feature: bool = False
    joint_ref_feature_as_condition: bool = False
    only_use_t5: bool = False

    vlm_residual_image_factor: float = 0.0

    vae_fp32: bool = True
    compile_flux: bool = False
    compile_qwen2p5vl: bool = False

    ema_pretrained_lvlm_name_or_path: Optional[str] = None
    
@dataclass
class UnivaTrainingDenoiseConfig:
    training_config: TrainingConfig
    dataset_config: DatasetConfig
    model_config: ModelConfig