kundaja-green commited on
Commit
0d8b1b0
·
1 Parent(s): ebb79f2

Implement persistent storage workflow to fix rate limiting

Browse files
Files changed (2) hide show
  1. Dockerfile +7 -31
  2. start.sh +39 -0
Dockerfile CHANGED
@@ -14,36 +14,12 @@ COPY requirements.txt .
14
  RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
15
  RUN pip install --no-cache-dir -r requirements.txt
16
 
17
- # --- NEW SECTION: DOWNLOAD MODELS ---
18
- # Download the official Wan2.1 models from their Hugging Face repository
19
- # This downloads them into a "Models/Wan" folder inside the container
20
- RUN huggingface-cli download wan-video/wan2.1 \
21
- --repo-type model \
22
- --include "*.pth" "*.json" "*.safetensors" \
23
- --local-dir Models/Wan --local-dir-use-symlinks False
24
-
25
- # Copy all your project files (code, dataset configs, etc.) into the container
26
  COPY . .
27
 
28
- # This is the command that will run when the Space starts.
29
- # It uses the models we just downloaded.
30
- CMD ["accelerate", "launch", "wan_train_network.py", \
31
- "--task", "i2v-14B", \
32
- "--dit", "Models/Wan/wan2.1_i2v_720p_14B_fp8_e4m3fn.safetensors", \
33
- "--vae", "Models/Wan/Wan2.1_VAE.pth", \
34
- "--clip", "Models/Wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", \
35
- "--t5", "Models/Wan/models_t5_umt5-xxl-enc-bf16.pth", \
36
- "--dataset_config", "dataset/testtoml.toml", \
37
- "--output_dir", "/data/output", \
38
- "--output_name", "My_HF_Lora_v1", \
39
- "--save_every_n_epochs", "10", \
40
- "--max_train_epochs", "70", \
41
- "--network_module", "networks.lora_wan", \
42
- "--network_dim", "32", \
43
- "--network_alpha", "4", \
44
- "--learning_rate", "2e-5", \
45
- "--optimizer_type", "adamw", \
46
- "--mixed_precision", "bf16", \
47
- "--gradient_checkpointing", \
48
- "--sdpa" \
49
- ]
 
14
  RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
15
  RUN pip install --no-cache-dir -r requirements.txt
16
 
17
+ # Copy all project files, including the new start.sh script
 
 
 
 
 
 
 
 
18
  COPY . .
19
 
20
+ # Make the startup script executable
21
+ RUN chmod +x start.sh
22
+
23
+ # The new command is just to run the script.
24
+ # The script itself handles downloading models and starting the training.
25
+ CMD ["./start.sh"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
start.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Define the target directory for the models on persistent storage
4
+ MODEL_DIR="/data/Models/Wan"
5
+
6
+ # Check if a key model file already exists in the persistent storage
7
+ if [ -f "$MODEL_DIR/wan2.1_i2v_720p_14B_fp8_e4m3fn.safetensors" ]; then
8
+ echo "Models already exist in persistent storage. Skipping download."
9
+ else
10
+ echo "Models not found. Downloading to persistent storage..."
11
+ # If models don't exist, download them to the /data directory
12
+ huggingface-cli download Wan-AI/Wan2.1-I2V-14B-720P \
13
+ --repo-type model \
14
+ --include "*.pth" "*.json" "*.safetensors" \
15
+ --local-dir $MODEL_DIR --local-dir-use-symlinks False
16
+ echo "Download complete."
17
+ fi
18
+
19
+ # Now, run the actual training command using the models from persistent storage
20
+ echo "Starting training..."
21
+ accelerate launch wan_train_network.py \
22
+ --task "i2v-14B" \
23
+ --dit "$MODEL_DIR/wan2.1_i2v_720p_14B_fp8_e4m3fn.safetensors" \
24
+ --vae "$MODEL_DIR/Wan2.1_VAE.pth" \
25
+ --clip "$MODEL_DIR/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
26
+ --t5 "$MODEL_DIR/models_t5_umt5-xxl-enc-bf16.pth" \
27
+ --dataset_config "dataset/testtoml.toml" \
28
+ --output_dir "/data/output" \
29
+ --output_name "My_HF_Lora_v1" \
30
+ --save_every_n_epochs "10" \
31
+ --max_train_epochs "70" \
32
+ --network_module "networks.lora_wan" \
33
+ --network_dim "32" \
34
+ --network_alpha "4" \
35
+ --learning_rate "2e-5" \
36
+ --optimizer_type "adamw" \
37
+ --mixed_precision "bf16" \
38
+ --gradient_checkpointing \
39
+ --sdpa