Gregniuki commited on
Commit
774f8ef
·
verified ·
1 Parent(s): 4c21e38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -35
app.py CHANGED
@@ -68,29 +68,43 @@ speed = 1
68
  fix_duration = None
69
 
70
 
71
- def load_model(page_name, repo_name, exp_name, model_cls, model_cfg, ckpt_step):
72
- ckpt_path = str(cached_path(f"hf://{page_name}/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
73
- # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
74
- vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
75
- model = CFM(
76
- transformer=model_cls(
77
- **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
78
- ),
79
- mel_spec_kwargs=dict(
80
- target_sample_rate=target_sample_rate,
81
- n_mel_channels=n_mel_channels,
82
- hop_length=hop_length,
83
- ),
84
- odeint_kwargs=dict(
85
- method=ode_method,
86
- ),
87
- vocab_char_map=vocab_char_map,
88
- ).to(device)
89
- dtype = None
90
-
91
- model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema = True)
92
-
93
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
 
96
  # load models
@@ -99,21 +113,16 @@ F5TTS_model_cfg = dict(
99
  )
100
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
101
 
102
- F5TTS_ema_model = load_model(
103
- "Gregniuki", "F5-tts_English_German_Polish", "English", DiT, F5TTS_model_cfg, 222600
104
  )
105
- E2TTS_ema_model = load_model(
106
- "Gregniuki", "F5-tts_English_German_Polish", "Polish2", DiT, F5TTS_model_cfg, 1200000
107
  )
108
- E2TTS_ema_model2 = load_model(
109
- "Gregniuki", "F5-tts_English_German_Polish", "Polish", DiT, F5TTS_model_cfg, 500000
110
- )
111
- E2TTS_ema_model3 = load_model(
112
- "SWivid", "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
113
- )
114
- E2TTS_ema_model4 = load_model(
115
- "SWivid", "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
116
  )
 
117
  def chunk_text(text, max_chars=135):
118
  """
119
  Splits the input text into chunks, each with a maximum number of characters.
 
68
  fix_duration = None
69
 
70
 
71
+ DEFAULT_TTS_MODEL = "F5-TTS"
72
+ tts_model_choice = DEFAULT_TTS_MODEL
73
+
74
+
75
+ # load models
76
+
77
+ vocoder = load_vocoder()
78
+
79
+
80
+ def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))):
81
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
82
+ return load_model(DiT, F5TTS_model_cfg, ckpt_path)
83
+
84
+
85
+ def load_e2tts(ckpt_path=str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))):
86
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
87
+ return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
88
+
89
+
90
+ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
91
+ ckpt_path, vocab_path = ckpt_path.strip(), vocab_path.strip()
92
+ if ckpt_path.startswith("hf://"):
93
+ ckpt_path = str(cached_path(ckpt_path))
94
+ if vocab_path.startswith("hf://"):
95
+ vocab_path = str(cached_path(vocab_path))
96
+ if model_cfg is None:
97
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
98
+ return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
99
+
100
+
101
+ F2TTS_ema_model3 = load_f5tts()
102
+ E2TTS_ema_model4 = load_e2tts() if USING_SPACES else None
103
+ custom_ema_model, pre_custom_path = None, ""
104
+
105
+ chat_model_state = None
106
+ chat_tokenizer_state = None
107
+
108
 
109
 
110
  # load models
 
113
  )
114
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
115
 
116
+ F5TTS_ema_model = load_custom(
117
+ "https://huggingface.co/Gregniuki/F5-tts_English_German_Polish/resolve/main/English/model_222600.pt", "", F5TTS_model_cfg
118
  )
119
+ E2TTS_ema_model = load_custom(
120
+ "https://huggingface.co/Gregniuki/F5-tts_English_German_Polish/resolve/main/Polish2/model_1200000.pt", "", F5TTS_model_cfg
121
  )
122
+ E2TTS_ema_model2 = load_custom(
123
+ "https://huggingface.co/Gregniuki/F5-tts_English_German_Polish/resolve/main/Polish/model_500000.pt", "", F5TTS_model_cfg
 
 
 
 
 
 
124
  )
125
+
126
  def chunk_text(text, max_chars=135):
127
  """
128
  Splits the input text into chunks, each with a maximum number of characters.