fffiloni commited on
Commit
d516aa4
·
verified ·
1 Parent(s): 093bdb0

cache method for model downloads

Browse files
Files changed (1) hide show
  1. app.py +23 -26
app.py CHANGED
@@ -1,31 +1,34 @@
 
 
 
 
 
 
 
1
  from huggingface_hub import snapshot_download
2
 
3
  # Download All Required Models using `snapshot_download`
4
 
5
- # Download Wan2.1-I2V-14B-480P model
6
- wan_model_path = snapshot_download(
7
- repo_id="Wan-AI/Wan2.1-I2V-14B-480P",
8
- local_dir="./weights/Wan2.1-I2V-14B-480P",
9
- #local_dir_use_symlinks=False
10
- )
 
 
 
 
 
11
 
12
- # Download Chinese wav2vec2 model
13
- wav2vec_path = snapshot_download(
14
- repo_id="TencentGameMate/chinese-wav2vec2-base",
15
- local_dir="./weights/chinese-wav2vec2-base",
16
- #local_dir_use_symlinks=False
17
- )
18
 
19
- # Download MeiGen MultiTalk weights
20
- multitalk_path = snapshot_download(
21
- repo_id="MeiGen-AI/MeiGen-MultiTalk",
22
- local_dir="./weights/MeiGen-MultiTalk",
23
- #local_dir_use_symlinks=False
24
- )
25
 
 
 
 
26
 
27
- import os
28
- import shutil
29
 
30
  # Define paths
31
  base_model_dir = "./weights/Wan2.1-I2V-14B-480P"
@@ -55,7 +58,6 @@ shutil.copy2(
55
  print("Copied MultiTalk files into base model directory.")
56
 
57
 
58
- import torch
59
 
60
  # Check if CUDA-compatible GPU is available
61
  if torch.cuda.is_available():
@@ -81,11 +83,7 @@ GPU_TO_VRAM_PARAMS = {
81
  USED_VRAM_PARAMS = GPU_TO_VRAM_PARAMS[gpu_name]
82
  print("Using", USED_VRAM_PARAMS, "for num_persistent_param_in_dit")
83
 
84
- import subprocess
85
 
86
- import json
87
- import tempfile
88
- #import os
89
 
90
  def create_temp_input_json(prompt: str, cond_image_path: str, cond_audio_path: str) -> str:
91
  """
@@ -136,7 +134,6 @@ def infer(prompt, cond_image_path, cond_audio_path):
136
 
137
  return "multi_long_mediumvra_exp.mp4"
138
 
139
- import gradio as gr
140
 
141
  with gr.Blocks(title="MultiTalk Inference") as demo:
142
  gr.Markdown("## 🎤 MultiTalk Inference Demo")
 
1
+ import torch
2
+ import os
3
+ import shutil
4
+ import subprocess
5
+ import gradio as gr
6
+ import json
7
+ import tempfile
8
  from huggingface_hub import snapshot_download
9
 
10
  # Download All Required Models using `snapshot_download`
11
 
12
+ def download_and_extract(repo_id, target_dir):
13
+ """
14
+ Downloads a model repo (cached) and copies its contents to a local target directory.
15
+ If the target_dir exists, it will be updated (not re-downloaded if cache is present).
16
+ """
17
+ print(f"Downloading {repo_id} into cache...")
18
+ snapshot_path = snapshot_download(repo_id)
19
+
20
+ print(f"Copying files to {target_dir}...")
21
+ os.makedirs(target_dir, exist_ok=True)
22
+ shutil.copytree(snapshot_path, target_dir, dirs_exist_ok=True)
23
 
24
+ print(f"Done: {repo_id} extracted to {target_dir}")
25
+ return target_dir
 
 
 
 
26
 
 
 
 
 
 
 
27
 
28
+ wan_model_path = download_and_extract("Wan-AI/Wan2.1-I2V-14B-480P", "./weights/Wan2.1-I2V-14B-480P")
29
+ wav2vec_path = download_and_extract("TencentGameMate/chinese-wav2vec2-base", "./weights/chinese-wav2vec2-base")
30
+ multitalk_path = download_and_extract("MeiGen-AI/MeiGen-MultiTalk", "./weights/MeiGen-MultiTalk")
31
 
 
 
32
 
33
  # Define paths
34
  base_model_dir = "./weights/Wan2.1-I2V-14B-480P"
 
58
  print("Copied MultiTalk files into base model directory.")
59
 
60
 
 
61
 
62
  # Check if CUDA-compatible GPU is available
63
  if torch.cuda.is_available():
 
83
  USED_VRAM_PARAMS = GPU_TO_VRAM_PARAMS[gpu_name]
84
  print("Using", USED_VRAM_PARAMS, "for num_persistent_param_in_dit")
85
 
 
86
 
 
 
 
87
 
88
  def create_temp_input_json(prompt: str, cond_image_path: str, cond_audio_path: str) -> str:
89
  """
 
134
 
135
  return "multi_long_mediumvra_exp.mp4"
136
 
 
137
 
138
  with gr.Blocks(title="MultiTalk Inference") as demo:
139
  gr.Markdown("## 🎤 MultiTalk Inference Demo")