Spaces:
Runtime error
Runtime error
kundaja-green
commited on
Commit
·
0d8b1b0
1
Parent(s):
ebb79f2
Implement persistent storage workflow to fix rate limiting
Browse files- Dockerfile +7 -31
- start.sh +39 -0
Dockerfile
CHANGED
@@ -14,36 +14,12 @@ COPY requirements.txt .
|
|
14 |
RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
15 |
RUN pip install --no-cache-dir -r requirements.txt
|
16 |
|
17 |
-
#
|
18 |
-
# Download the official Wan2.1 models from their Hugging Face repository
|
19 |
-
# This downloads them into a "Models/Wan" folder inside the container
|
20 |
-
RUN huggingface-cli download wan-video/wan2.1 \
|
21 |
-
--repo-type model \
|
22 |
-
--include "*.pth" "*.json" "*.safetensors" \
|
23 |
-
--local-dir Models/Wan --local-dir-use-symlinks False
|
24 |
-
|
25 |
-
# Copy all your project files (code, dataset configs, etc.) into the container
|
26 |
COPY . .
|
27 |
|
28 |
-
#
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
"--clip", "Models/Wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", \
|
35 |
-
"--t5", "Models/Wan/models_t5_umt5-xxl-enc-bf16.pth", \
|
36 |
-
"--dataset_config", "dataset/testtoml.toml", \
|
37 |
-
"--output_dir", "/data/output", \
|
38 |
-
"--output_name", "My_HF_Lora_v1", \
|
39 |
-
"--save_every_n_epochs", "10", \
|
40 |
-
"--max_train_epochs", "70", \
|
41 |
-
"--network_module", "networks.lora_wan", \
|
42 |
-
"--network_dim", "32", \
|
43 |
-
"--network_alpha", "4", \
|
44 |
-
"--learning_rate", "2e-5", \
|
45 |
-
"--optimizer_type", "adamw", \
|
46 |
-
"--mixed_precision", "bf16", \
|
47 |
-
"--gradient_checkpointing", \
|
48 |
-
"--sdpa" \
|
49 |
-
]
|
|
|
14 |
RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
15 |
RUN pip install --no-cache-dir -r requirements.txt
|
16 |
|
17 |
+
# Copy all project files, including the new start.sh script
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
COPY . .
|
19 |
|
20 |
+
# Make the startup script executable
|
21 |
+
RUN chmod +x start.sh
|
22 |
+
|
23 |
+
# The new command is just to run the script.
|
24 |
+
# The script itself handles downloading models and starting the training.
|
25 |
+
CMD ["./start.sh"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start.sh
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Define the target directory for the models on persistent storage
|
4 |
+
MODEL_DIR="/data/Models/Wan"
|
5 |
+
|
6 |
+
# Check if a key model file already exists in the persistent storage
|
7 |
+
if [ -f "$MODEL_DIR/wan2.1_i2v_720p_14B_fp8_e4m3fn.safetensors" ]; then
|
8 |
+
echo "Models already exist in persistent storage. Skipping download."
|
9 |
+
else
|
10 |
+
echo "Models not found. Downloading to persistent storage..."
|
11 |
+
# If models don't exist, download them to the /data directory
|
12 |
+
huggingface-cli download Wan-AI/Wan2.1-I2V-14B-720P \
|
13 |
+
--repo-type model \
|
14 |
+
--include "*.pth" "*.json" "*.safetensors" \
|
15 |
+
--local-dir $MODEL_DIR --local-dir-use-symlinks False
|
16 |
+
echo "Download complete."
|
17 |
+
fi
|
18 |
+
|
19 |
+
# Now, run the actual training command using the models from persistent storage
|
20 |
+
echo "Starting training..."
|
21 |
+
accelerate launch wan_train_network.py \
|
22 |
+
--task "i2v-14B" \
|
23 |
+
--dit "$MODEL_DIR/wan2.1_i2v_720p_14B_fp8_e4m3fn.safetensors" \
|
24 |
+
--vae "$MODEL_DIR/Wan2.1_VAE.pth" \
|
25 |
+
--clip "$MODEL_DIR/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
26 |
+
--t5 "$MODEL_DIR/models_t5_umt5-xxl-enc-bf16.pth" \
|
27 |
+
--dataset_config "dataset/testtoml.toml" \
|
28 |
+
--output_dir "/data/output" \
|
29 |
+
--output_name "My_HF_Lora_v1" \
|
30 |
+
--save_every_n_epochs "10" \
|
31 |
+
--max_train_epochs "70" \
|
32 |
+
--network_module "networks.lora_wan" \
|
33 |
+
--network_dim "32" \
|
34 |
+
--network_alpha "4" \
|
35 |
+
--learning_rate "2e-5" \
|
36 |
+
--optimizer_type "adamw" \
|
37 |
+
--mixed_precision "bf16" \
|
38 |
+
--gradient_checkpointing \
|
39 |
+
--sdpa
|