Spaces:
Runtime error
Runtime error
import tkinter as tk | |
from tkinter import ttk, filedialog, messagebox, Menu | |
import subprocess | |
import threading | |
import json | |
import os | |
import sys | |
import signal | |
# Dark theme color scheme | |
BG_COLOR = "#2C3E50" # Main background (dark gray with blue tint) | |
FG_COLOR = "#ECF0F1" # Light text | |
ACCENT_COLOR = "#2980B9" # Blue accent for tabs | |
ENTRY_BG = "#1B2A38" # Entry field background (darker than main) | |
BUTTON_ACTIVE = "#1B2A38" # Active button background | |
BORDER_COLOR = "#333333" # Dark border color | |
ACTIVE_ENTRY_BG = "white" # Background color for active entry field | |
ACTIVE_ENTRY_FG = "black" # Text color for active entry field | |
class LoRATrainerGUI: | |
def __init__(self, master): | |
self.master = master | |
master.title("Wan 2.1 LoRA Trainer") | |
master.geometry("900x1024") | |
master.configure(bg=BG_COLOR) | |
self.current_process = None | |
self.training_thread = None | |
self.process_group_id = None | |
self.user_scrolled = False # Flag for manual console scrolling | |
# Initialize settings with default values, including conversion settings | |
self.settings = { | |
"DATASET_CONFIG": "dataset/dataset_example.toml", | |
"VAE_MODEL": "Models/Wan/Wan2.1_VAE.pth", | |
"CLIP_MODEL": "Models/Wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", | |
"T5_MODEL": "Models/Wan/models_t5_umt5-xxl-enc-bf16.pth", | |
"DIT_MODEL": "Models/Wan/wan2.1_i2v_720p_14B_fp8_e4m3fn.safetensors", | |
"LORA_OUTPUT_DIR": "Output_LoRAs/", | |
"LORA_NAME": "My_Best_Lora_v1", | |
"MODEL_TYPE": "i2v-14B", | |
"FLOW_SHIFT": 3.0, | |
"LEARNING_RATE": 2e-5, | |
"LORA_LR_RATIO": 4, | |
"NETWORK_DIM": 32, | |
"NETWORK_ALPHA": 4, | |
"MAX_TRAIN_EPOCHS": 70, | |
"SAVE_EVERY_N_EPOCHS": 10, | |
"SEED": 1234, | |
"BLOCKS_SWAP": 16, | |
"RESUME_TRAINING": "", | |
"OPTIMIZER_TYPE": "adamw8bit", | |
"OPTIMIZER_ARGS": "", | |
"ATTENTION_MECHANISM": "none", | |
"LOGGING_DIR": "", | |
"LOG_WITH": "none", | |
"LOG_PREFIX": "", | |
"IMG_IN_TXT_IN_OFFLOADING": False, | |
"LR_SCHEDULER": "constant", | |
"LR_WARMUP_STEPS": "", | |
"LR_DECAY_STEPS": "", | |
"TIMESTEP_SAMPLING": "shift", | |
"DISCRETE_FLOW_SHIFT": "3.0", | |
"WEIGHTING_SCHEME": "none", | |
"METADATA_TITLE": "", | |
"METADATA_AUTHOR": "", | |
"METADATA_DESCRIPTION": "", | |
"METADATA_LICENSE": "", | |
"METADATA_TAGS": "", | |
"INPUT_LORA": "", | |
"OUTPUT_DIR": "", | |
"CONVERTED_LORA_NAME": "", | |
"FP8": True, # Default FP8 setting | |
"SCALED": False # Default Scaled setting | |
} | |
self.model_types = ["t2v-1.3B", "t2v-14B", "i2v-14B", "t2i-14B"] | |
self.optimizer_types = ["adamw", "adamw8bit", "adafactor", "torch.optim.AdamW", "bitsandbytes.optim.AdEMAMix8bit", "bitsandbytes.optim.PagedAdEMAMix8bit", "came"] | |
self.setup_styles() | |
# Create notebook and tabs | |
self.notebook = ttk.Notebook(master) | |
self.notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) | |
# Создание вкладок с привязкой события клика мыши | |
self.training_tab = ttk.Frame(self.notebook) | |
self.training_tab.bind("<Button-1>", self.remove_focus) # Привязка клика для снятия фокуса | |
self.notebook.add(self.training_tab, text="Training settings") | |
self.advanced_tab = ttk.Frame(self.notebook) | |
self.advanced_tab.bind("<Button-1>", self.remove_focus) # Привязка клика для снятия фокуса | |
self.notebook.add(self.advanced_tab, text="Advanced settings") | |
self.conversion_tab = ttk.Frame(self.notebook) | |
self.conversion_tab.bind("<Button-1>", self.remove_focus) # Привязка клика для снятия фокуса | |
self.notebook.add(self.conversion_tab, text="LoRA Conversion") | |
# Initialize tab contents | |
self.create_training_settings() | |
self.create_advanced_settings() | |
self.create_conversion_settings() | |
# Create context menu for copying console text | |
self.context_menu = Menu(self.master, tearoff=0) | |
self.context_menu.add_command(label="Copy", command=self.copy_selected_text) | |
def remove_focus(self, event): | |
"""Снимает фокус с активного виджета при клике по фону""" | |
self.master.focus_set() | |
def setup_styles(self): | |
"""Set up styles for dark theme""" | |
style = ttk.Style() | |
style.theme_use("clam") | |
style.configure(".", background=BG_COLOR, foreground=FG_COLOR) | |
style.configure("TFrame", background=BG_COLOR) | |
style.configure("TLabel", background=BG_COLOR, foreground=FG_COLOR) | |
style.configure( | |
"TButton", | |
background=BG_COLOR, | |
foreground=FG_COLOR, | |
bordercolor=BORDER_COLOR, | |
borderwidth=1, | |
focusthickness=3, | |
focuscolor=BG_COLOR, | |
padding=[5, 1] | |
) | |
style.map( | |
"TButton", | |
background=[("active", BUTTON_ACTIVE), ("pressed", BUTTON_ACTIVE)], | |
foreground=[("active", FG_COLOR), ("pressed", FG_COLOR)] | |
) | |
style.configure("TCheckbutton", background=BG_COLOR, foreground=FG_COLOR) | |
style.map("TCheckbutton", background=[("active", BG_COLOR)], foreground=[("active", FG_COLOR)]) | |
style.configure("TNotebook", background=BG_COLOR, borderwidth=0) | |
style.configure("TNotebook.Tab", background=BG_COLOR, foreground=FG_COLOR, padding=[5, 2]) | |
style.map("TNotebook.Tab", background=[("selected", ACCENT_COLOR)], foreground=[("selected", FG_COLOR)]) | |
style.configure( | |
"TEntry", | |
fieldbackground=ENTRY_BG, | |
foreground=FG_COLOR, | |
bordercolor=BORDER_COLOR | |
) | |
style.map("TEntry", | |
fieldbackground=[("focus", ACTIVE_ENTRY_BG)], | |
foreground=[("focus", ACTIVE_ENTRY_FG)] | |
) | |
style.configure( | |
"TCombobox", | |
fieldbackground=ENTRY_BG, | |
background=BG_COLOR, | |
foreground=FG_COLOR, | |
bordercolor=BORDER_COLOR | |
) | |
style.map("TCombobox", | |
fieldbackground=[("focus", ACTIVE_ENTRY_BG), ("readonly", ENTRY_BG), ("!disabled", ENTRY_BG)], | |
foreground=[("focus", ACTIVE_ENTRY_FG), ("readonly", FG_COLOR), ("!disabled", FG_COLOR)], | |
selectbackground=[("readonly", ENTRY_BG), ("!disabled", ENTRY_BG)], | |
selectforeground=[("readonly", FG_COLOR), ("!disabled", FG_COLOR)] | |
) | |
style.configure( | |
"Vertical.TScrollbar", | |
background=ENTRY_BG, | |
troughcolor=BG_COLOR, | |
bordercolor=BORDER_COLOR, | |
arrowcolor=FG_COLOR, | |
darkcolor=BG_COLOR, | |
lightcolor=BG_COLOR | |
) | |
style.map( | |
"Vertical.TScrollbar", | |
background=[("active", BUTTON_ACTIVE), ("pressed", BUTTON_ACTIVE)] | |
) | |
def create_training_settings(self): | |
row = 0 | |
ttk.Label(self.training_tab, text="Training Settings", font=("Arial", 12, "bold")).grid( | |
row=row, column=0, columnspan=3, pady=(10, 10) | |
) | |
row += 1 | |
button_frame_top = ttk.Frame(self.training_tab) | |
button_frame_top.grid(row=row, column=0, columnspan=3, pady=5) | |
ttk.Button(button_frame_top, text="Load Settings", command=self.load_settings).pack(side=tk.LEFT, padx=10) | |
ttk.Button(button_frame_top, text="Save Settings", command=self.save_settings).pack(side=tk.LEFT, padx=10) | |
row += 1 | |
settings_config = [ | |
("Dataset Config", "DATASET_CONFIG", "file"), | |
("VAE Model", "VAE_MODEL", "file"), | |
("Clip Model", "CLIP_MODEL", "file"), | |
("T5 Model", "T5_MODEL", "file"), | |
("Dit Model", "DIT_MODEL", "file"), | |
("LoRA Output Dir", "LORA_OUTPUT_DIR", "directory"), | |
("LoRA Name", "LORA_NAME", "text"), | |
("Model Type", "MODEL_TYPE", "dropdown"), | |
("Flow Shift", "FLOW_SHIFT", "float"), | |
("Learning Rate", "LEARNING_RATE", "float"), | |
("LoRA LR Ratio", "LORA_LR_RATIO", "int"), | |
("Network Dim", "NETWORK_DIM", "int"), | |
("Network Alpha", "NETWORK_ALPHA", "float"), | |
("Max Train Epochs", "MAX_TRAIN_EPOCHS", "int"), | |
("Save Every N Epochs", "SAVE_EVERY_N_EPOCHS", "int"), | |
("Seed", "SEED", "int"), | |
("Blocks Swap", "BLOCKS_SWAP", "int"), | |
("Resume Training", "RESUME_TRAINING", "directory"), | |
("Optimizer Type", "OPTIMIZER_TYPE", "dropdown"), | |
("Optimizer Args", "OPTIMIZER_ARGS", "text"), | |
] | |
self.entries = {} | |
for label_text, key, input_type in settings_config: | |
ttk.Label(self.training_tab, text=f"{label_text}:").grid( | |
row=row, column=0, sticky=tk.W, padx=5, pady=2 | |
) | |
if input_type == "dropdown": | |
if key == "MODEL_TYPE": | |
var = tk.StringVar(value=self.settings[key]) | |
self.entries[key] = ttk.Combobox( | |
self.training_tab, textvariable=var, values=self.model_types, state="readonly" | |
) | |
self.entries[key].current(self.model_types.index(self.settings[key])) | |
elif key == "OPTIMIZER_TYPE": | |
var = tk.StringVar(value=self.settings[key]) | |
self.entries[key] = ttk.Combobox( | |
self.training_tab, textvariable=var, values=self.optimizer_types, state="readonly" | |
) | |
self.entries[key].current(self.optimizer_types.index(self.settings[key])) | |
else: | |
self.entries[key] = ttk.Entry(self.training_tab, width=40) | |
self.entries[key].insert(0, self.settings[key]) | |
self.entries[key].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
if input_type in ["file", "directory"]: | |
ttk.Button( | |
self.training_tab, | |
text="Browse", | |
command=lambda k=key, t=input_type: self.browse_file(k, t) | |
).grid(row=row, column=2, sticky=tk.W, padx=5) | |
row += 1 | |
# Weight Optimization Checkboxes | |
ttk.Label(self.training_tab, text="Weight Optimization:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.fp8_var = tk.BooleanVar(value=self.settings["FP8"]) | |
self.scaled_var = tk.BooleanVar(value=self.settings["SCALED"]) | |
self.fp8_check = ttk.Checkbutton(self.training_tab, text="FP8 Base", variable=self.fp8_var, command=self.toggle_scaled) | |
self.fp8_check.grid(row=row, column=1, sticky=tk.W, padx=5, pady=2) | |
self.scaled_check = ttk.Checkbutton(self.training_tab, text="FP8 Scaled", variable=self.scaled_var, state=tk.DISABLED if not self.fp8_var.get() else tk.NORMAL) | |
self.scaled_check.grid(row=row, column=1, sticky=tk.W, padx=100, pady=2) | |
row += 1 | |
self.enable_cache_var = tk.BooleanVar(value=True) | |
ttk.Checkbutton( | |
self.training_tab, text="Enable Cache Preparation", variable=self.enable_cache_var | |
).grid(row=row, column=0, columnspan=3, pady=5) | |
row += 1 | |
button_frame = ttk.Frame(self.training_tab) | |
button_frame.grid(row=row, column=0, columnspan=3, pady=10) | |
ttk.Button(button_frame, text="Start Training", command=self.start_training).pack(side=tk.LEFT, padx=10) | |
ttk.Button(button_frame, text="Stop Training", command=self.stop_training).pack(side=tk.LEFT, padx=10) | |
row += 1 | |
self.console_frame = ttk.Frame(self.training_tab) | |
self.console_frame.grid(row=row, column=0, columnspan=3, padx=5, pady=5, sticky="nsew") | |
self.console_output = tk.Text( | |
self.console_frame, | |
height=10, | |
width=80, | |
bg=ENTRY_BG, | |
fg=FG_COLOR, | |
wrap="word", | |
state="disabled", | |
selectbackground="white", | |
selectforeground="black" | |
) | |
self.console_output.grid(row=0, column=0, sticky="nsew") | |
self.console_scrollbar = ttk.Scrollbar( | |
self.console_frame, | |
orient="vertical", | |
command=self.console_output.yview, | |
style="Vertical.TScrollbar" | |
) | |
self.console_scrollbar.grid(row=0, column=1, sticky="ns") | |
self.console_output.configure(yscrollcommand=self.console_scrollbar.set) | |
self.console_output.bind("<MouseWheel>", self.on_mousewheel) | |
self.console_output.bind("<Button-4>", self.on_mousewheel) # For Linux | |
self.console_output.bind("<Button-5>", self.on_mousewheel) # For Linux | |
self.console_output.bind("<Button-3>", self.show_context_menu) | |
self.training_tab.grid_rowconfigure(row, weight=1) | |
self.training_tab.grid_columnconfigure(1, weight=1) | |
self.console_frame.grid_rowconfigure(0, weight=1) | |
self.console_frame.grid_columnconfigure(0, weight=1) | |
def toggle_scaled(self): | |
"""Enable or disable the Scaled checkbox based on FP8 checkbox state""" | |
if self.fp8_var.get(): | |
self.scaled_check.config(state=tk.NORMAL) | |
else: | |
self.scaled_check.config(state=tk.DISABLED) | |
self.scaled_var.set(False) | |
def create_advanced_settings(self): | |
row = 0 | |
ttk.Label(self.advanced_tab, text="Advanced Settings", font=("Arial", 12, "bold")).grid( | |
row=row, column=0, columnspan=3, pady=(10, 10) | |
) | |
row += 1 | |
# Attention Mechanism | |
ttk.Label(self.advanced_tab, text="Attention Mechanism:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.attention_var = tk.StringVar(value=self.settings["ATTENTION_MECHANISM"]) | |
attention_options = ["none", "sdpa", "flash_attn", "sage_attn", "xformers", "flash3", "split_attn"] | |
self.entries["ATTENTION_MECHANISM"] = ttk.Combobox(self.advanced_tab, textvariable=self.attention_var, values=attention_options, state="readonly") | |
self.entries["ATTENTION_MECHANISM"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
# Logging | |
ttk.Label(self.advanced_tab, text="Logging Directory:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.entries["LOGGING_DIR"] = ttk.Entry(self.advanced_tab, width=40) | |
self.entries["LOGGING_DIR"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
ttk.Button(self.advanced_tab, text="Browse", command=lambda: self.browse_directory("LOGGING_DIR")).grid(row=row, column=2, sticky=tk.W, padx=5) | |
row += 1 | |
ttk.Label(self.advanced_tab, text="Log With:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.log_with_var = tk.StringVar(value=self.settings["LOG_WITH"]) | |
log_with_options = ["none", "tensorboard", "wandb", "all"] | |
self.entries["LOG_WITH"] = ttk.Combobox(self.advanced_tab, textvariable=self.log_with_var, values=log_with_options, state="readonly") | |
self.entries["LOG_WITH"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
ttk.Label(self.advanced_tab, text="Log Prefix:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.entries["LOG_PREFIX"] = ttk.Entry(self.advanced_tab, width=40) | |
self.entries["LOG_PREFIX"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
# Memory Management | |
self.img_in_txt_in_offloading_var = tk.BooleanVar(value=self.settings["IMG_IN_TXT_IN_OFFLOADING"]) | |
ttk.Checkbutton(self.advanced_tab, text="Offload img_in and txt_in to CPU", variable=self.img_in_txt_in_offloading_var).grid(row=row, column=0, columnspan=3, pady=5) | |
self.entries["IMG_IN_TXT_IN_OFFLOADING"] = self.img_in_txt_in_offloading_var | |
row += 1 | |
# Learning Rate Scheduler | |
ttk.Label(self.advanced_tab, text="LR Scheduler:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.lr_scheduler_var = tk.StringVar(value=self.settings["LR_SCHEDULER"]) | |
lr_scheduler_options = ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "adafactor"] | |
self.entries["LR_SCHEDULER"] = ttk.Combobox(self.advanced_tab, textvariable=self.lr_scheduler_var, values=lr_scheduler_options, state="readonly") | |
self.entries["LR_SCHEDULER"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
ttk.Label(self.advanced_tab, text="LR Warmup Steps:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.entries["LR_WARMUP_STEPS"] = ttk.Entry(self.advanced_tab, width=40) | |
self.entries["LR_WARMUP_STEPS"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
ttk.Label(self.advanced_tab, text="LR Decay Steps:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.entries["LR_DECAY_STEPS"] = ttk.Entry(self.advanced_tab, width=40) | |
self.entries["LR_DECAY_STEPS"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
# Timestep Sampling | |
ttk.Label(self.advanced_tab, text="Timestep Sampling:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.timestep_sampling_var = tk.StringVar(value=self.settings["TIMESTEP_SAMPLING"]) | |
timestep_sampling_options = ["sigma", "uniform", "sigmoid", "shift"] | |
self.entries["TIMESTEP_SAMPLING"] = ttk.Combobox(self.advanced_tab, textvariable=self.timestep_sampling_var, values=timestep_sampling_options, state="readonly") | |
self.entries["TIMESTEP_SAMPLING"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
ttk.Label(self.advanced_tab, text="Discrete Flow Shift:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.entries["DISCRETE_FLOW_SHIFT"] = ttk.Entry(self.advanced_tab, width=40) | |
self.entries["DISCRETE_FLOW_SHIFT"].insert(0, self.settings["DISCRETE_FLOW_SHIFT"]) | |
self.entries["DISCRETE_FLOW_SHIFT"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
# Weighting Scheme | |
ttk.Label(self.advanced_tab, text="Weighting Scheme:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.weighting_scheme_var = tk.StringVar(value=self.settings["WEIGHTING_SCHEME"]) | |
weighting_scheme_options = ["logit_normal", "mode", "cosmap", "sigma_sqrt", "none"] | |
self.entries["WEIGHTING_SCHEME"] = ttk.Combobox(self.advanced_tab, textvariable=self.weighting_scheme_var, values=weighting_scheme_options, state="readonly") | |
self.entries["WEIGHTING_SCHEME"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
# Metadata | |
ttk.Label(self.advanced_tab, text="Metadata Title:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.entries["METADATA_TITLE"] = ttk.Entry(self.advanced_tab, width=40) | |
self.entries["METADATA_TITLE"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
ttk.Label(self.advanced_tab, text="Metadata Author:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.entries["METADATA_AUTHOR"] = ttk.Entry(self.advanced_tab, width=40) | |
self.entries["METADATA_AUTHOR"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
ttk.Label(self.advanced_tab, text="Metadata Description:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.entries["METADATA_DESCRIPTION"] = ttk.Entry(self.advanced_tab, width=40) | |
self.entries["METADATA_DESCRIPTION"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
ttk.Label(self.advanced_tab, text="Metadata License:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.entries["METADATA_LICENSE"] = ttk.Entry(self.advanced_tab, width=40) | |
self.entries["METADATA_LICENSE"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
ttk.Label(self.advanced_tab, text="Metadata Tags:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.entries["METADATA_TAGS"] = ttk.Entry(self.advanced_tab, width=40) | |
self.entries["METADATA_TAGS"].grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
row += 1 | |
# Настройка столбца для автоматического расширения | |
self.advanced_tab.grid_columnconfigure(1, weight=1) | |
def create_conversion_settings(self): | |
"""Create the LoRA Conversion tab with input fields and buttons""" | |
row = 0 | |
ttk.Label(self.conversion_tab, text="LoRA Conversion Settings", font=("Arial", 12, "bold")).grid( | |
row=row, column=0, columnspan=3, pady=(10, 10) | |
) | |
row += 1 | |
# Input LoRA File | |
ttk.Label(self.conversion_tab, text="Input LoRA File:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.input_lora_entry = ttk.Entry(self.conversion_tab, width=40) | |
self.input_lora_entry.grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
self.input_lora_entry.insert(0, self.settings["INPUT_LORA"]) | |
ttk.Button(self.conversion_tab, text="Browse", command=self.browse_input_lora).grid(row=row, column=2, sticky=tk.W, padx=5) | |
row += 1 | |
# Output Directory | |
ttk.Label(self.conversion_tab, text="Output Directory:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.output_dir_entry = ttk.Entry(self.conversion_tab, width=40) | |
self.output_dir_entry.grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
self.output_dir_entry.insert(0, self.settings["OUTPUT_DIR"]) | |
ttk.Button(self.conversion_tab, text="Browse", command=self.browse_output_dir).grid(row=row, column=2, sticky=tk.W, padx=5) | |
row += 1 | |
# Converted LoRA Name | |
ttk.Label(self.conversion_tab, text="Converted LoRA Name:").grid(row=row, column=0, sticky=tk.W, padx=5, pady=2) | |
self.converted_lora_name_entry = ttk.Entry(self.conversion_tab, width=40) | |
self.converted_lora_name_entry.grid(row=row, column=1, sticky=tk.EW, padx=5, pady=2) | |
self.converted_lora_name_entry.insert(0, self.settings["CONVERTED_LORA_NAME"]) | |
row += 1 | |
# Convert Button | |
ttk.Button(self.conversion_tab, text="Convert", command=self.convert_lora).grid(row=row, column=0, columnspan=3, pady=10) | |
# Configure grid to expand horizontally | |
self.conversion_tab.grid_columnconfigure(1, weight=1) | |
# Add entries to self.entries for saving/loading | |
self.entries["INPUT_LORA"] = self.input_lora_entry | |
self.entries["OUTPUT_DIR"] = self.output_dir_entry | |
self.entries["CONVERTED_LORA_NAME"] = self.converted_lora_name_entry | |
def show_context_menu(self, event): | |
"""Show context menu on right-click""" | |
try: | |
self.context_menu.tk_popup(event.x_root, event.y_root) | |
finally: | |
self.context_menu.grab_release() | |
def copy_selected_text(self): | |
"""Copy selected text to clipboard""" | |
if self.console_output.selection_get(): | |
self.master.clipboard_clear() | |
self.master.clipboard_append(self.console_output.selection_get()) | |
def browse_directory(self, setting_name): | |
path = filedialog.askdirectory() | |
if path: | |
self.entries[setting_name].delete(0, tk.END) | |
self.entries[setting_name].insert(0, path) | |
def on_mousewheel(self, event): | |
"""Handle scroll event""" | |
if self.console_output.yview()[1] < 1.0: | |
self.user_scrolled = True | |
else: | |
self.user_scrolled = False | |
def update_console(self, line): | |
"""Update console with scroll handling""" | |
self.console_output.configure(state="normal") | |
self.console_output.insert(tk.END, line) | |
if not self.user_scrolled: | |
self.console_output.yview(tk.END) | |
self.console_output.configure(state="disabled") | |
def browse_file(self, setting_name, input_type): | |
if input_type == "directory": | |
path = filedialog.askdirectory() | |
else: | |
path = filedialog.askopenfilename() | |
if path: | |
self.settings[setting_name] = path | |
self.entries[setting_name].delete(0, tk.END) | |
self.entries[setting_name].insert(0, self.settings[setting_name]) | |
def browse_input_lora(self): | |
"""Browse for input LoRA file""" | |
file_path = filedialog.askopenfilename(filetypes=[("LoRA files", "*.safetensors")]) | |
if file_path: | |
self.input_lora_entry.delete(0, tk.END) | |
self.input_lora_entry.insert(0, file_path) | |
def browse_output_dir(self): | |
"""Browse for output directory""" | |
dir_path = filedialog.askdirectory() | |
if dir_path: | |
self.output_dir_entry.delete(0, tk.END) | |
self.output_dir_entry.insert(0, dir_path) | |
def convert_lora(self): | |
"""Convert the LoRA model using specified settings""" | |
input_path = self.input_lora_entry.get() | |
output_dir = self.output_dir_entry.get() | |
converted_name = self.converted_lora_name_entry.get() | |
if not input_path or not output_dir or not converted_name: | |
messagebox.showerror("Error", "Please fill in all fields.") | |
return | |
output_path = os.path.join(output_dir, converted_name + ".safetensors") | |
command = [ | |
sys.executable, "convert_lora.py", | |
"--input", input_path, | |
"--output", output_path, | |
"--target", "other" | |
] | |
self.run_subprocess(command, "Conversion") | |
def run_subprocess(self, cmd, name, callback=None): | |
"""Run a subprocess and handle its output with UTF-8 encoding""" | |
env = os.environ.copy() | |
env["PYTHONIOENCODING"] = "utf-8" # Устанавливаем UTF-8 для среды выполнения | |
if os.name == 'nt': | |
creationflags = subprocess.CREATE_NEW_PROCESS_GROUP | |
preexec_fn = None | |
else: | |
creationflags = 0 | |
preexec_fn = os.setsid | |
# Запускаем подпроцесс с явным указанием кодировки UTF-8 | |
process = subprocess.Popen( | |
cmd, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
text=True, # Включаем текстовый режим для автоматической декодировки | |
bufsize=1, # Построчная буферизация | |
universal_newlines=True, # Поддержка универсальных переносов строк | |
encoding='utf-8', # Явно указываем кодировку UTF-8 для вывода | |
env=env, | |
creationflags=creationflags, | |
preexec_fn=preexec_fn | |
) | |
self.current_process = process | |
if os.name == 'nt': | |
self.process_group_id = process.pid | |
def read_output(pipe, output_type): | |
"""Читает вывод подпроцесса построчно""" | |
while True: | |
line = pipe.readline() | |
if not line: | |
break | |
self.master.after(0, self.update_console, f"{name} {output_type}: {line}") | |
pipe.close() | |
# Запускаем потоки для чтения stdout и stderr | |
threading.Thread(target=read_output, args=(process.stdout, "STDOUT"), daemon=True).start() | |
threading.Thread(target=read_output, args=(process.stderr, "STDERR"), daemon=True).start() | |
def check_process(): | |
"""Проверяет завершение подпроцесса""" | |
process.wait() | |
self.master.after(0, self.update_console, f"{name} process completed.\n") | |
self.current_process = None | |
if callback: | |
callback() | |
threading.Thread(target=check_process, daemon=True).start() | |
def start_training(self): | |
"""Запускает обучение с последовательным выполнением процессов кэширования""" | |
# Check for unsupported optimizer | |
optimizer_type = self.entries["OPTIMIZER_TYPE"].get() | |
if optimizer_type == "came": | |
messagebox.showwarning( | |
"Предупреждение", | |
"Оптимизатор 'came' не поддерживается в текущей версии. Пожалуйста, выберите другой оптимизатор, например 'adamw' или 'adamw8bit'." | |
) | |
return | |
# Update settings from entries | |
self.settings.update({ | |
"MODEL_TYPE": self.entries["MODEL_TYPE"].get(), | |
"FLOW_SHIFT": float(self.entries["FLOW_SHIFT"].get()), | |
"LEARNING_RATE": float(self.entries["LEARNING_RATE"].get()), | |
"LORA_LR_RATIO": int(self.entries["LORA_LR_RATIO"].get()), | |
"NETWORK_DIM": int(self.entries["NETWORK_DIM"].get()), | |
"NETWORK_ALPHA": float(self.entries["NETWORK_ALPHA"].get()), | |
"MAX_TRAIN_EPOCHS": int(self.entries["MAX_TRAIN_EPOCHS"].get()), | |
"SAVE_EVERY_N_EPOCHS": int(self.entries["SAVE_EVERY_N_EPOCHS"].get()), | |
"SEED": int(self.entries["SEED"].get()), | |
"BLOCKS_SWAP": int(self.entries["BLOCKS_SWAP"].get()), | |
"DATASET_CONFIG": self.entries["DATASET_CONFIG"].get(), | |
"VAE_MODEL": self.entries["VAE_MODEL"].get(), | |
"CLIP_MODEL": self.entries["CLIP_MODEL"].get(), | |
"T5_MODEL": self.entries["T5_MODEL"].get(), | |
"DIT_MODEL": self.entries["DIT_MODEL"].get(), | |
"LORA_OUTPUT_DIR": self.entries["LORA_OUTPUT_DIR"].get(), | |
"LORA_NAME": self.entries["LORA_NAME"].get(), | |
"RESUME_TRAINING": self.entries["RESUME_TRAINING"].get(), | |
"OPTIMIZER_TYPE": optimizer_type, | |
"OPTIMIZER_ARGS": self.entries["OPTIMIZER_ARGS"].get(), | |
"ATTENTION_MECHANISM": self.entries["ATTENTION_MECHANISM"].get(), | |
"LOGGING_DIR": self.entries["LOGGING_DIR"].get(), | |
"LOG_WITH": self.entries["LOG_WITH"].get(), | |
"LOG_PREFIX": self.entries["LOG_PREFIX"].get(), | |
"IMG_IN_TXT_IN_OFFLOADING": self.entries["IMG_IN_TXT_IN_OFFLOADING"].get(), | |
"LR_SCHEDULER": self.entries["LR_SCHEDULER"].get(), | |
"LR_WARMUP_STEPS": self.entries["LR_WARMUP_STEPS"].get(), | |
"LR_DECAY_STEPS": self.entries["LR_DECAY_STEPS"].get(), | |
"TIMESTEP_SAMPLING": self.entries["TIMESTEP_SAMPLING"].get(), | |
"DISCRETE_FLOW_SHIFT": self.entries["DISCRETE_FLOW_SHIFT"].get(), | |
"WEIGHTING_SCHEME": self.entries["WEIGHTING_SCHEME"].get(), | |
"METADATA_TITLE": self.entries["METADATA_TITLE"].get(), | |
"METADATA_AUTHOR": self.entries["METADATA_AUTHOR"].get(), | |
"METADATA_DESCRIPTION": self.entries["METADATA_DESCRIPTION"].get(), | |
"METADATA_LICENSE": self.entries["METADATA_LICENSE"].get(), | |
"METADATA_TAGS": self.entries["METADATA_TAGS"].get(), | |
"FP8": self.fp8_var.get(), | |
"SCALED": self.scaled_var.get() | |
}) | |
# Build training command | |
command = [ | |
"accelerate", "launch", | |
"--num_cpu_threads_per_process", "2", | |
"--mixed_precision", "bf16", | |
"wan_train_network.py", | |
"--task", self.settings["MODEL_TYPE"], | |
"--dit", self.settings["DIT_MODEL"], | |
"--dataset_config", self.settings["DATASET_CONFIG"], | |
"--sdpa", | |
"--mixed_precision", "bf16", | |
] | |
# Добавляем параметры для Weight Optimization | |
if self.settings["FP8"]: | |
command.append("--fp8_base") | |
if self.settings["SCALED"]: | |
command.append("--fp8_scaled") | |
command.extend([ | |
"--blocks_to_swap", str(self.settings["BLOCKS_SWAP"]), | |
"--optimizer_type", self.settings["OPTIMIZER_TYPE"], | |
"--learning_rate", str(self.settings["LEARNING_RATE"]), | |
"--gradient_checkpointing", | |
"--max_data_loader_n_workers", "2", | |
"--persistent_data_loader_workers", | |
"--network_module", "networks.lora_wan", | |
"--network_dim", str(self.settings["NETWORK_DIM"]), | |
"--network_alpha", str(self.settings["NETWORK_ALPHA"]), | |
"--network_args", f"loraplus_lr_ratio={self.settings['LORA_LR_RATIO']}", | |
"--timestep_sampling", self.settings["TIMESTEP_SAMPLING"], | |
"--discrete_flow_shift", str(self.settings["DISCRETE_FLOW_SHIFT"]), | |
"--max_train_epochs", str(self.settings["MAX_TRAIN_EPOCHS"]), | |
"--save_every_n_epochs", str(self.settings["SAVE_EVERY_N_EPOCHS"]), | |
"--save_state", | |
"--seed", str(self.settings["SEED"]), | |
"--output_dir", self.settings["LORA_OUTPUT_DIR"], | |
"--output_name", self.settings["LORA_NAME"], | |
]) | |
if self.settings["OPTIMIZER_ARGS"]: | |
command.extend(["--optimizer_args", self.settings["OPTIMIZER_ARGS"]]) | |
attention = self.settings["ATTENTION_MECHANISM"] | |
if attention != "none": | |
command.append(f"--{attention}") | |
logging_dir = self.settings["LOGGING_DIR"] | |
if logging_dir: | |
command.extend(["--logging_dir", logging_dir]) | |
log_with = self.settings["LOG_WITH"] | |
if log_with != "none": | |
command.extend(["--log_with", log_with]) | |
log_prefix = self.settings["LOG_PREFIX"] | |
if log_prefix: | |
command.extend(["--log_prefix", log_prefix]) | |
if self.settings["IMG_IN_TXT_IN_OFFLOADING"]: | |
command.append("--img_in_txt_in_offloading") | |
lr_scheduler = self.settings["LR_SCHEDULER"] | |
if lr_scheduler: | |
command.extend(["--lr_scheduler", lr_scheduler]) | |
lr_warmup_steps = self.settings["LR_WARMUP_STEPS"] | |
if lr_warmup_steps: | |
command.extend(["--lr_warmup_steps", lr_warmup_steps]) | |
lr_decay_steps = self.settings["LR_DECAY_STEPS"] | |
if lr_decay_steps: | |
command.extend(["--lr_decay_steps", lr_decay_steps]) | |
weighting_scheme = self.settings["WEIGHTING_SCHEME"] | |
if weighting_scheme != "none": | |
command.extend(["--weighting_scheme", weighting_scheme]) | |
metadata_title = self.settings["METADATA_TITLE"] | |
if metadata_title: | |
command.extend(["--metadata_title", metadata_title]) | |
metadata_author = self.settings["METADATA_AUTHOR"] | |
if metadata_author: | |
command.extend(["--metadata_author", metadata_author]) | |
metadata_description = self.settings["METADATA_DESCRIPTION"] | |
if metadata_description: | |
command.extend(["--metadata_description", metadata_description]) | |
metadata_license = self.settings["METADATA_LICENSE"] | |
if metadata_license: | |
command.extend(["--metadata_license", metadata_license]) | |
metadata_tags = self.settings["METADATA_TAGS"] | |
if metadata_tags: | |
command.extend(["--metadata_tags", metadata_tags]) | |
if self.settings["RESUME_TRAINING"].strip(): | |
command.append(f"--resume={self.settings['RESUME_TRAINING']}") | |
cache_preparation_command = [ | |
sys.executable, "wan_cache_latents.py", | |
"--dataset_config", self.settings["DATASET_CONFIG"], | |
"--vae", self.settings["VAE_MODEL"], | |
"--clip", self.settings["CLIP_MODEL"] | |
] | |
text_encoder_caching_command = [ | |
sys.executable, "wan_cache_text_encoder_outputs.py", | |
"--dataset_config", self.settings["DATASET_CONFIG"], | |
"--t5", self.settings["T5_MODEL"], | |
"--batch_size", "16", | |
"--fp8_t5" | |
] | |
self.console_output.configure(state="normal") | |
self.console_output.delete(1.0, tk.END) | |
self.console_output.configure(state="disabled") | |
if self.enable_cache_var.get(): | |
self.update_console("Starting cache preparation...\n") | |
def on_text_encoder_caching_complete(): | |
self.update_console("Text encoder caching completed.\nStarting training...\n") | |
self.run_subprocess(command, "Training") | |
def on_cache_preparation_complete(): | |
self.update_console("Cache preparation completed.\nStarting text encoder caching...\n") | |
self.run_subprocess(text_encoder_caching_command, "Text Encoder Caching", on_text_encoder_caching_complete) | |
self.run_subprocess(cache_preparation_command, "Cache Preparation", on_cache_preparation_complete) | |
else: | |
self.update_console("Starting training without caching...\n") | |
self.run_subprocess(command, "Training") | |
def stop_training(self): | |
"""Stop the current running process""" | |
if self.current_process and self.current_process.poll() is None: | |
try: | |
if os.name == 'nt': | |
self.current_process.send_signal(signal.CTRL_BREAK_EVENT) | |
else: | |
os.killpg(os.getpgid(self.current_process.pid), signal.SIGTERM) | |
except Exception as e: | |
self.update_console("Error stopping process: " + str(e) + "\n") | |
try: | |
self.current_process.wait(timeout=5) | |
except subprocess.TimeoutExpired: | |
try: | |
self.current_process.kill() | |
self.current_process.wait() | |
except Exception as e: | |
self.update_console("Error killing process: " + str(e) + "\n") | |
self.current_process = None | |
if self.training_thread: | |
self.training_thread.join(timeout=1) | |
self.training_thread = None | |
self.update_console("Training stopped\n") | |
else: | |
self.update_console("No active process to stop\n") | |
def save_settings(self): | |
"""Save all settings, including conversion settings, to a JSON file""" | |
current_settings = {} | |
for key, entry in self.entries.items(): | |
if isinstance(entry, ttk.Combobox): | |
current_settings[key] = entry.get() | |
elif isinstance(entry, tk.BooleanVar): | |
current_settings[key] = entry.get() | |
else: | |
current_settings[key] = entry.get() | |
current_settings["FP8"] = self.fp8_var.get() | |
current_settings["SCALED"] = self.scaled_var.get() | |
current_settings["ENABLE_CACHE"] = self.enable_cache_var.get() | |
file_path = filedialog.asksaveasfilename(defaultextension=".json", filetypes=[("JSON files", "*.json")]) | |
if file_path: | |
with open(file_path, "w") as f: | |
json.dump(current_settings, f, indent=4) | |
def load_settings(self): | |
"""Load settings from a JSON file, including conversion settings""" | |
file_path = filedialog.askopenfilename(filetypes=[("JSON files", "*.json")]) | |
if file_path: | |
with open(file_path, "r") as f: | |
loaded_settings = json.load(f) | |
for key, value in loaded_settings.items(): | |
if key in self.entries: | |
if isinstance(self.entries[key], ttk.Combobox): | |
self.entries[key].set(value) | |
elif isinstance(self.entries[key], tk.BooleanVar): | |
self.entries[key].set(value) | |
else: | |
self.entries[key].delete(0, tk.END) | |
self.entries[key].insert(0, value) | |
if "FP8" in loaded_settings: | |
self.fp8_var.set(loaded_settings["FP8"]) | |
if "SCALED" in loaded_settings: | |
self.scaled_var.set(loaded_settings["SCALED"]) | |
if "ENABLE_CACHE" in loaded_settings: | |
self.enable_cache_var.set(loaded_settings["ENABLE_CACHE"]) | |
self.toggle_scaled() # Update Scaled checkbox state based on FP8 | |
root = tk.Tk() | |
gui = LoRATrainerGUI(root) | |
root.mainloop() |