kemuriririn commited on
Commit
a70eba7
·
1 Parent(s): c4fe16f
Files changed (2) hide show
  1. tools/download_files.py +9 -71
  2. webui.py +3 -15
tools/download_files.py CHANGED
@@ -3,30 +3,6 @@ import zipfile
3
  import os
4
  import argparse
5
 
6
- def download_file_from_google_drive(file_id, destination):
7
- """
8
- 通过文件ID下载Google Drive共享文件
9
-
10
- Args:
11
- file_id (str): Google Drive文件的ID
12
- destination (str): 本地保存路径
13
- """
14
- # 基本的下载URL
15
- URL = "https://docs.google.com/uc?export=download"
16
-
17
- session = requests.Session()
18
-
19
- # 发起初始GET请求
20
- response = session.get(URL, params={'id': file_id}, stream=True)
21
- token = get_confirm_token(response) # 从响应中获取确认令牌(如果需要)
22
-
23
- if token: # 如果需要确认(大文件)
24
- params = {'id': file_id, 'confirm': token}
25
- response = session.get(URL, params=params, stream=True)
26
-
27
- # 将响应内容保存到文件
28
- save_response_content(response, destination)
29
-
30
  def get_confirm_token(response):
31
  """
32
  从响应中检查是否存在下载确认令牌(cookie)
@@ -57,54 +33,27 @@ def save_response_content(response, destination, chunk_size=32768):
57
  f.write(chunk)
58
 
59
  def download_model_from_modelscope(destination,hf_cache_dir):
60
- """
61
- 从ModelScope下载模型(伪代码,需根据实际API实现)
62
- Args:
63
- model_id (str): ModelScope模型ID
64
- destination (str): 本地保存路径
65
- """
66
  print(f"[ModelScope] Downloading models to {destination},model cache dir={hf_cache_dir}")
67
  from modelscope import snapshot_download
68
- os.makedirs(os.path.join(hf_cache_dir, "models--amphion--MaskGCT"), exist_ok=True)
69
- os.makedirs(os.path.join(hf_cache_dir, "models--facebook--w2v-bert-2.0"), exist_ok=True)
70
- os.makedirs(os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"), exist_ok=True)
71
- os.makedirs(os.path.join(hf_cache_dir, "models--funasr--campplus"), exist_ok=True)
72
- snapshot_download("IndexTeam/IndexTTS-2", local_dir="checkpoints")
73
- snapshot_download("amphion/MaskGCT", local_dir="checkpoints/hf_cache/models--amphion--MaskGCT")
74
- snapshot_download("facebook/w2v-bert-2.0",local_dir="checkpoints/hf_cache/models--facebook--w2v-bert-2.0")
75
- snapshot_download("nv-community/bigvgan_v2_22khz_80band_256x",local_dir="checkpoints/hf_cache/models--nvidia--bigvgan_v2_22khz_80band_256x")
76
- # models--funasr--campplus
77
- snapshot_download("nv-community/bigvgan_v2_22khz_80band_256x",local_dir="checkpoints/hf_cache/models--nvidia--bigvgan_v2_22khz_80band_256x")
78
 
79
  def download_model_from_huggingface(destination,hf_cache_dir):
80
- """
81
- 从HuggingFace下载模型(伪代码,需根据实际API实现)
82
- Args:
83
- model_id (str): HuggingFace模型ID
84
- destination (str): 本地保存路径
85
- """
86
  print(f"[HuggingFace] Downloading models to {destination},model cache dir={hf_cache_dir}")
87
  from huggingface_hub import snapshot_download
88
- os.makedirs(os.path.join(hf_cache_dir,"models--amphion--MaskGCT"), exist_ok=True)
89
- os.makedirs(os.path.join(hf_cache_dir,"models--facebook--w2v-bert-2.0"), exist_ok=True)
90
- os.makedirs(os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"), exist_ok=True)
91
- os.makedirs(os.path.join(hf_cache_dir,"models--funasr--campplus"), exist_ok=True)
92
  snapshot_download("IndexTeam/IndexTTS-2", local_dir=destination)
93
- print("[HuggingFace] IndexTTS-2 Download finished")
94
- # snapshot_download("amphion/MaskGCT", local_dir=os.path.join(hf_cache_dir,"models--amphion--MaskGCT"))
95
- # print("[HuggingFace] MaskGCT Download finished")
96
- # snapshot_download("facebook/w2v-bert-2.0",local_dir=os.path.join(hf_cache_dir,"models--facebook--w2v-bert-2.0"))
97
- snapshot_download("facebook/w2v-bert-2.0")
98
- print("[HuggingFace] w2v-bert-2.0 Download finished")
99
  snapshot_download("nvidia/bigvgan_v2_22khz_80band_256x",local_dir=os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"))
100
- print("[HuggingFace] bigvgan_v2_22khz_80band_256x Download finished")
101
  snapshot_download("funasr/campplus",local_dir=os.path.join(hf_cache_dir,"models--funasr--campplus"))
102
- print("[HuggingFace] campplus Download finished")
103
 
104
  # 使用示例
105
  if __name__ == "__main__":
106
- parser = argparse.ArgumentParser(description="下载文件和模型工具")
107
- parser.add_argument('--model_source', choices=['modelscope', 'huggingface'], default=None, help='模型下载来源')
108
  args = parser.parse_args()
109
 
110
  if args.model_source:
@@ -112,14 +61,3 @@ if __name__ == "__main__":
112
  download_model_from_modelscope("checkpoints",os.path.join("checkpoints","hf_cache"))
113
  elif args.model_source == 'huggingface':
114
  download_model_from_huggingface("checkpoints",os.path.join("checkpoints","hf_cache"))
115
-
116
- print("Downloading example files from Google Drive...")
117
- file_id = "1o_dCMzwjaA2azbGOxAE7-4E7NbJkgdgO"
118
- destination = "example_wavs.zip" # 替换为你希望的本地路径
119
- download_file_from_google_drive(file_id, destination)
120
- print(f"File downloaded to: {destination}")
121
- # 解压下载的zip文件到examples目录
122
- examples_dir = "examples"
123
- with zipfile.ZipFile(destination, 'r') as zip_ref:
124
- zip_ref.extractall(examples_dir)
125
- print(f"File extracted to: {examples_dir}")
 
3
  import os
4
  import argparse
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def get_confirm_token(response):
7
  """
8
  从响应中检查是否存在下载确认令牌(cookie)
 
33
  f.write(chunk)
34
 
35
  def download_model_from_modelscope(destination,hf_cache_dir):
 
 
 
 
 
 
36
  print(f"[ModelScope] Downloading models to {destination},model cache dir={hf_cache_dir}")
37
  from modelscope import snapshot_download
38
+ snapshot_download("IndexTeam/IndexTTS-2", local_dir=destination)
39
+ snapshot_download("amphion/MaskGCT", local_dir=os.path.join(hf_cache_dir,"models--amphion--MaskGCT"))
40
+ snapshot_download("facebook/w2v-bert-2.0",local_dir=os.path.join(hf_cache_dir,"models--facebook--w2v-bert-2.0"))
41
+ snapshot_download("nv-community/bigvgan_v2_22khz_80band_256x",local_dir=os.path.join(hf_cache_dir,"models--nvidia--bigvgan_v2_22khz_80band_256x"))
42
+ snapshot_download("iic/speech_campplus_sv_zh-cn_16k-common",local_dir=os.path.join(hf_cache_dir,"models--funasr--campplus"))
 
 
 
 
 
43
 
44
  def download_model_from_huggingface(destination,hf_cache_dir):
 
 
 
 
 
 
45
  print(f"[HuggingFace] Downloading models to {destination},model cache dir={hf_cache_dir}")
46
  from huggingface_hub import snapshot_download
 
 
 
 
47
  snapshot_download("IndexTeam/IndexTTS-2", local_dir=destination)
48
+ snapshot_download("amphion/MaskGCT", local_dir=os.path.join(hf_cache_dir,"models--amphion--MaskGCT"))
49
+ snapshot_download("facebook/w2v-bert-2.0",local_dir=os.path.join(hf_cache_dir,"models--facebook--w2v-bert-2.0"))
 
 
 
 
50
  snapshot_download("nvidia/bigvgan_v2_22khz_80band_256x",local_dir=os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"))
 
51
  snapshot_download("funasr/campplus",local_dir=os.path.join(hf_cache_dir,"models--funasr--campplus"))
 
52
 
53
  # 使用示例
54
  if __name__ == "__main__":
55
+ parser = argparse.ArgumentParser(description="Download models and example files")
56
+ parser.add_argument('-s','--model_source', choices=['modelscope', 'huggingface'], default=None, help='Model source')
57
  args = parser.parse_args()
58
 
59
  if args.model_source:
 
61
  download_model_from_modelscope("checkpoints",os.path.join("checkpoints","hf_cache"))
62
  elif args.model_source == 'huggingface':
63
  download_model_from_huggingface("checkpoints",os.path.join("checkpoints","hf_cache"))
 
 
 
 
 
 
 
 
 
 
 
webui.py CHANGED
@@ -32,21 +32,9 @@ parser.add_argument("--cuda_kernel", action="store_true", default=False, help="U
32
  parser.add_argument("--gui_seg_tokens", type=int, default=120, help="GUI: Max tokens per generation segment")
33
  cmd_args = parser.parse_args()
34
 
35
- if not os.path.exists(cmd_args.model_dir):
36
- print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.")
37
- sys.exit(1)
38
-
39
- for file in [
40
- "bpe.model",
41
- "gpt.pth",
42
- "config.yaml",
43
- "s2mel.pth",
44
- "wav2vec2bert_stats.pt"
45
- ]:
46
- file_path = os.path.join(cmd_args.model_dir, file)
47
- if not os.path.exists(file_path):
48
- print(f"Required file {file_path} does not exist. Please download it.")
49
- sys.exit(1)
50
 
51
  import gradio as gr
52
  from indextts.infer_v2 import IndexTTS2
 
32
  parser.add_argument("--gui_seg_tokens", type=int, default=120, help="GUI: Max tokens per generation segment")
33
  cmd_args = parser.parse_args()
34
 
35
+ from tools.download_files import download_model_from_huggingface
36
+ download_model_from_huggingface(os.path.join(current_dir,"checkpoints"),
37
+ os.path.join(current_dir, "checkpoints","hf_cache"))
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  import gradio as gr
40
  from indextts.infer_v2 import IndexTTS2